Merge branch 'main' into elvischenv/update-flashinfer

This commit is contained in:
elvischenv 2025-12-23 09:33:11 +08:00 committed by GitHub
commit f5228915a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 350 additions and 116 deletions

View File

@ -28,3 +28,4 @@ The backends below live **outside** the main `vllm` repository and follow the
| Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> | | Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> |
| Baidu Kunlun XPU | N/A, install from source | <https://github.com/baidu/vLLM-Kunlun> | | Baidu Kunlun XPU | N/A, install from source | <https://github.com/baidu/vLLM-Kunlun> |
| Sophgo TPU | N/A, install from source | <https://github.com/sophgo/vllm-tpu> | | Sophgo TPU | N/A, install from source | <https://github.com/sophgo/vllm-tpu> |
| Apple Silicon (Metal) | N/A, install from source | <https://github.com/vllm-project/vllm-metal> |

View File

@ -4,6 +4,9 @@ vLLM has experimental support for macOS with Apple Silicon. For now, users must
Currently the CPU implementation for macOS supports FP32 and FP16 datatypes. Currently the CPU implementation for macOS supports FP32 and FP16 datatypes.
!!! tip "GPU-Accelerated Inference with vLLM-Metal"
For GPU-accelerated inference on Apple Silicon using Metal, check out [vllm-metal](https://github.com/vllm-project/vllm-metal), a community-maintained hardware plugin that uses MLX as the compute backend.
# --8<-- [end:installation] # --8<-- [end:installation]
# --8<-- [start:requirements] # --8<-- [start:requirements]

View File

@ -16,7 +16,7 @@ To run inference on a single or multiple GPUs, use `VLLM` class from `langchain`
from langchain_community.llms import VLLM from langchain_community.llms import VLLM
llm = VLLM( llm = VLLM(
model="mosaicml/mpt-7b", model="Qwen/Qwen3-4B",
trust_remote_code=True, # mandatory for hf models trust_remote_code=True, # mandatory for hf models
max_new_tokens=128, max_new_tokens=128,
top_k=10, top_k=10,

View File

@ -215,7 +215,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
), ),
"CwmForCausalLM": _HfExamplesInfo("facebook/cwm", min_transformers_version="4.58"), "CwmForCausalLM": _HfExamplesInfo("facebook/cwm", min_transformers_version="4.58"),
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"), # FIXME: databricks/dbrx-instruct has been deleted
"DbrxForCausalLM": _HfExamplesInfo(
"databricks/dbrx-instruct", is_available_online=False
),
"DeciLMForCausalLM": _HfExamplesInfo( "DeciLMForCausalLM": _HfExamplesInfo(
"nvidia/Llama-3_3-Nemotron-Super-49B-v1", "nvidia/Llama-3_3-Nemotron-Super-49B-v1",
trust_remote_code=True, trust_remote_code=True,
@ -366,7 +369,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
{"tiny": "TitanML/tiny-mixtral"}, {"tiny": "TitanML/tiny-mixtral"},
), ),
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), # FIXME: mosaicml/mpt-7b has been deleted
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b", is_available_online=False),
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),
"NemotronHForCausalLM": _HfExamplesInfo( "NemotronHForCausalLM": _HfExamplesInfo(
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True "nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True

View File

@ -38,7 +38,8 @@ TOKENIZERS = [
"EleutherAI/gpt-j-6b", "EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m", "EleutherAI/pythia-70m",
"bigscience/bloom-560m", "bigscience/bloom-560m",
"mosaicml/mpt-7b", # FIXME: mosaicml/mpt-7b has been deleted
# "mosaicml/mpt-7b",
"tiiuae/falcon-7b", "tiiuae/falcon-7b",
"meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
"codellama/CodeLlama-7b-hf", "codellama/CodeLlama-7b-hf",

View File

@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import MLAAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
AttentionBackendEnum.CUTLASS_MLA, AttentionBackendEnum.CUTLASS_MLA,
@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def run_attention_backend( def run_attention_backend(
backend: AttentionBackendEnum, backend: AttentionBackendEnum,
kv_cache_spec: FullAttentionSpec, kv_cache_spec: MLAAttentionSpec,
layer_names: list[str], layer_names: list[str],
vllm_config, vllm_config,
device: torch.device, device: torch.device,
@ -740,7 +740,7 @@ def test_backend_correctness(
kv_cache = kv_cache_per_block_size[block_size] kv_cache = kv_cache_per_block_size[block_size]
# Create kv_cache_spec with the correct block_size for this backend # Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = FullAttentionSpec( backend_kv_cache_spec = MLAAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads( num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config vllm_config.parallel_config
@ -748,6 +748,7 @@ def test_backend_correctness(
head_size=vllm_config.model_config.get_head_size(), head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype, dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(), sliding_window=vllm_config.model_config.get_sliding_window(),
cache_dtype_str=vllm_config.cache_config.cache_dtype,
) )
backend_output = run_attention_backend( backend_output = run_attention_backend(

View File

@ -5,6 +5,7 @@ import pytest
import torch.cuda import torch.cuda
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core import EngineCore
@ -14,6 +15,11 @@ MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
def test_preprocess_error_handling(monkeypatch: pytest.MonkeyPatch): def test_preprocess_error_handling(monkeypatch: pytest.MonkeyPatch):
"""Test that preprocessing errors are handled gracefully.""" """Test that preprocessing errors are handled gracefully."""
if current_platform.is_rocm():
pytest.skip(
"Skipped on ROCm: this test only works with 'fork', but ROCm uses 'spawn'."
)
assert not torch.cuda.is_initialized(), ( assert not torch.cuda.is_initialized(), (
"fork needs to be used for the engine " "fork needs to be used for the engine "
"core process and this isn't possible if cuda is already initialized" "core process and this isn't possible if cuda is already initialized"

View File

@ -306,10 +306,16 @@ def test_prepare_inputs_padded():
proposer = _create_proposer("eagle", num_speculative_tokens) proposer = _create_proposer("eagle", num_speculative_tokens)
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded( output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = (
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
)
) )
# Verify num_rejected_tokens_gpu is calculated correctly
expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device)
assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected)
assert output_metadata.max_query_len == 3 assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc) assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)

View File

@ -2132,6 +2132,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
] ]
E, num_tokens, N, K, top_k_num = self.moe_problem_size( E, num_tokens, N, K, top_k_num = self.moe_problem_size(
@ -2156,7 +2157,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
compute_type = tl.float16 compute_type = tl.float16
elif hidden_states.dtype == torch.float32: elif hidden_states.dtype == torch.float32:
compute_type = tl.float32 compute_type = tl.float32
elif hidden_states.dtype == torch.float8_e4m3fn: elif (
hidden_states.dtype == torch.float8_e4m3fn
or hidden_states.dtype == torch.float8_e4m3fnuz
):
compute_type = tl.bfloat16 compute_type = tl.bfloat16
else: else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")

View File

@ -13,6 +13,10 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def __init__(self, defer_input_quant: bool = False) -> None:
super().__init__()
self.defer_input_quant = defer_input_quant
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
@ -48,6 +52,11 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
# Note: do not use inplace for shared experts overlap # Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
# which use a single kernel call for quant + experts.
if self.defer_input_quant:
return a1, None, None, None, None
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
quant_config.a1_scale, quant_config.a1_scale,

View File

@ -5,11 +5,15 @@ from functools import lru_cache
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
class QuantMethod(IntEnum): class QuantMethod(IntEnum):
@ -263,3 +267,78 @@ def rocm_aiter_fused_experts(
a2_scale=quant_config.a2_scale, a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input, doweight_stage1=apply_router_weight_on_input,
) )
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config):
super().__init__(quant_config)
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
def supports_expert_map(self):
return True
def supports_chunking(self):
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Workspaces are managed internally by AITER.
workspace1 = (0,)
workspace2 = (0,)
output = (M, K)
return (workspace1, workspace2, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert a1q_scale is None
assert a2_scale is None
assert expert_tokens_meta is None
result = rocm_aiter_fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.quant_config,
)
assert result.shape == output.shape
output.copy_(result)

View File

@ -117,6 +117,7 @@ class Fp8MoeBackend(Enum):
DEEPGEMM = 3 DEEPGEMM = 3
MARLIN = 4 MARLIN = 4
TRITON = 5 TRITON = 5
AITER = 6
def get_fp8_moe_backend( def get_fp8_moe_backend(
@ -189,6 +190,10 @@ def get_fp8_moe_backend(
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local") logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
return Fp8MoeBackend.DEEPGEMM return Fp8MoeBackend.DEEPGEMM
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
return Fp8MoeBackend.AITER
# default to Triton # default to Triton
logger.info_once("Using Triton backend for FP8 MoE") logger.info_once("Using Triton backend for FP8 MoE")
return Fp8MoeBackend.TRITON return Fp8MoeBackend.TRITON
@ -888,16 +893,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False): if getattr(layer, "_already_called_process_weights_after_loading", False):
return return
# Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# TODO (rob): refactor block quant into separate class. # TODO (rob): refactor block quant into separate class.
if self.block_quant: if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic" assert self.quant_config.activation_scheme == "dynamic"
@ -932,7 +931,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv) replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
replace_parameter(layer, "w2_weight", w2_weight) replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv) replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
if self.rocm_aiter_moe_enabled: if self.fp8_backend == Fp8MoeBackend.AITER:
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data layer.w13_weight.data, layer.w2_weight.data
@ -1026,7 +1025,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
start += shard_size start += shard_size
if self.rocm_aiter_moe_enabled: if self.fp8_backend == Fp8MoeBackend.AITER:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight layer.w13_weight, layer.w2_weight
) )
@ -1072,6 +1071,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = config self.moe_quant_config = config
self.kernel = mk.FusedMoEModularKernel( self.kernel = mk.FusedMoEModularKernel(
# TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP
# with the changes to defer input quantization
FlashInferAllGatherMoEPrepareAndFinalize( FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=(self.moe.dp_size > 1), use_dp=(self.moe.dp_size > 1),
use_deepseek_fp8_block_scale=self.block_quant, use_deepseek_fp8_block_scale=self.block_quant,
@ -1093,6 +1094,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.DEEPGEMM,
Fp8MoeBackend.TRITON, Fp8MoeBackend.TRITON,
Fp8MoeBackend.MARLIN, Fp8MoeBackend.MARLIN,
Fp8MoeBackend.AITER,
]: ]:
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
@ -1103,24 +1105,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoEP,
) )
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
config = self.get_fused_moe_quant_config(layer) config = self.get_fused_moe_quant_config(layer)
assert config is not None assert config is not None
self.moe_quant_config = config self.moe_quant_config = config
use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
moe_kernel = (
MarlinExperts(quant_config=self.moe_quant_config)
if use_marlin
else TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=allow_deep_gemm,
)
)
self.kernel = mk.FusedMoEModularKernel( if self.fp8_backend == Fp8MoeBackend.AITER:
MoEPrepareAndFinalizeNoEP(), moe_kernel self.kernel = mk.FusedMoEModularKernel(
) # TODO: make defer_input_quant an attr of the AiterExperts
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
AiterExperts(quant_config=self.moe_quant_config),
)
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config=self.moe_quant_config),
)
else:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
),
)
self.use_inplace = True self.use_inplace = True
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
@ -1128,7 +1139,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if ( if (
self.rocm_aiter_moe_enabled self.fp8_backend == Fp8MoeBackend.AITER
or self.fp8_backend == Fp8MoeBackend.MARLIN or self.fp8_backend == Fp8MoeBackend.MARLIN
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
): ):
@ -1161,11 +1172,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
) )
assert ( if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
self.fp8_backend != Fp8MoeBackend.MARLIN raise NotImplementedError(
) and not self.rocm_aiter_moe_enabled, ( "Marlin and ROCm AITER are not supported with all2all yet."
"Marlin and ROCm AITER are not supported with all2all yet." )
)
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
@ -1313,37 +1323,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
result = self.kernel(
if self.rocm_aiter_moe_enabled: x,
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 layer.w13_weight,
rocm_aiter_fused_experts, layer.w2_weight,
) topk_weights,
topk_ids,
# TODO(rob): convert this to MK. inplace=self.use_inplace,
result = rocm_aiter_fused_experts( activation=layer.activation,
x, global_num_experts=layer.global_num_experts,
layer.w13_weight, expert_map=layer.expert_map,
layer.w2_weight, apply_router_weight_on_input=layer.apply_router_weight_on_input,
topk_weights=topk_weights, )
topk_ids=topk_ids,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
else:
result = self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
return result return result
@ -1456,15 +1447,10 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False): if getattr(layer, "_already_called_process_weights_after_loading", False):
return return
# Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
@ -1481,7 +1467,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
replace_parameter(layer, "w2_weight", w2_weight) replace_parameter(layer, "w2_weight", w2_weight)
# Reshuffle weights for AITER if needed. # Reshuffle weights for AITER if needed.
if self.rocm_aiter_moe_enabled: if self.fp8_backend == Fp8MoeBackend.AITER:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight layer.w13_weight, layer.w2_weight
) )
@ -1489,7 +1475,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
replace_parameter(layer, "w2_weight", shuffled_w2) replace_parameter(layer, "w2_weight", shuffled_w2)
# Rushuffle weights for MARLIN if needed. # Rushuffle weights for MARLIN if needed.
if self.fp8_backend == Fp8MoeBackend.MARLIN: elif self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin( prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype layer, False, input_dtype=self.marlin_input_dtype
) )

