Merge branch 'main' into elvischenv/update-flashinfer

This commit is contained in:
elvischenv 2025-12-23 09:33:11 +08:00 committed by GitHub
commit f5228915a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 350 additions and 116 deletions

View File

@ -28,3 +28,4 @@ The backends below live **outside** the main `vllm` repository and follow the
| Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> |
| Baidu Kunlun XPU | N/A, install from source | <https://github.com/baidu/vLLM-Kunlun> |
| Sophgo TPU | N/A, install from source | <https://github.com/sophgo/vllm-tpu> |
| Apple Silicon (Metal) | N/A, install from source | <https://github.com/vllm-project/vllm-metal> |

View File

@ -4,6 +4,9 @@ vLLM has experimental support for macOS with Apple Silicon. For now, users must
Currently the CPU implementation for macOS supports FP32 and FP16 datatypes.
!!! tip "GPU-Accelerated Inference with vLLM-Metal"
For GPU-accelerated inference on Apple Silicon using Metal, check out [vllm-metal](https://github.com/vllm-project/vllm-metal), a community-maintained hardware plugin that uses MLX as the compute backend.
# --8<-- [end:installation]
# --8<-- [start:requirements]

View File

@ -16,7 +16,7 @@ To run inference on a single or multiple GPUs, use `VLLM` class from `langchain`
from langchain_community.llms import VLLM
llm = VLLM(
model="mosaicml/mpt-7b",
model="Qwen/Qwen3-4B",
trust_remote_code=True, # mandatory for hf models
max_new_tokens=128,
top_k=10,

View File

@ -215,7 +215,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True,
),
"CwmForCausalLM": _HfExamplesInfo("facebook/cwm", min_transformers_version="4.58"),
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
# FIXME: databricks/dbrx-instruct has been deleted
"DbrxForCausalLM": _HfExamplesInfo(
"databricks/dbrx-instruct", is_available_online=False
),
"DeciLMForCausalLM": _HfExamplesInfo(
"nvidia/Llama-3_3-Nemotron-Super-49B-v1",
trust_remote_code=True,
@ -366,7 +369,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
{"tiny": "TitanML/tiny-mixtral"},
),
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
# FIXME: mosaicml/mpt-7b has been deleted
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b", is_available_online=False),
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),
"NemotronHForCausalLM": _HfExamplesInfo(
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True

View File

@ -38,7 +38,8 @@ TOKENIZERS = [
"EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m",
"bigscience/bloom-560m",
"mosaicml/mpt-7b",
# FIXME: mosaicml/mpt-7b has been deleted
# "mosaicml/mpt-7b",
"tiiuae/falcon-7b",
"meta-llama/Llama-3.2-1B-Instruct",
"codellama/CodeLlama-7b-hf",

View File

@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
from vllm.v1.kv_cache_interface import MLAAttentionSpec
BACKENDS_TO_TEST = [
AttentionBackendEnum.CUTLASS_MLA,
@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def run_attention_backend(
backend: AttentionBackendEnum,
kv_cache_spec: FullAttentionSpec,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config,
device: torch.device,
@ -740,7 +740,7 @@ def test_backend_correctness(
kv_cache = kv_cache_per_block_size[block_size]
# Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = FullAttentionSpec(
backend_kv_cache_spec = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
@ -748,6 +748,7 @@ def test_backend_correctness(
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(),
cache_dtype_str=vllm_config.cache_config.cache_dtype,
)
backend_output = run_attention_backend(

View File

@ -5,6 +5,7 @@ import pytest
import torch.cuda
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
@ -14,6 +15,11 @@ MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
def test_preprocess_error_handling(monkeypatch: pytest.MonkeyPatch):
"""Test that preprocessing errors are handled gracefully."""
if current_platform.is_rocm():
pytest.skip(
"Skipped on ROCm: this test only works with 'fork', but ROCm uses 'spawn'."
)
assert not torch.cuda.is_initialized(), (
"fork needs to be used for the engine "
"core process and this isn't possible if cuda is already initialized"

View File

@ -306,10 +306,16 @@ def test_prepare_inputs_padded():
proposer = _create_proposer("eagle", num_speculative_tokens)
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = (
proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
)
)
# Verify num_rejected_tokens_gpu is calculated correctly
expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device)
assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected)
assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)

View File

@ -2132,6 +2132,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
torch.float16,
torch.bfloat16,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
]
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
@ -2156,7 +2157,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif hidden_states.dtype == torch.float8_e4m3fn:
elif (
hidden_states.dtype == torch.float8_e4m3fn
or hidden_states.dtype == torch.float8_e4m3fnuz
):
compute_type = tl.bfloat16
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")

View File

@ -13,6 +13,10 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def __init__(self, defer_input_quant: bool = False) -> None:
super().__init__()
self.defer_input_quant = defer_input_quant
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@ -48,6 +52,11 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
# which use a single kernel call for quant + experts.
if self.defer_input_quant:
return a1, None, None, None, None
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_scale,

View File

@ -5,11 +5,15 @@ from functools import lru_cache
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
class QuantMethod(IntEnum):
@ -263,3 +267,78 @@ def rocm_aiter_fused_experts(
a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input,
)
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config):
super().__init__(quant_config)
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
def supports_expert_map(self):
return True
def supports_chunking(self):
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Workspaces are managed internally by AITER.
workspace1 = (0,)
workspace2 = (0,)
output = (M, K)
return (workspace1, workspace2, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert a1q_scale is None
assert a2_scale is None
assert expert_tokens_meta is None
result = rocm_aiter_fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.quant_config,
)
assert result.shape == output.shape
output.copy_(result)

View File

@ -117,6 +117,7 @@ class Fp8MoeBackend(Enum):
DEEPGEMM = 3
MARLIN = 4
TRITON = 5
AITER = 6
def get_fp8_moe_backend(
@ -189,6 +190,10 @@ def get_fp8_moe_backend(
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
return Fp8MoeBackend.DEEPGEMM
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
return Fp8MoeBackend.AITER
# default to Triton
logger.info_once("Using Triton backend for FP8 MoE")
return Fp8MoeBackend.TRITON
@ -888,16 +893,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale = None
layer.w2_input_scale = None
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# TODO (rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
@ -932,7 +931,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
if self.rocm_aiter_moe_enabled:
if self.fp8_backend == Fp8MoeBackend.AITER:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
@ -1026,7 +1025,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
start += shard_size
if self.rocm_aiter_moe_enabled:
if self.fp8_backend == Fp8MoeBackend.AITER:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
@ -1072,6 +1071,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = config
self.kernel = mk.FusedMoEModularKernel(
# TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP
# with the changes to defer input quantization
FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=(self.moe.dp_size > 1),
use_deepseek_fp8_block_scale=self.block_quant,
@ -1093,6 +1094,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
Fp8MoeBackend.DEEPGEMM,
Fp8MoeBackend.TRITON,
Fp8MoeBackend.MARLIN,
Fp8MoeBackend.AITER,
]:
from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts,
@ -1103,24 +1105,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
config = self.get_fused_moe_quant_config(layer)
assert config is not None
self.moe_quant_config = config
use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
moe_kernel = (
MarlinExperts(quant_config=self.moe_quant_config)
if use_marlin
else TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=allow_deep_gemm,
)
)
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), moe_kernel
)
if self.fp8_backend == Fp8MoeBackend.AITER:
self.kernel = mk.FusedMoEModularKernel(
# TODO: make defer_input_quant an attr of the AiterExperts
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
AiterExperts(quant_config=self.moe_quant_config),
)
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config=self.moe_quant_config),
)
else:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
),
)
self.use_inplace = True
def maybe_make_prepare_finalize(
@ -1128,7 +1139,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if (
self.rocm_aiter_moe_enabled
self.fp8_backend == Fp8MoeBackend.AITER
or self.fp8_backend == Fp8MoeBackend.MARLIN
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
@ -1161,11 +1172,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
TritonOrDeepGemmExperts,
)
assert (
self.fp8_backend != Fp8MoeBackend.MARLIN
) and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet."
)
if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
raise NotImplementedError(
"Marlin and ROCm AITER are not supported with all2all yet."
)
assert self.moe_quant_config is not None
@ -1313,37 +1323,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
hidden_states=x,
router_logits=router_logits,
)
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_fused_experts,
)
# TODO(rob): convert this to MK.
result = rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
else:
result = self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
result = self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
return result
@ -1456,15 +1447,10 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
layer.w13_input_scale = None
layer.w2_input_scale = None
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
@ -1481,7 +1467,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
replace_parameter(layer, "w2_weight", w2_weight)
# Reshuffle weights for AITER if needed.
if self.rocm_aiter_moe_enabled:
if self.fp8_backend == Fp8MoeBackend.AITER:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
@ -1489,7 +1475,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
replace_parameter(layer, "w2_weight", shuffled_w2)
# Rushuffle weights for MARLIN if needed.
if self.fp8_backend == Fp8MoeBackend.MARLIN:
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)

