mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-10 16:54:33 +08:00
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
parent
612d5ffdab
commit
3e10262356
@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
AttentionBackendEnum.CUTLASS_MLA,
|
||||
@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
|
||||
|
||||
def run_attention_backend(
|
||||
backend: AttentionBackendEnum,
|
||||
kv_cache_spec: MLAAttentionSpec,
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config,
|
||||
device: torch.device,
|
||||
@ -740,7 +740,7 @@ def test_backend_correctness(
|
||||
kv_cache = kv_cache_per_block_size[block_size]
|
||||
|
||||
# Create kv_cache_spec with the correct block_size for this backend
|
||||
backend_kv_cache_spec = MLAAttentionSpec(
|
||||
backend_kv_cache_spec = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config
|
||||
@ -748,7 +748,6 @@ def test_backend_correctness(
|
||||
head_size=vllm_config.model_config.get_head_size(),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||
cache_dtype_str=vllm_config.cache_config.cache_dtype,
|
||||
)
|
||||
|
||||
backend_output = run_attention_backend(
|
||||
|
||||
@ -325,6 +325,7 @@ def flashinfer_trtllm_fp4_moe(
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
tile_tokens_dim=None,
|
||||
routing_method_type=routing_method_type,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
|
||||
@ -541,11 +541,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
metadata_cls if metadata_cls is not None else MLACommonMetadata
|
||||
)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.q_data_type = (
|
||||
current_platform.fp8_dtype()
|
||||
if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str)
|
||||
else vllm_config.model_config.dtype
|
||||
)
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
@ -689,6 +684,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
# For main run, qo_indptr == kv_indptr
|
||||
kv_indptr = qo_indptr.clone()
|
||||
|
||||
# Prepare main prefill
|
||||
self._fi_prefill_main.plan(
|
||||
qo_indptr=qo_indptr,
|
||||
@ -701,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
sm_scale=self._global_hyperparameters.sm_scale,
|
||||
window_left=self._global_hyperparameters.window_left,
|
||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
q_data_type=self.model_config.dtype,
|
||||
)
|
||||
|
||||
# Prepare context prefills
|
||||
@ -720,7 +716,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
sm_scale=self._global_hyperparameters.sm_scale,
|
||||
window_left=self._global_hyperparameters.window_left,
|
||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
q_data_type=self.model_config.dtype,
|
||||
)
|
||||
|
||||
prefill.prefill_main = self._fi_prefill_main
|
||||
@ -973,7 +969,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
chunked_context=chunked_context_metadata,
|
||||
q_data_type=self.q_data_type,
|
||||
)
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
@ -1384,15 +1379,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
return attn_out
|
||||
|
||||
def _run_prefill_new_tokens_fa(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
||||
):
|
||||
logger.debug_once("Running FlashAttention prefill new tokens", scope="local")
|
||||
return self._flash_attn_varlen_diff_headdims(
|
||||
q=q,
|
||||
k=k,
|
||||
@ -1407,23 +1395,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
)
|
||||
|
||||
def _run_prefill_new_tokens_fi(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
||||
):
|
||||
logger.debug_once("Running FlashInfer prefill new tokens", scope="local")
|
||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||
assert prefill.prefill_main is not None
|
||||
if fp8_attention:
|
||||
logger.debug_once("Running Flashinfer prefill in FP8")
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
q = q.to(fp8_dtype)
|
||||
k = k.to(fp8_dtype)
|
||||
v = v.to(fp8_dtype)
|
||||
|
||||
ret = prefill.prefill_main.run(
|
||||
q=q,
|
||||
k=k,
|
||||
@ -1436,18 +1412,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
return ret
|
||||
|
||||
def _run_prefill_new_tokens_cudnn(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
||||
):
|
||||
logger.debug_once("Running Cudnn prefill new tokens", scope="local")
|
||||
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||
assert prefill.query_seq_lens is not None
|
||||
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
|
||||
output, lse = cudnn_batch_prefill_with_kv_cache(
|
||||
q=q,
|
||||
k_cache=k,
|
||||
@ -1469,19 +1437,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
return output
|
||||
|
||||
def _run_prefill_context_chunk_fa(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
chunk_idx: int,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
||||
):
|
||||
logger.debug_once("Running FlashAttention prefill context chunk", scope="local")
|
||||
assert prefill.chunked_context is not None
|
||||
assert fp8_attention is False, (
|
||||
"FlashAttention prefill does not support fp8 attention"
|
||||
)
|
||||
return self._flash_attn_varlen_diff_headdims(
|
||||
q=q,
|
||||
k=k,
|
||||
@ -1496,22 +1454,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
)
|
||||
|
||||
def _run_prefill_context_chunk_fi(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
chunk_idx: int,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
||||
):
|
||||
logger.debug_once("Running FlashInfer prefill context chunk", scope="local")
|
||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||
if fp8_attention:
|
||||
logger.debug_once("Running FlashInfer prefill in FP8")
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
q = q.to(fp8_dtype)
|
||||
k = k.to(fp8_dtype)
|
||||
v = v.to(fp8_dtype)
|
||||
|
||||
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
|
||||
q=q,
|
||||
k=k,
|
||||
@ -1523,20 +1469,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
return attn_out, lse.transpose(0, 1).contiguous()
|
||||
|
||||
def _run_prefill_context_chunk_cudnn(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
chunk_idx: int,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
||||
):
|
||||
logger.debug_once("Running Cudnn prefill context chunk", scope="local")
|
||||
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||
assert prefill.chunked_context is not None
|
||||
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
||||
assert prefill.query_seq_lens is not None
|
||||
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
|
||||
return cudnn_batch_prefill_with_kv_cache(
|
||||
q=q,
|
||||
k_cache=k,
|
||||
@ -1556,28 +1494,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
)
|
||||
|
||||
def _run_prefill_new_tokens_trtllm_ragged(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
||||
):
|
||||
logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local")
|
||||
"""TRT-LLM ragged attention for new tokens (causal)."""
|
||||
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
||||
|
||||
assert prefill.query_seq_lens is not None
|
||||
assert prefill.workspace_buffer is not None
|
||||
|
||||
if fp8_attention:
|
||||
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
q = q.to(fp8_dtype)
|
||||
k = k.to(fp8_dtype)
|
||||
v = v.to(fp8_dtype)
|
||||
|
||||
ret = trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
@ -1604,15 +1528,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
return ret
|
||||
|
||||
def _run_prefill_context_chunk_trtllm_ragged(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
chunk_idx: int,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
||||
):
|
||||
logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local")
|
||||
"""TRT-LLM ragged attention for context chunks (non-causal)."""
|
||||
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
||||
|
||||
@ -1629,13 +1546,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
)
|
||||
prefill.workspace_buffer.fill_(0)
|
||||
|
||||
if fp8_attention:
|
||||
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
q = q.to(fp8_dtype)
|
||||
k = k.to(fp8_dtype)
|
||||
v = v.to(fp8_dtype)
|
||||
|
||||
attn_out, lse = trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
@ -1788,7 +1698,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
fp8_attention: bool,
|
||||
):
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
@ -1827,7 +1736,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
fp8_attention=fp8_attention,
|
||||
)
|
||||
|
||||
if output is None:
|
||||
@ -1856,7 +1764,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
dcp_world_size: int,
|
||||
fp8_attention: bool,
|
||||
):
|
||||
assert k_scale is None, "DCP not support scaled kvcache now."
|
||||
assert attn_metadata.prefill is not None
|
||||
@ -1933,7 +1840,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
fp8_attention=fp8_attention,
|
||||
)
|
||||
|
||||
if output is None:
|
||||
@ -1964,7 +1870,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
fp8_attention: bool = False,
|
||||
) -> None:
|
||||
# TODO (zyongye): Prefill function here
|
||||
assert attn_metadata.prefill is not None
|
||||
@ -1984,7 +1889,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
k=k,
|
||||
v=v,
|
||||
return_softmax_lse=has_context,
|
||||
fp8_attention=fp8_attention,
|
||||
)
|
||||
|
||||
if has_context:
|
||||
@ -1997,12 +1901,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
attn_metadata,
|
||||
k_scale=None,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
fp8_attention=fp8_attention,
|
||||
)
|
||||
)
|
||||
else:
|
||||
context_output, context_lse = self._compute_prefill_context(
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
|
||||
)
|
||||
|
||||
# unpad if necessary
|
||||
@ -2123,7 +2026,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
attn_metadata,
|
||||
layer._k_scale,
|
||||
output=output[num_decode_tokens:],
|
||||
fp8_attention=fp8_attention,
|
||||
)
|
||||
|
||||
if has_decode:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user