View File

@ -143,7 +143,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
query_start_loc = m.query_start_loc query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device) context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True)
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if ( if (

View File

@ -355,6 +355,8 @@ class MLACommonPrefillMetadata:
max_query_len: int max_query_len: int
chunked_context: ChunkedContextMetadata | None = None chunked_context: ChunkedContextMetadata | None = None
query_seq_lens: torch.Tensor | None = None query_seq_lens: torch.Tensor | None = None
workspace_buffer: torch.Tensor | None = None
q_data_type: torch.dtype | None = None
@dataclass @dataclass
@ -539,6 +541,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
metadata_cls if metadata_cls is not None else MLACommonMetadata metadata_cls if metadata_cls is not None else MLACommonMetadata
) )
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.q_data_type = (
current_platform.fp8_dtype()
if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str)
else vllm_config.model_config.dtype
)
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
@ -558,6 +565,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.dcp_rank = 0 self.dcp_rank = 0
self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
self.cp_kv_cache_interleave_size = parallel_config.cp_kv_cache_interleave_size
# Don't try to access the runner on AMD # Don't try to access the runner on AMD
if self.aot_schedule: if self.aot_schedule:
@ -681,7 +689,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# For main run, qo_indptr == kv_indptr # For main run, qo_indptr == kv_indptr
kv_indptr = qo_indptr.clone() kv_indptr = qo_indptr.clone()
# Prepare main prefill # Prepare main prefill
self._fi_prefill_main.plan( self._fi_prefill_main.plan(
qo_indptr=qo_indptr, qo_indptr=qo_indptr,
@ -694,7 +701,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale, sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left, window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap, logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype, q_data_type=self.q_data_type,
) )
# Prepare context prefills # Prepare context prefills
@ -713,7 +720,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale, sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left, window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap, logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype, q_data_type=self.q_data_type,
) )
prefill.prefill_main = self._fi_prefill_main prefill.prefill_main = self._fi_prefill_main
@ -722,8 +729,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def _build_decode( def _build_decode(
self, self,
block_table_tensor: torch.Tensor, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor, query_start_loc_device: torch.Tensor,
num_decode_tokens: int, num_decode_tokens: int,
@ -773,13 +780,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills( split_decodes_and_prefills(
@ -794,6 +795,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata = None prefill_metadata = None
if num_prefills > 0: if num_prefills > 0:
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu
reqs_start = num_decodes # prefill_start reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
@ -970,6 +973,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc=prefill_query_start_loc, query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len, max_query_len=max_query_len,
chunked_context=chunked_context_metadata, chunked_context=chunked_context_metadata,
q_data_type=self.q_data_type,
) )
if self._use_cudnn_prefill: if self._use_cudnn_prefill:
@ -983,19 +987,29 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata.query_seq_lens = ( prefill_metadata.query_seq_lens = (
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
) )
prefill_metadata.workspace_buffer = self._workspace_buffer
decode_metadata = None decode_metadata = None
if num_decodes > 0: if num_decodes > 0:
dcp_tot_seq_lens_device = None dcp_tot_seq_lens_device = None
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
dcp_tot_seq_lens_device = seq_lens[:num_decodes] dcp_tot_seq_lens_device = seq_lens[:num_decodes]
seq_lens_cpu = dcp_local_seq_lens_cpu
seq_lens = dcp_local_seq_lens seq_lens = dcp_local_seq_lens
# After DCP distribution, the maximum number of tokens for any rank is
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
# and I is cp_kv_cache_interleave_size.
# This eliminates GPU->CPU sync while minimizing workspace
# over-allocation.
num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
max_seq_len = (
(max_seq_len + num_partitions - 1) // num_partitions
) * self.cp_kv_cache_interleave_size
decode_metadata = self._build_decode( decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...], block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes], seq_lens_device=seq_lens[:num_decodes],
max_seq_len=max_seq_len,
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
query_start_loc_device=query_start_loc[: num_decodes + 1], query_start_loc_device=query_start_loc[: num_decodes + 1],
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
@ -1370,8 +1384,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out return attn_out
def _run_prefill_new_tokens_fa( def _run_prefill_new_tokens_fa(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
): ):
logger.debug_once("Running FlashAttention prefill new tokens", scope="local")
return self._flash_attn_varlen_diff_headdims( return self._flash_attn_varlen_diff_headdims(
q=q, q=q,
k=k, k=k,
@ -1386,11 +1407,23 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) )
def _run_prefill_new_tokens_fi( def _run_prefill_new_tokens_fi(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
): ):
logger.debug_once("Running FlashInfer prefill new tokens", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata) assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None assert prefill.prefill_main is not None
if fp8_attention:
logger.debug_once("Running Flashinfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = prefill.prefill_main.run( ret = prefill.prefill_main.run(
q=q, q=q,
k=k, k=k,
@ -1403,10 +1436,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret return ret
def _run_prefill_new_tokens_cudnn( def _run_prefill_new_tokens_cudnn(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
): ):
logger.debug_once("Running Cudnn prefill new tokens", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata) assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
output, lse = cudnn_batch_prefill_with_kv_cache( output, lse = cudnn_batch_prefill_with_kv_cache(
q=q, q=q,
k_cache=k, k_cache=k,
@ -1428,9 +1469,19 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return output return output
def _run_prefill_context_chunk_fa( def _run_prefill_context_chunk_fa(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
): ):
logger.debug_once("Running FlashAttention prefill context chunk", scope="local")
assert prefill.chunked_context is not None assert prefill.chunked_context is not None
assert fp8_attention is False, (
"FlashAttention prefill does not support fp8 attention"
)
return self._flash_attn_varlen_diff_headdims( return self._flash_attn_varlen_diff_headdims(
q=q, q=q,
k=k, k=k,
@ -1445,10 +1496,22 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) )
def _run_prefill_context_chunk_fi( def _run_prefill_context_chunk_fi(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
): ):
logger.debug_once("Running FlashInfer prefill context chunk", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata) assert isinstance(prefill, FlashInferPrefillMetadata)
if fp8_attention:
logger.debug_once("Running FlashInfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = prefill.prefill_chunks[chunk_idx].run( attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q, q=q,
k=k, k=k,
@ -1460,12 +1523,20 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out, lse.transpose(0, 1).contiguous() return attn_out, lse.transpose(0, 1).contiguous()
def _run_prefill_context_chunk_cudnn( def _run_prefill_context_chunk_cudnn(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
): ):
logger.debug_once("Running Cudnn prefill context chunk", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata) assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.chunked_context is not None assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
return cudnn_batch_prefill_with_kv_cache( return cudnn_batch_prefill_with_kv_cache(
q=q, q=q,
k_cache=k, k_cache=k,
@ -1485,18 +1556,33 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) )
def _run_prefill_new_tokens_trtllm_ragged( def _run_prefill_new_tokens_trtllm_ragged(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
): ):
logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local")
"""TRT-LLM ragged attention for new tokens (causal).""" """TRT-LLM ragged attention for new tokens (causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
assert prefill.workspace_buffer is not None
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = trtllm_ragged_attention_deepseek( ret = trtllm_ragged_attention_deepseek(
query=q, query=q,
key=k, key=k,
value=v, value=v,
workspace_buffer=self._workspace_buffer, workspace_buffer=prefill.workspace_buffer,
seq_lens=prefill.query_seq_lens, seq_lens=prefill.query_seq_lens,
max_q_len=prefill.max_query_len, max_q_len=prefill.max_query_len,
max_kv_len=prefill.max_query_len, max_kv_len=prefill.max_query_len,
@ -1518,13 +1604,21 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret return ret
def _run_prefill_context_chunk_trtllm_ragged( def _run_prefill_context_chunk_trtllm_ragged(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
): ):
logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local")
"""TRT-LLM ragged attention for context chunks (non-causal).""" """TRT-LLM ragged attention for context chunks (non-causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.chunked_context is not None assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.workspace_buffer is not None
out = torch.zeros( out = torch.zeros(
q.shape[0], q.shape[0],
@ -1533,13 +1627,20 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
device=q.device, device=q.device,
dtype=q.dtype, dtype=q.dtype,
) )
self._workspace_buffer.fill_(0) prefill.workspace_buffer.fill_(0)
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = trtllm_ragged_attention_deepseek( attn_out, lse = trtllm_ragged_attention_deepseek(
query=q, query=q,
key=k, key=k,
value=v, value=v,
workspace_buffer=self._workspace_buffer, workspace_buffer=prefill.workspace_buffer,
seq_lens=prefill.chunked_context.seq_lens[chunk_idx], seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
max_q_len=prefill.max_query_len, max_q_len=prefill.max_query_len,
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx], max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
@ -1687,6 +1788,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor, k_scale: torch.Tensor,
fp8_attention: bool,
): ):
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
@ -1725,6 +1827,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q, q=q,
k=k, k=k,
v=v, v=v,
fp8_attention=fp8_attention,
) )
if output is None: if output is None:
@ -1753,6 +1856,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor, k_scale: torch.Tensor,
dcp_world_size: int, dcp_world_size: int,
fp8_attention: bool,
): ):
assert k_scale is None, "DCP not support scaled kvcache now." assert k_scale is None, "DCP not support scaled kvcache now."
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
@ -1829,6 +1933,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q, q=q,
k=k, k=k,
v=v, v=v,
fp8_attention=fp8_attention,
) )
if output is None: if output is None:
@ -1859,6 +1964,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor, k_scale: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
fp8_attention: bool = False,
) -> None: ) -> None:
# TODO (zyongye): Prefill function here # TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
@ -1878,6 +1984,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k=k, k=k,
v=v, v=v,
return_softmax_lse=has_context, return_softmax_lse=has_context,
fp8_attention=fp8_attention,
) )
if has_context: if has_context:
@ -1890,11 +1997,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata, attn_metadata,
k_scale=None, k_scale=None,
dcp_world_size=self.dcp_world_size, dcp_world_size=self.dcp_world_size,
fp8_attention=fp8_attention,
) )
) )
else: else:
context_output, context_lse = self._compute_prefill_context( context_output, context_lse = self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention
) )
# unpad if necessary # unpad if necessary
@ -2015,6 +2123,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata, attn_metadata,
layer._k_scale, layer._k_scale,
output=output[num_decode_tokens:], output=output[num_decode_tokens:],
fp8_attention=fp8_attention,
) )
if has_decode: if has_decode:

View File

@ -169,8 +169,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
def _build_decode( def _build_decode(
self, self,
block_table_tensor: torch.Tensor, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor, query_start_loc_device: torch.Tensor,
num_decode_tokens: int, num_decode_tokens: int,
@ -178,7 +178,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
) -> FlashAttnMLADecodeMetadata: ) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item() max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
# For Flash Attention MLA + full cudagraph # For Flash Attention MLA + full cudagraph
max_num_splits = 0 max_num_splits = 0
@ -193,7 +192,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
max_num_splits = 1 max_num_splits = 1
scheduler_metadata = self._schedule_decode( scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(), num_reqs=seq_lens_device.shape[0],
cu_query_lens=query_start_loc_device, cu_query_lens=query_start_loc_device,
max_query_len=max_query_len, max_query_len=max_query_len,
seqlens=seq_lens_device, seqlens=seq_lens_device,

View File

@ -143,8 +143,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def _build_decode( def _build_decode(
self, self,
block_table_tensor: torch.Tensor, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor, query_start_loc_device: torch.Tensor,
num_decode_tokens: int, num_decode_tokens: int,

View File

@ -106,8 +106,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
def _build_decode( def _build_decode(
self, self,
block_table_tensor: torch.Tensor, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor, query_start_loc_device: torch.Tensor,
num_decode_tokens: int, num_decode_tokens: int,

View File

@ -236,6 +236,7 @@ class EagleProposer:
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
@ -414,6 +415,17 @@ class EagleProposer:
common_attn_metadata.query_start_loc_cpu = torch.from_numpy( common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[: batch_size + 1] self.token_arange_np[: batch_size + 1]
).clone() ).clone()
# In padded drafter batch, we need to adjust the sequence lengths
# to remove the "padding" (i.e. rejected tokens).
# Only apply this adjustment when we have rejected tokens
# (i.e., not the first proposal).
if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
# Invalidate the CPU-side shadows to avoid H<>D sync.
common_attn_metadata._seq_lens_cpu = None
common_attn_metadata._num_computed_tokens_cpu = None
for token_index in range(self.num_speculative_tokens - 1): for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs. # Update the inputs.
# cast to int32 is crucial when eagle model is compiled. # cast to int32 is crucial when eagle model is compiled.
@ -628,13 +640,14 @@ class EagleProposer:
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata, spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor, valid_sampled_tokens_count: torch.Tensor,
) -> tuple[CommonAttentionMetadata, torch.Tensor]: ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
""" """
This function is used to prepare the inputs for speculative decoding This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding, It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`. used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
""" """
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
device = valid_sampled_tokens_count.device device = valid_sampled_tokens_count.device
@ -642,14 +655,17 @@ class EagleProposer:
token_indices_to_sample = torch.empty( token_indices_to_sample = torch.empty(
(num_reqs,), dtype=torch.int32, device=device (num_reqs,), dtype=torch.int32, device=device
) )
num_rejected_tokens_gpu = torch.empty(
(num_reqs,), dtype=torch.int32, device=device
)
# Kernel grid: one program per request (row)
grid = (num_reqs,) grid = (num_reqs,)
eagle_prepare_inputs_padded_kernel[grid]( eagle_prepare_inputs_padded_kernel[grid](
spec_decode_metadata.cu_num_draft_tokens, spec_decode_metadata.cu_num_draft_tokens,
valid_sampled_tokens_count, valid_sampled_tokens_count,
common_attn_metadata.query_start_loc, common_attn_metadata.query_start_loc,
token_indices_to_sample, token_indices_to_sample,
num_rejected_tokens_gpu,
num_reqs, num_reqs,
) )
@ -674,7 +690,11 @@ class EagleProposer:
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
) )
return spec_common_attn_metadata, token_indices_to_sample return (
spec_common_attn_metadata,
token_indices_to_sample,
num_rejected_tokens_gpu,
)
def propose_tree( def propose_tree(
self, self,

View File

@ -23,6 +23,7 @@ def eagle_prepare_inputs_padded_kernel(
valid_sampled_tokens_count_ptr, # [num_reqs] valid_sampled_tokens_count_ptr, # [num_reqs]
query_start_loc_gpu_ptr, # [num_reqs + 1] query_start_loc_gpu_ptr, # [num_reqs + 1]
token_indices_to_sample_ptr, # [num_reqs] (output) token_indices_to_sample_ptr, # [num_reqs] (output)
num_rejected_tokens_gpu_ptr, # [num_reqs] (output)
num_reqs, # tl.int32 num_reqs, # tl.int32
): ):
""" """
@ -56,6 +57,7 @@ def eagle_prepare_inputs_padded_kernel(
index_to_sample = q_last_tok_idx - num_rejected_tokens index_to_sample = q_last_tok_idx - num_rejected_tokens
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample) tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
tl.store(num_rejected_tokens_gpu_ptr + req_idx, num_rejected_tokens)
@triton.jit @triton.jit

View File

@ -3534,6 +3534,7 @@ class GPUModelRunner(
next_token_ids, valid_sampled_tokens_count next_token_ids, valid_sampled_tokens_count
) )
num_rejected_tokens_gpu = None
if spec_decode_metadata is None: if spec_decode_metadata is None:
token_indices_to_sample = None token_indices_to_sample = None
# input_ids can be None for multimodal models. # input_ids can be None for multimodal models.
@ -3564,12 +3565,14 @@ class GPUModelRunner(
else: else:
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
else: else:
common_attn_metadata, token_indices_to_sample = ( (
self.drafter.prepare_inputs_padded( common_attn_metadata,
common_attn_metadata, token_indices_to_sample,
spec_decode_metadata, num_rejected_tokens_gpu,
valid_sampled_tokens_count, ) = self.drafter.prepare_inputs_padded(
) common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count,
) )
total_num_tokens = common_attn_metadata.num_actual_tokens total_num_tokens = common_attn_metadata.num_actual_tokens
# When padding the batch, token_indices is just a range # When padding the batch, token_indices is just a range
@ -3600,6 +3603,7 @@ class GPUModelRunner(
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
mm_embed_inputs=mm_embed_inputs, mm_embed_inputs=mm_embed_inputs,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
) )
return draft_token_ids return draft_token_ids