View File

@ -143,7 +143,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True)
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (

View File

@ -355,6 +355,8 @@ class MLACommonPrefillMetadata:
max_query_len: int
chunked_context: ChunkedContextMetadata | None = None
query_seq_lens: torch.Tensor | None = None
workspace_buffer: torch.Tensor | None = None
q_data_type: torch.dtype | None = None
@dataclass
@ -539,6 +541,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
metadata_cls if metadata_cls is not None else MLACommonMetadata
)
self.kv_cache_spec = kv_cache_spec
self.q_data_type = (
current_platform.fp8_dtype()
if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str)
else vllm_config.model_config.dtype
)
scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
@ -558,6 +565,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.dcp_rank = 0
self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
self.cp_kv_cache_interleave_size = parallel_config.cp_kv_cache_interleave_size
# Don't try to access the runner on AMD
if self.aot_schedule:
@ -681,7 +689,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# For main run, qo_indptr == kv_indptr
kv_indptr = qo_indptr.clone()
# Prepare main prefill
self._fi_prefill_main.plan(
qo_indptr=qo_indptr,
@ -694,7 +701,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype,
q_data_type=self.q_data_type,
)
# Prepare context prefills
@ -713,7 +720,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype,
q_data_type=self.q_data_type,
)
prefill.prefill_main = self._fi_prefill_main
@ -722,8 +729,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
@ -773,13 +780,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
@ -794,6 +795,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata = None
if num_prefills > 0:
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu
reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
@ -970,6 +973,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
q_data_type=self.q_data_type,
)
if self._use_cudnn_prefill:
@ -983,19 +987,29 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata.query_seq_lens = (
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
)
prefill_metadata.workspace_buffer = self._workspace_buffer
decode_metadata = None
if num_decodes > 0:
dcp_tot_seq_lens_device = None
if self.dcp_world_size > 1:
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
seq_lens_cpu = dcp_local_seq_lens_cpu
seq_lens = dcp_local_seq_lens
# After DCP distribution, the maximum number of tokens for any rank is
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
# and I is cp_kv_cache_interleave_size.
# This eliminates GPU->CPU sync while minimizing workspace
# over-allocation.
num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
max_seq_len = (
(max_seq_len + num_partitions - 1) // num_partitions
) * self.cp_kv_cache_interleave_size
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
max_seq_len=max_seq_len,
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
query_start_loc_device=query_start_loc[: num_decodes + 1],
num_decode_tokens=num_decode_tokens,
@ -1370,8 +1384,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out
def _run_prefill_new_tokens_fa(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
):
logger.debug_once("Running FlashAttention prefill new tokens", scope="local")
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
@ -1386,11 +1407,23 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def _run_prefill_new_tokens_fi(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
):
logger.debug_once("Running FlashInfer prefill new tokens", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None
if fp8_attention:
logger.debug_once("Running Flashinfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = prefill.prefill_main.run(
q=q,
k=k,
@ -1403,10 +1436,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret
def _run_prefill_new_tokens_cudnn(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
):
logger.debug_once("Running Cudnn prefill new tokens", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
output, lse = cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
@ -1428,9 +1469,19 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return output
def _run_prefill_context_chunk_fa(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
):
logger.debug_once("Running FlashAttention prefill context chunk", scope="local")
assert prefill.chunked_context is not None
assert fp8_attention is False, (
"FlashAttention prefill does not support fp8 attention"
)
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
@ -1445,10 +1496,22 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def _run_prefill_context_chunk_fi(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
):
logger.debug_once("Running FlashInfer prefill context chunk", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata)
if fp8_attention:
logger.debug_once("Running FlashInfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q,
k=k,
@ -1460,12 +1523,20 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out, lse.transpose(0, 1).contiguous()
def _run_prefill_context_chunk_cudnn(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
):
logger.debug_once("Running Cudnn prefill context chunk", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
return cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
@ -1485,18 +1556,33 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def _run_prefill_new_tokens_trtllm_ragged(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
):
logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local")
"""TRT-LLM ragged attention for new tokens (causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.query_seq_lens is not None
assert prefill.workspace_buffer is not None
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self._workspace_buffer,
workspace_buffer=prefill.workspace_buffer,
seq_lens=prefill.query_seq_lens,
max_q_len=prefill.max_query_len,
max_kv_len=prefill.max_query_len,
@ -1518,13 +1604,21 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret
def _run_prefill_context_chunk_trtllm_ragged(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
):
logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local")
"""TRT-LLM ragged attention for context chunks (non-causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.workspace_buffer is not None
out = torch.zeros(
q.shape[0],
@ -1533,13 +1627,20 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
device=q.device,
dtype=q.dtype,
)
self._workspace_buffer.fill_(0)
prefill.workspace_buffer.fill_(0)
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self._workspace_buffer,
workspace_buffer=prefill.workspace_buffer,
seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
max_q_len=prefill.max_query_len,
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
@ -1687,6 +1788,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
fp8_attention: bool,
):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
@ -1725,6 +1827,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q,
k=k,
v=v,
fp8_attention=fp8_attention,
)
if output is None:
@ -1753,6 +1856,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
dcp_world_size: int,
fp8_attention: bool,
):
assert k_scale is None, "DCP not support scaled kvcache now."
assert attn_metadata.prefill is not None
@ -1829,6 +1933,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q,
k=k,
v=v,
fp8_attention=fp8_attention,
)
if output is None:
@ -1859,6 +1964,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
output: torch.Tensor,
fp8_attention: bool = False,
) -> None:
# TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None
@ -1878,6 +1984,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k=k,
v=v,
return_softmax_lse=has_context,
fp8_attention=fp8_attention,
)
if has_context:
@ -1890,11 +1997,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata,
k_scale=None,
dcp_world_size=self.dcp_world_size,
fp8_attention=fp8_attention,
)
)
else:
context_output, context_lse = self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention
)
# unpad if necessary
@ -2015,6 +2123,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata,
layer._k_scale,
output=output[num_decode_tokens:],
fp8_attention=fp8_attention,
)
if has_decode:

View File

@ -169,8 +169,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
@ -178,7 +178,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
# For Flash Attention MLA + full cudagraph
max_num_splits = 0
@ -193,7 +192,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
max_num_splits = 1
scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(),
num_reqs=seq_lens_device.shape[0],
cu_query_lens=query_start_loc_device,
max_query_len=max_query_len,
seqlens=seq_lens_device,

View File

@ -143,8 +143,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,

View File

@ -106,8 +106,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,

View File

@ -236,6 +236,7 @@ class EagleProposer:
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
@ -414,6 +415,17 @@ class EagleProposer:
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[: batch_size + 1]
).clone()
# In padded drafter batch, we need to adjust the sequence lengths
# to remove the "padding" (i.e. rejected tokens).
# Only apply this adjustment when we have rejected tokens
# (i.e., not the first proposal).
if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
# Invalidate the CPU-side shadows to avoid H<>D sync.
common_attn_metadata._seq_lens_cpu = None
common_attn_metadata._num_computed_tokens_cpu = None
for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
@ -628,13 +640,14 @@ class EagleProposer:
common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor,
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
"""
num_reqs = common_attn_metadata.num_reqs
device = valid_sampled_tokens_count.device
@ -642,14 +655,17 @@ class EagleProposer:
token_indices_to_sample = torch.empty(
(num_reqs,), dtype=torch.int32, device=device
)
num_rejected_tokens_gpu = torch.empty(
(num_reqs,), dtype=torch.int32, device=device
)
# Kernel grid: one program per request (row)
grid = (num_reqs,)
eagle_prepare_inputs_padded_kernel[grid](
spec_decode_metadata.cu_num_draft_tokens,
valid_sampled_tokens_count,
common_attn_metadata.query_start_loc,
token_indices_to_sample,
num_rejected_tokens_gpu,
num_reqs,
)
@ -674,7 +690,11 @@ class EagleProposer:
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
)
return spec_common_attn_metadata, token_indices_to_sample
return (
spec_common_attn_metadata,
token_indices_to_sample,
num_rejected_tokens_gpu,
)
def propose_tree(
self,

View File

@ -23,6 +23,7 @@ def eagle_prepare_inputs_padded_kernel(
valid_sampled_tokens_count_ptr, # [num_reqs]
query_start_loc_gpu_ptr, # [num_reqs + 1]
token_indices_to_sample_ptr, # [num_reqs] (output)
num_rejected_tokens_gpu_ptr, # [num_reqs] (output)
num_reqs, # tl.int32
):
"""
@ -56,6 +57,7 @@ def eagle_prepare_inputs_padded_kernel(
index_to_sample = q_last_tok_idx - num_rejected_tokens
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
tl.store(num_rejected_tokens_gpu_ptr + req_idx, num_rejected_tokens)
@triton.jit

View File

@ -3534,6 +3534,7 @@ class GPUModelRunner(
next_token_ids, valid_sampled_tokens_count
)
num_rejected_tokens_gpu = None
if spec_decode_metadata is None:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
@ -3564,12 +3565,14 @@ class GPUModelRunner(
else:
target_hidden_states = hidden_states[token_indices]
else:
common_attn_metadata, token_indices_to_sample = (
self.drafter.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count,
)
(
common_attn_metadata,
token_indices_to_sample,
num_rejected_tokens_gpu,
) = self.drafter.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count,
)
total_num_tokens = common_attn_metadata.num_actual_tokens
# When padding the batch, token_indices is just a range
@ -3600,6 +3603,7 @@ class GPUModelRunner(
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
mm_embed_inputs=mm_embed_inputs,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
return draft_token_ids