mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 17:49:11 +08:00
Merge branch 'main' into elvischenv/update-flashinfer
This commit is contained in:
commit
f5228915a4
@ -28,3 +28,4 @@ The backends below live **outside** the main `vllm` repository and follow the
|
|||||||
| Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> |
|
| Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> |
|
||||||
| Baidu Kunlun XPU | N/A, install from source | <https://github.com/baidu/vLLM-Kunlun> |
|
| Baidu Kunlun XPU | N/A, install from source | <https://github.com/baidu/vLLM-Kunlun> |
|
||||||
| Sophgo TPU | N/A, install from source | <https://github.com/sophgo/vllm-tpu> |
|
| Sophgo TPU | N/A, install from source | <https://github.com/sophgo/vllm-tpu> |
|
||||||
|
| Apple Silicon (Metal) | N/A, install from source | <https://github.com/vllm-project/vllm-metal> |
|
||||||
|
|||||||
@ -4,6 +4,9 @@ vLLM has experimental support for macOS with Apple Silicon. For now, users must
|
|||||||
|
|
||||||
Currently the CPU implementation for macOS supports FP32 and FP16 datatypes.
|
Currently the CPU implementation for macOS supports FP32 and FP16 datatypes.
|
||||||
|
|
||||||
|
!!! tip "GPU-Accelerated Inference with vLLM-Metal"
|
||||||
|
For GPU-accelerated inference on Apple Silicon using Metal, check out [vllm-metal](https://github.com/vllm-project/vllm-metal), a community-maintained hardware plugin that uses MLX as the compute backend.
|
||||||
|
|
||||||
# --8<-- [end:installation]
|
# --8<-- [end:installation]
|
||||||
# --8<-- [start:requirements]
|
# --8<-- [start:requirements]
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,7 @@ To run inference on a single or multiple GPUs, use `VLLM` class from `langchain`
|
|||||||
from langchain_community.llms import VLLM
|
from langchain_community.llms import VLLM
|
||||||
|
|
||||||
llm = VLLM(
|
llm = VLLM(
|
||||||
model="mosaicml/mpt-7b",
|
model="Qwen/Qwen3-4B",
|
||||||
trust_remote_code=True, # mandatory for hf models
|
trust_remote_code=True, # mandatory for hf models
|
||||||
max_new_tokens=128,
|
max_new_tokens=128,
|
||||||
top_k=10,
|
top_k=10,
|
||||||
|
|||||||
@ -215,7 +215,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
),
|
),
|
||||||
"CwmForCausalLM": _HfExamplesInfo("facebook/cwm", min_transformers_version="4.58"),
|
"CwmForCausalLM": _HfExamplesInfo("facebook/cwm", min_transformers_version="4.58"),
|
||||||
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
|
# FIXME: databricks/dbrx-instruct has been deleted
|
||||||
|
"DbrxForCausalLM": _HfExamplesInfo(
|
||||||
|
"databricks/dbrx-instruct", is_available_online=False
|
||||||
|
),
|
||||||
"DeciLMForCausalLM": _HfExamplesInfo(
|
"DeciLMForCausalLM": _HfExamplesInfo(
|
||||||
"nvidia/Llama-3_3-Nemotron-Super-49B-v1",
|
"nvidia/Llama-3_3-Nemotron-Super-49B-v1",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
@ -366,7 +369,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
{"tiny": "TitanML/tiny-mixtral"},
|
{"tiny": "TitanML/tiny-mixtral"},
|
||||||
),
|
),
|
||||||
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
|
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
|
||||||
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
|
# FIXME: mosaicml/mpt-7b has been deleted
|
||||||
|
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b", is_available_online=False),
|
||||||
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),
|
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),
|
||||||
"NemotronHForCausalLM": _HfExamplesInfo(
|
"NemotronHForCausalLM": _HfExamplesInfo(
|
||||||
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
|
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
|
||||||
|
|||||||
@ -38,7 +38,8 @@ TOKENIZERS = [
|
|||||||
"EleutherAI/gpt-j-6b",
|
"EleutherAI/gpt-j-6b",
|
||||||
"EleutherAI/pythia-70m",
|
"EleutherAI/pythia-70m",
|
||||||
"bigscience/bloom-560m",
|
"bigscience/bloom-560m",
|
||||||
"mosaicml/mpt-7b",
|
# FIXME: mosaicml/mpt-7b has been deleted
|
||||||
|
# "mosaicml/mpt-7b",
|
||||||
"tiiuae/falcon-7b",
|
"tiiuae/falcon-7b",
|
||||||
"meta-llama/Llama-3.2-1B-Instruct",
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
"codellama/CodeLlama-7b-hf",
|
"codellama/CodeLlama-7b-hf",
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv
|
|||||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||||
|
|
||||||
BACKENDS_TO_TEST = [
|
BACKENDS_TO_TEST = [
|
||||||
AttentionBackendEnum.CUTLASS_MLA,
|
AttentionBackendEnum.CUTLASS_MLA,
|
||||||
@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
|
|||||||
|
|
||||||
def run_attention_backend(
|
def run_attention_backend(
|
||||||
backend: AttentionBackendEnum,
|
backend: AttentionBackendEnum,
|
||||||
kv_cache_spec: FullAttentionSpec,
|
kv_cache_spec: MLAAttentionSpec,
|
||||||
layer_names: list[str],
|
layer_names: list[str],
|
||||||
vllm_config,
|
vllm_config,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
@ -740,7 +740,7 @@ def test_backend_correctness(
|
|||||||
kv_cache = kv_cache_per_block_size[block_size]
|
kv_cache = kv_cache_per_block_size[block_size]
|
||||||
|
|
||||||
# Create kv_cache_spec with the correct block_size for this backend
|
# Create kv_cache_spec with the correct block_size for this backend
|
||||||
backend_kv_cache_spec = FullAttentionSpec(
|
backend_kv_cache_spec = MLAAttentionSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
|
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
|
||||||
vllm_config.parallel_config
|
vllm_config.parallel_config
|
||||||
@ -748,6 +748,7 @@ def test_backend_correctness(
|
|||||||
head_size=vllm_config.model_config.get_head_size(),
|
head_size=vllm_config.model_config.get_head_size(),
|
||||||
dtype=vllm_config.model_config.dtype,
|
dtype=vllm_config.model_config.dtype,
|
||||||
sliding_window=vllm_config.model_config.get_sliding_window(),
|
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||||
|
cache_dtype_str=vllm_config.cache_config.cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
backend_output = run_attention_backend(
|
backend_output = run_attention_backend(
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import pytest
|
|||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.core import EngineCore
|
from vllm.v1.engine.core import EngineCore
|
||||||
|
|
||||||
@ -14,6 +15,11 @@ MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
|||||||
def test_preprocess_error_handling(monkeypatch: pytest.MonkeyPatch):
|
def test_preprocess_error_handling(monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Test that preprocessing errors are handled gracefully."""
|
"""Test that preprocessing errors are handled gracefully."""
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
pytest.skip(
|
||||||
|
"Skipped on ROCm: this test only works with 'fork', but ROCm uses 'spawn'."
|
||||||
|
)
|
||||||
|
|
||||||
assert not torch.cuda.is_initialized(), (
|
assert not torch.cuda.is_initialized(), (
|
||||||
"fork needs to be used for the engine "
|
"fork needs to be used for the engine "
|
||||||
"core process and this isn't possible if cuda is already initialized"
|
"core process and this isn't possible if cuda is already initialized"
|
||||||
|
|||||||
@ -306,10 +306,16 @@ def test_prepare_inputs_padded():
|
|||||||
|
|
||||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||||
|
|
||||||
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
|
output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = (
|
||||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
proposer.prepare_inputs_padded(
|
||||||
|
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Verify num_rejected_tokens_gpu is calculated correctly
|
||||||
|
expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device)
|
||||||
|
assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected)
|
||||||
|
|
||||||
assert output_metadata.max_query_len == 3
|
assert output_metadata.max_query_len == 3
|
||||||
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
|
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
|
||||||
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
|
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
|
||||||
|
|||||||
@ -2132,6 +2132,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
|
torch.float8_e4m3fnuz,
|
||||||
]
|
]
|
||||||
|
|
||||||
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
|
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
|
||||||
@ -2156,7 +2157,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
compute_type = tl.float16
|
compute_type = tl.float16
|
||||||
elif hidden_states.dtype == torch.float32:
|
elif hidden_states.dtype == torch.float32:
|
||||||
compute_type = tl.float32
|
compute_type = tl.float32
|
||||||
elif hidden_states.dtype == torch.float8_e4m3fn:
|
elif (
|
||||||
|
hidden_states.dtype == torch.float8_e4m3fn
|
||||||
|
or hidden_states.dtype == torch.float8_e4m3fnuz
|
||||||
|
):
|
||||||
compute_type = tl.bfloat16
|
compute_type = tl.bfloat16
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||||
|
|||||||
@ -13,6 +13,10 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
|||||||
|
|
||||||
|
|
||||||
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||||
|
def __init__(self, defer_input_quant: bool = False) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.defer_input_quant = defer_input_quant
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||||
return mk.FusedMoEActivationFormat.Standard
|
return mk.FusedMoEActivationFormat.Standard
|
||||||
@ -48,6 +52,11 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
|||||||
# Note: do not use inplace for shared experts overlap
|
# Note: do not use inplace for shared experts overlap
|
||||||
a1 = a1 * topk_weights.to(a1.dtype)
|
a1 = a1 * topk_weights.to(a1.dtype)
|
||||||
|
|
||||||
|
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
|
||||||
|
# which use a single kernel call for quant + experts.
|
||||||
|
if self.defer_input_quant:
|
||||||
|
return a1, None, None, None, None
|
||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1,
|
a1,
|
||||||
quant_config.a1_scale,
|
quant_config.a1_scale,
|
||||||
|
|||||||
@ -5,11 +5,15 @@ from functools import lru_cache
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
|
TopKWeightAndReduceNoOP,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QuantMethod(IntEnum):
|
class QuantMethod(IntEnum):
|
||||||
@ -263,3 +267,78 @@ def rocm_aiter_fused_experts(
|
|||||||
a2_scale=quant_config.a2_scale,
|
a2_scale=quant_config.a2_scale,
|
||||||
doweight_stage1=apply_router_weight_on_input,
|
doweight_stage1=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
def __init__(self, quant_config):
|
||||||
|
super().__init__(quant_config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_formats(
|
||||||
|
self,
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
return (
|
||||||
|
mk.FusedMoEActivationFormat.Standard,
|
||||||
|
mk.FusedMoEActivationFormat.Standard,
|
||||||
|
)
|
||||||
|
|
||||||
|
def supports_expert_map(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_chunking(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
|
return TopKWeightAndReduceNoOP()
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
|
# Workspaces are managed internally by AITER.
|
||||||
|
workspace1 = (0,)
|
||||||
|
workspace2 = (0,)
|
||||||
|
output = (M, K)
|
||||||
|
return (workspace1, workspace2, output)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
output: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: torch.Tensor | None,
|
||||||
|
a1q_scale: torch.Tensor | None,
|
||||||
|
a2_scale: torch.Tensor | None,
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
):
|
||||||
|
assert a1q_scale is None
|
||||||
|
assert a2_scale is None
|
||||||
|
assert expert_tokens_meta is None
|
||||||
|
|
||||||
|
result = rocm_aiter_fused_experts(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
expert_map=expert_map,
|
||||||
|
quant_config=self.quant_config,
|
||||||
|
)
|
||||||
|
assert result.shape == output.shape
|
||||||
|
output.copy_(result)
|
||||||
|
|||||||
@ -117,6 +117,7 @@ class Fp8MoeBackend(Enum):
|
|||||||
DEEPGEMM = 3
|
DEEPGEMM = 3
|
||||||
MARLIN = 4
|
MARLIN = 4
|
||||||
TRITON = 5
|
TRITON = 5
|
||||||
|
AITER = 6
|
||||||
|
|
||||||
|
|
||||||
def get_fp8_moe_backend(
|
def get_fp8_moe_backend(
|
||||||
@ -189,6 +190,10 @@ def get_fp8_moe_backend(
|
|||||||
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
|
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
|
||||||
return Fp8MoeBackend.DEEPGEMM
|
return Fp8MoeBackend.DEEPGEMM
|
||||||
|
|
||||||
|
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
|
||||||
|
logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
|
||||||
|
return Fp8MoeBackend.AITER
|
||||||
|
|
||||||
# default to Triton
|
# default to Triton
|
||||||
logger.info_once("Using Triton backend for FP8 MoE")
|
logger.info_once("Using Triton backend for FP8 MoE")
|
||||||
return Fp8MoeBackend.TRITON
|
return Fp8MoeBackend.TRITON
|
||||||
@ -888,16 +893,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_input_scale = None
|
layer.w13_input_scale = None
|
||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
self.rocm_aiter_moe_enabled = False
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Lazy import to avoid importing triton too early.
|
|
||||||
|
|
||||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
|
||||||
|
|
||||||
# TODO (rob): refactor block quant into separate class.
|
# TODO (rob): refactor block quant into separate class.
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert self.quant_config.activation_scheme == "dynamic"
|
assert self.quant_config.activation_scheme == "dynamic"
|
||||||
@ -932,7 +931,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
|
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
|
||||||
replace_parameter(layer, "w2_weight", w2_weight)
|
replace_parameter(layer, "w2_weight", w2_weight)
|
||||||
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
|
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.fp8_backend == Fp8MoeBackend.AITER:
|
||||||
# reshaping weights is required for aiter moe kernel.
|
# reshaping weights is required for aiter moe kernel.
|
||||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||||
layer.w13_weight.data, layer.w2_weight.data
|
layer.w13_weight.data, layer.w2_weight.data
|
||||||
@ -1026,7 +1025,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
start += shard_size
|
start += shard_size
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.fp8_backend == Fp8MoeBackend.AITER:
|
||||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||||
layer.w13_weight, layer.w2_weight
|
layer.w13_weight, layer.w2_weight
|
||||||
)
|
)
|
||||||
@ -1072,6 +1071,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.moe_quant_config = config
|
self.moe_quant_config = config
|
||||||
|
|
||||||
self.kernel = mk.FusedMoEModularKernel(
|
self.kernel = mk.FusedMoEModularKernel(
|
||||||
|
# TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP
|
||||||
|
# with the changes to defer input quantization
|
||||||
FlashInferAllGatherMoEPrepareAndFinalize(
|
FlashInferAllGatherMoEPrepareAndFinalize(
|
||||||
use_dp=(self.moe.dp_size > 1),
|
use_dp=(self.moe.dp_size > 1),
|
||||||
use_deepseek_fp8_block_scale=self.block_quant,
|
use_deepseek_fp8_block_scale=self.block_quant,
|
||||||
@ -1093,6 +1094,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
Fp8MoeBackend.DEEPGEMM,
|
Fp8MoeBackend.DEEPGEMM,
|
||||||
Fp8MoeBackend.TRITON,
|
Fp8MoeBackend.TRITON,
|
||||||
Fp8MoeBackend.MARLIN,
|
Fp8MoeBackend.MARLIN,
|
||||||
|
Fp8MoeBackend.AITER,
|
||||||
]:
|
]:
|
||||||
from vllm.model_executor.layers.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
TritonOrDeepGemmExperts,
|
TritonOrDeepGemmExperts,
|
||||||
@ -1103,24 +1105,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
MoEPrepareAndFinalizeNoEP,
|
MoEPrepareAndFinalizeNoEP,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
|
AiterExperts,
|
||||||
|
)
|
||||||
|
|
||||||
config = self.get_fused_moe_quant_config(layer)
|
config = self.get_fused_moe_quant_config(layer)
|
||||||
assert config is not None
|
assert config is not None
|
||||||
self.moe_quant_config = config
|
self.moe_quant_config = config
|
||||||
use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
|
||||||
allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
|
||||||
moe_kernel = (
|
|
||||||
MarlinExperts(quant_config=self.moe_quant_config)
|
|
||||||
if use_marlin
|
|
||||||
else TritonOrDeepGemmExperts(
|
|
||||||
quant_config=self.moe_quant_config,
|
|
||||||
allow_deep_gemm=allow_deep_gemm,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.kernel = mk.FusedMoEModularKernel(
|
if self.fp8_backend == Fp8MoeBackend.AITER:
|
||||||
MoEPrepareAndFinalizeNoEP(), moe_kernel
|
self.kernel = mk.FusedMoEModularKernel(
|
||||||
)
|
# TODO: make defer_input_quant an attr of the AiterExperts
|
||||||
|
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
|
||||||
|
AiterExperts(quant_config=self.moe_quant_config),
|
||||||
|
)
|
||||||
|
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||||
|
self.kernel = mk.FusedMoEModularKernel(
|
||||||
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
|
MarlinExperts(quant_config=self.moe_quant_config),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.kernel = mk.FusedMoEModularKernel(
|
||||||
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
|
TritonOrDeepGemmExperts(
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||||
|
),
|
||||||
|
)
|
||||||
self.use_inplace = True
|
self.use_inplace = True
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
@ -1128,7 +1139,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||||
if (
|
if (
|
||||||
self.rocm_aiter_moe_enabled
|
self.fp8_backend == Fp8MoeBackend.AITER
|
||||||
or self.fp8_backend == Fp8MoeBackend.MARLIN
|
or self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||||
):
|
):
|
||||||
@ -1161,11 +1172,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
TritonOrDeepGemmExperts,
|
TritonOrDeepGemmExperts,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
|
||||||
self.fp8_backend != Fp8MoeBackend.MARLIN
|
raise NotImplementedError(
|
||||||
) and not self.rocm_aiter_moe_enabled, (
|
"Marlin and ROCm AITER are not supported with all2all yet."
|
||||||
"Marlin and ROCm AITER are not supported with all2all yet."
|
)
|
||||||
)
|
|
||||||
|
|
||||||
assert self.moe_quant_config is not None
|
assert self.moe_quant_config is not None
|
||||||
|
|
||||||
@ -1313,37 +1323,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
)
|
)
|
||||||
|
result = self.kernel(
|
||||||
if self.rocm_aiter_moe_enabled:
|
x,
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
layer.w13_weight,
|
||||||
rocm_aiter_fused_experts,
|
layer.w2_weight,
|
||||||
)
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
# TODO(rob): convert this to MK.
|
inplace=self.use_inplace,
|
||||||
result = rocm_aiter_fused_experts(
|
activation=layer.activation,
|
||||||
x,
|
global_num_experts=layer.global_num_experts,
|
||||||
layer.w13_weight,
|
expert_map=layer.expert_map,
|
||||||
layer.w2_weight,
|
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||||
topk_weights=topk_weights,
|
)
|
||||||
topk_ids=topk_ids,
|
|
||||||
activation=layer.activation,
|
|
||||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
|
||||||
expert_map=layer.expert_map,
|
|
||||||
quant_config=self.moe_quant_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
result = self.kernel(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
inplace=self.use_inplace,
|
|
||||||
activation=layer.activation,
|
|
||||||
global_num_experts=layer.global_num_experts,
|
|
||||||
expert_map=layer.expert_map,
|
|
||||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -1456,15 +1447,10 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
|||||||
layer.w13_input_scale = None
|
layer.w13_input_scale = None
|
||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
self.rocm_aiter_moe_enabled = False
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Lazy import to avoid importing triton too early.
|
|
||||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16, quantize in place.
|
||||||
fp8_dtype = current_platform.fp8_dtype()
|
fp8_dtype = current_platform.fp8_dtype()
|
||||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||||
@ -1481,7 +1467,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
|||||||
replace_parameter(layer, "w2_weight", w2_weight)
|
replace_parameter(layer, "w2_weight", w2_weight)
|
||||||
|
|
||||||
# Reshuffle weights for AITER if needed.
|
# Reshuffle weights for AITER if needed.
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.fp8_backend == Fp8MoeBackend.AITER:
|
||||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||||
layer.w13_weight, layer.w2_weight
|
layer.w13_weight, layer.w2_weight
|
||||||
)
|
)
|
||||||
@ -1489,7 +1475,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
|||||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||||
|
|
||||||
# Rushuffle weights for MARLIN if needed.
|
# Rushuffle weights for MARLIN if needed.
|
||||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||||
prepare_moe_fp8_layer_for_marlin(
|
prepare_moe_fp8_layer_for_marlin(
|
||||||
layer, False, input_dtype=self.marlin_input_dtype
|
layer, False, input_dtype=self.marlin_input_dtype
|
||||||
)
|
)
|
||||||
|
|||||||
@ -143,7 +143,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
|||||||
|
|
||||||
query_start_loc = m.query_start_loc
|
query_start_loc = m.query_start_loc
|
||||||
context_lens = m.num_computed_tokens_cpu
|
context_lens = m.num_computed_tokens_cpu
|
||||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True)
|
||||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@ -355,6 +355,8 @@ class MLACommonPrefillMetadata:
|
|||||||
max_query_len: int
|
max_query_len: int
|
||||||
chunked_context: ChunkedContextMetadata | None = None
|
chunked_context: ChunkedContextMetadata | None = None
|
||||||
query_seq_lens: torch.Tensor | None = None
|
query_seq_lens: torch.Tensor | None = None
|
||||||
|
workspace_buffer: torch.Tensor | None = None
|
||||||
|
q_data_type: torch.dtype | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -539,6 +541,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
metadata_cls if metadata_cls is not None else MLACommonMetadata
|
metadata_cls if metadata_cls is not None else MLACommonMetadata
|
||||||
)
|
)
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
self.q_data_type = (
|
||||||
|
current_platform.fp8_dtype()
|
||||||
|
if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str)
|
||||||
|
else vllm_config.model_config.dtype
|
||||||
|
)
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
@ -558,6 +565,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
self.dcp_rank = 0
|
self.dcp_rank = 0
|
||||||
self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
|
self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
|
||||||
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
|
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
|
||||||
|
self.cp_kv_cache_interleave_size = parallel_config.cp_kv_cache_interleave_size
|
||||||
|
|
||||||
# Don't try to access the runner on AMD
|
# Don't try to access the runner on AMD
|
||||||
if self.aot_schedule:
|
if self.aot_schedule:
|
||||||
@ -681,7 +689,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
|
|
||||||
# For main run, qo_indptr == kv_indptr
|
# For main run, qo_indptr == kv_indptr
|
||||||
kv_indptr = qo_indptr.clone()
|
kv_indptr = qo_indptr.clone()
|
||||||
|
|
||||||
# Prepare main prefill
|
# Prepare main prefill
|
||||||
self._fi_prefill_main.plan(
|
self._fi_prefill_main.plan(
|
||||||
qo_indptr=qo_indptr,
|
qo_indptr=qo_indptr,
|
||||||
@ -694,7 +701,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
sm_scale=self._global_hyperparameters.sm_scale,
|
sm_scale=self._global_hyperparameters.sm_scale,
|
||||||
window_left=self._global_hyperparameters.window_left,
|
window_left=self._global_hyperparameters.window_left,
|
||||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||||
q_data_type=self.model_config.dtype,
|
q_data_type=self.q_data_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare context prefills
|
# Prepare context prefills
|
||||||
@ -713,7 +720,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
sm_scale=self._global_hyperparameters.sm_scale,
|
sm_scale=self._global_hyperparameters.sm_scale,
|
||||||
window_left=self._global_hyperparameters.window_left,
|
window_left=self._global_hyperparameters.window_left,
|
||||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||||
q_data_type=self.model_config.dtype,
|
q_data_type=self.q_data_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill.prefill_main = self._fi_prefill_main
|
prefill.prefill_main = self._fi_prefill_main
|
||||||
@ -722,8 +729,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
def _build_decode(
|
def _build_decode(
|
||||||
self,
|
self,
|
||||||
block_table_tensor: torch.Tensor,
|
block_table_tensor: torch.Tensor,
|
||||||
seq_lens_cpu: torch.Tensor,
|
|
||||||
seq_lens_device: torch.Tensor,
|
seq_lens_device: torch.Tensor,
|
||||||
|
max_seq_len: int,
|
||||||
query_start_loc_cpu: torch.Tensor,
|
query_start_loc_cpu: torch.Tensor,
|
||||||
query_start_loc_device: torch.Tensor,
|
query_start_loc_device: torch.Tensor,
|
||||||
num_decode_tokens: int,
|
num_decode_tokens: int,
|
||||||
@ -773,13 +780,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
|
||||||
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
|
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
|
||||||
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
|
|
||||||
|
|
||||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
|
||||||
|
|
||||||
num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
|
|
||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
split_decodes_and_prefills(
|
split_decodes_and_prefills(
|
||||||
@ -794,6 +795,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
|
|
||||||
prefill_metadata = None
|
prefill_metadata = None
|
||||||
if num_prefills > 0:
|
if num_prefills > 0:
|
||||||
|
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu
|
||||||
|
|
||||||
reqs_start = num_decodes # prefill_start
|
reqs_start = num_decodes # prefill_start
|
||||||
|
|
||||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||||
@ -970,6 +973,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
query_start_loc=prefill_query_start_loc,
|
query_start_loc=prefill_query_start_loc,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
chunked_context=chunked_context_metadata,
|
chunked_context=chunked_context_metadata,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._use_cudnn_prefill:
|
if self._use_cudnn_prefill:
|
||||||
@ -983,19 +987,29 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
prefill_metadata.query_seq_lens = (
|
prefill_metadata.query_seq_lens = (
|
||||||
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
|
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
|
||||||
)
|
)
|
||||||
|
prefill_metadata.workspace_buffer = self._workspace_buffer
|
||||||
|
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
dcp_tot_seq_lens_device = None
|
dcp_tot_seq_lens_device = None
|
||||||
if self.dcp_world_size > 1:
|
if self.dcp_world_size > 1:
|
||||||
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
|
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
|
||||||
seq_lens_cpu = dcp_local_seq_lens_cpu
|
|
||||||
seq_lens = dcp_local_seq_lens
|
seq_lens = dcp_local_seq_lens
|
||||||
|
|
||||||
|
# After DCP distribution, the maximum number of tokens for any rank is
|
||||||
|
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
|
||||||
|
# and I is cp_kv_cache_interleave_size.
|
||||||
|
# This eliminates GPU->CPU sync while minimizing workspace
|
||||||
|
# over-allocation.
|
||||||
|
num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
|
||||||
|
max_seq_len = (
|
||||||
|
(max_seq_len + num_partitions - 1) // num_partitions
|
||||||
|
) * self.cp_kv_cache_interleave_size
|
||||||
|
|
||||||
decode_metadata = self._build_decode(
|
decode_metadata = self._build_decode(
|
||||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||||
seq_lens_cpu=seq_lens_cpu[:num_decodes],
|
|
||||||
seq_lens_device=seq_lens[:num_decodes],
|
seq_lens_device=seq_lens[:num_decodes],
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
|
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
|
||||||
query_start_loc_device=query_start_loc[: num_decodes + 1],
|
query_start_loc_device=query_start_loc[: num_decodes + 1],
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
@ -1370,8 +1384,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
return attn_out
|
return attn_out
|
||||||
|
|
||||||
def _run_prefill_new_tokens_fa(
|
def _run_prefill_new_tokens_fa(
|
||||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
self,
|
||||||
|
prefill: MLACommonPrefillMetadata,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
return_softmax_lse,
|
||||||
|
fp8_attention: bool,
|
||||||
):
|
):
|
||||||
|
logger.debug_once("Running FlashAttention prefill new tokens", scope="local")
|
||||||
return self._flash_attn_varlen_diff_headdims(
|
return self._flash_attn_varlen_diff_headdims(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
@ -1386,11 +1407,23 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _run_prefill_new_tokens_fi(
|
def _run_prefill_new_tokens_fi(
|
||||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
self,
|
||||||
|
prefill: MLACommonPrefillMetadata,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
return_softmax_lse,
|
||||||
|
fp8_attention: bool,
|
||||||
):
|
):
|
||||||
|
logger.debug_once("Running FlashInfer prefill new tokens", scope="local")
|
||||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||||
assert prefill.prefill_main is not None
|
assert prefill.prefill_main is not None
|
||||||
|
if fp8_attention:
|
||||||
|
logger.debug_once("Running Flashinfer prefill in FP8")
|
||||||
|
fp8_dtype = current_platform.fp8_dtype()
|
||||||
|
q = q.to(fp8_dtype)
|
||||||
|
k = k.to(fp8_dtype)
|
||||||
|
v = v.to(fp8_dtype)
|
||||||
ret = prefill.prefill_main.run(
|
ret = prefill.prefill_main.run(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
@ -1403,10 +1436,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def _run_prefill_new_tokens_cudnn(
|
def _run_prefill_new_tokens_cudnn(
|
||||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
self,
|
||||||
|
prefill: MLACommonPrefillMetadata,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
return_softmax_lse,
|
||||||
|
fp8_attention: bool,
|
||||||
):
|
):
|
||||||
|
logger.debug_once("Running Cudnn prefill new tokens", scope="local")
|
||||||
assert isinstance(prefill, CudnnPrefillMetadata)
|
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||||
assert prefill.query_seq_lens is not None
|
assert prefill.query_seq_lens is not None
|
||||||
|
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
|
||||||
output, lse = cudnn_batch_prefill_with_kv_cache(
|
output, lse = cudnn_batch_prefill_with_kv_cache(
|
||||||
q=q,
|
q=q,
|
||||||
k_cache=k,
|
k_cache=k,
|
||||||
@ -1428,9 +1469,19 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def _run_prefill_context_chunk_fa(
|
def _run_prefill_context_chunk_fa(
|
||||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
self,
|
||||||
|
prefill: MLACommonPrefillMetadata,
|
||||||
|
chunk_idx: int,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
fp8_attention: bool,
|
||||||
):
|
):
|
||||||
|
logger.debug_once("Running FlashAttention prefill context chunk", scope="local")
|
||||||
assert prefill.chunked_context is not None
|
assert prefill.chunked_context is not None
|
||||||
|
assert fp8_attention is False, (
|
||||||
|
"FlashAttention prefill does not support fp8 attention"
|
||||||
|
)
|
||||||
return self._flash_attn_varlen_diff_headdims(
|
return self._flash_attn_varlen_diff_headdims(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
@ -1445,10 +1496,22 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _run_prefill_context_chunk_fi(
|
def _run_prefill_context_chunk_fi(
|
||||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
self,
|
||||||
|
prefill: MLACommonPrefillMetadata,
|
||||||
|
chunk_idx: int,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
fp8_attention: bool,
|
||||||
):
|
):
|
||||||
|
logger.debug_once("Running FlashInfer prefill context chunk", scope="local")
|
||||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||||
|
if fp8_attention:
|
||||||
|
logger.debug_once("Running FlashInfer prefill in FP8")
|
||||||
|
fp8_dtype = current_platform.fp8_dtype()
|
||||||
|
q = q.to(fp8_dtype)
|
||||||
|
k = k.to(fp8_dtype)
|
||||||
|
v = v.to(fp8_dtype)
|
||||||
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
|
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
@ -1460,12 +1523,20 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
return attn_out, lse.transpose(0, 1).contiguous()
|
return attn_out, lse.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
def _run_prefill_context_chunk_cudnn(
|
def _run_prefill_context_chunk_cudnn(
|
||||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
self,
|
||||||
|
prefill: MLACommonPrefillMetadata,
|
||||||
|
chunk_idx: int,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
fp8_attention: bool,
|
||||||
):
|
):
|
||||||
|
logger.debug_once("Running Cudnn prefill context chunk", scope="local")
|
||||||
assert isinstance(prefill, CudnnPrefillMetadata)
|
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||||
assert prefill.chunked_context is not None
|
assert prefill.chunked_context is not None
|
||||||
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
||||||
assert prefill.query_seq_lens is not None
|
assert prefill.query_seq_lens is not None
|
||||||
|
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
|
||||||
return cudnn_batch_prefill_with_kv_cache(
|
return cudnn_batch_prefill_with_kv_cache(
|
||||||
q=q,
|
q=q,
|
||||||
k_cache=k,
|
k_cache=k,
|
||||||
@ -1485,18 +1556,33 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _run_prefill_new_tokens_trtllm_ragged(
|
def _run_prefill_new_tokens_trtllm_ragged(
|
||||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
self,
|
||||||
|
prefill: MLACommonPrefillMetadata,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
return_softmax_lse,
|
||||||
|
fp8_attention: bool,
|
||||||
):
|
):
|
||||||
|
logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local")
|
||||||
"""TRT-LLM ragged attention for new tokens (causal)."""
|
"""TRT-LLM ragged attention for new tokens (causal)."""
|
||||||
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
||||||
|
|
||||||
assert prefill.query_seq_lens is not None
|
assert prefill.query_seq_lens is not None
|
||||||
|
assert prefill.workspace_buffer is not None
|
||||||
|
|
||||||
|
if fp8_attention:
|
||||||
|
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
|
||||||
|
fp8_dtype = current_platform.fp8_dtype()
|
||||||
|
q = q.to(fp8_dtype)
|
||||||
|
k = k.to(fp8_dtype)
|
||||||
|
v = v.to(fp8_dtype)
|
||||||
|
|
||||||
ret = trtllm_ragged_attention_deepseek(
|
ret = trtllm_ragged_attention_deepseek(
|
||||||
query=q,
|
query=q,
|
||||||
key=k,
|
key=k,
|
||||||
value=v,
|
value=v,
|
||||||
workspace_buffer=self._workspace_buffer,
|
workspace_buffer=prefill.workspace_buffer,
|
||||||
seq_lens=prefill.query_seq_lens,
|
seq_lens=prefill.query_seq_lens,
|
||||||
max_q_len=prefill.max_query_len,
|
max_q_len=prefill.max_query_len,
|
||||||
max_kv_len=prefill.max_query_len,
|
max_kv_len=prefill.max_query_len,
|
||||||
@ -1518,13 +1604,21 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def _run_prefill_context_chunk_trtllm_ragged(
|
def _run_prefill_context_chunk_trtllm_ragged(
|
||||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
self,
|
||||||
|
prefill: MLACommonPrefillMetadata,
|
||||||
|
chunk_idx: int,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
fp8_attention: bool,
|
||||||
):
|
):
|
||||||
|
logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local")
|
||||||
"""TRT-LLM ragged attention for context chunks (non-causal)."""
|
"""TRT-LLM ragged attention for context chunks (non-causal)."""
|
||||||
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
||||||
|
|
||||||
assert prefill.chunked_context is not None
|
assert prefill.chunked_context is not None
|
||||||
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
||||||
|
assert prefill.workspace_buffer is not None
|
||||||
|
|
||||||
out = torch.zeros(
|
out = torch.zeros(
|
||||||
q.shape[0],
|
q.shape[0],
|
||||||
@ -1533,13 +1627,20 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
device=q.device,
|
device=q.device,
|
||||||
dtype=q.dtype,
|
dtype=q.dtype,
|
||||||
)
|
)
|
||||||
self._workspace_buffer.fill_(0)
|
prefill.workspace_buffer.fill_(0)
|
||||||
|
|
||||||
|
if fp8_attention:
|
||||||
|
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
|
||||||
|
fp8_dtype = current_platform.fp8_dtype()
|
||||||
|
q = q.to(fp8_dtype)
|
||||||
|
k = k.to(fp8_dtype)
|
||||||
|
v = v.to(fp8_dtype)
|
||||||
|
|
||||||
attn_out, lse = trtllm_ragged_attention_deepseek(
|
attn_out, lse = trtllm_ragged_attention_deepseek(
|
||||||
query=q,
|
query=q,
|
||||||
key=k,
|
key=k,
|
||||||
value=v,
|
value=v,
|
||||||
workspace_buffer=self._workspace_buffer,
|
workspace_buffer=prefill.workspace_buffer,
|
||||||
seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
|
seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
|
||||||
max_q_len=prefill.max_query_len,
|
max_q_len=prefill.max_query_len,
|
||||||
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
|
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
|
||||||
@ -1687,6 +1788,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
k_scale: torch.Tensor,
|
k_scale: torch.Tensor,
|
||||||
|
fp8_attention: bool,
|
||||||
):
|
):
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
prefill_metadata = attn_metadata.prefill
|
prefill_metadata = attn_metadata.prefill
|
||||||
@ -1725,6 +1827,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
|
fp8_attention=fp8_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
if output is None:
|
if output is None:
|
||||||
@ -1753,6 +1856,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
k_scale: torch.Tensor,
|
k_scale: torch.Tensor,
|
||||||
dcp_world_size: int,
|
dcp_world_size: int,
|
||||||
|
fp8_attention: bool,
|
||||||
):
|
):
|
||||||
assert k_scale is None, "DCP not support scaled kvcache now."
|
assert k_scale is None, "DCP not support scaled kvcache now."
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
@ -1829,6 +1933,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
|
fp8_attention=fp8_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
if output is None:
|
if output is None:
|
||||||
@ -1859,6 +1964,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
k_scale: torch.Tensor,
|
k_scale: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
|
fp8_attention: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO (zyongye): Prefill function here
|
# TODO (zyongye): Prefill function here
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
@ -1878,6 +1984,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
return_softmax_lse=has_context,
|
return_softmax_lse=has_context,
|
||||||
|
fp8_attention=fp8_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_context:
|
if has_context:
|
||||||
@ -1890,11 +1997,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
attn_metadata,
|
attn_metadata,
|
||||||
k_scale=None,
|
k_scale=None,
|
||||||
dcp_world_size=self.dcp_world_size,
|
dcp_world_size=self.dcp_world_size,
|
||||||
|
fp8_attention=fp8_attention,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context_output, context_lse = self._compute_prefill_context(
|
context_output, context_lse = self._compute_prefill_context(
|
||||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
|
q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention
|
||||||
)
|
)
|
||||||
|
|
||||||
# unpad if necessary
|
# unpad if necessary
|
||||||
@ -2015,6 +2123,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
attn_metadata,
|
attn_metadata,
|
||||||
layer._k_scale,
|
layer._k_scale,
|
||||||
output=output[num_decode_tokens:],
|
output=output[num_decode_tokens:],
|
||||||
|
fp8_attention=fp8_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
|
|||||||
@ -169,8 +169,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
def _build_decode(
|
def _build_decode(
|
||||||
self,
|
self,
|
||||||
block_table_tensor: torch.Tensor,
|
block_table_tensor: torch.Tensor,
|
||||||
seq_lens_cpu: torch.Tensor,
|
|
||||||
seq_lens_device: torch.Tensor,
|
seq_lens_device: torch.Tensor,
|
||||||
|
max_seq_len: int,
|
||||||
query_start_loc_cpu: torch.Tensor,
|
query_start_loc_cpu: torch.Tensor,
|
||||||
query_start_loc_device: torch.Tensor,
|
query_start_loc_device: torch.Tensor,
|
||||||
num_decode_tokens: int,
|
num_decode_tokens: int,
|
||||||
@ -178,7 +178,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
) -> FlashAttnMLADecodeMetadata:
|
) -> FlashAttnMLADecodeMetadata:
|
||||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||||
max_query_len = query_lens_cpu.max().item()
|
max_query_len = query_lens_cpu.max().item()
|
||||||
max_seq_len = seq_lens_cpu.max().item()
|
|
||||||
|
|
||||||
# For Flash Attention MLA + full cudagraph
|
# For Flash Attention MLA + full cudagraph
|
||||||
max_num_splits = 0
|
max_num_splits = 0
|
||||||
@ -193,7 +192,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
max_num_splits = 1
|
max_num_splits = 1
|
||||||
|
|
||||||
scheduler_metadata = self._schedule_decode(
|
scheduler_metadata = self._schedule_decode(
|
||||||
num_reqs=seq_lens_cpu.numel(),
|
num_reqs=seq_lens_device.shape[0],
|
||||||
cu_query_lens=query_start_loc_device,
|
cu_query_lens=query_start_loc_device,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
seqlens=seq_lens_device,
|
seqlens=seq_lens_device,
|
||||||
|
|||||||
@ -143,8 +143,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
def _build_decode(
|
def _build_decode(
|
||||||
self,
|
self,
|
||||||
block_table_tensor: torch.Tensor,
|
block_table_tensor: torch.Tensor,
|
||||||
seq_lens_cpu: torch.Tensor,
|
|
||||||
seq_lens_device: torch.Tensor,
|
seq_lens_device: torch.Tensor,
|
||||||
|
max_seq_len: int,
|
||||||
query_start_loc_cpu: torch.Tensor,
|
query_start_loc_cpu: torch.Tensor,
|
||||||
query_start_loc_device: torch.Tensor,
|
query_start_loc_device: torch.Tensor,
|
||||||
num_decode_tokens: int,
|
num_decode_tokens: int,
|
||||||
|
|||||||
@ -106,8 +106,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
def _build_decode(
|
def _build_decode(
|
||||||
self,
|
self,
|
||||||
block_table_tensor: torch.Tensor,
|
block_table_tensor: torch.Tensor,
|
||||||
seq_lens_cpu: torch.Tensor,
|
|
||||||
seq_lens_device: torch.Tensor,
|
seq_lens_device: torch.Tensor,
|
||||||
|
max_seq_len: int,
|
||||||
query_start_loc_cpu: torch.Tensor,
|
query_start_loc_cpu: torch.Tensor,
|
||||||
query_start_loc_device: torch.Tensor,
|
query_start_loc_device: torch.Tensor,
|
||||||
num_decode_tokens: int,
|
num_decode_tokens: int,
|
||||||
|
|||||||
@ -236,6 +236,7 @@ class EagleProposer:
|
|||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
|
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
|
||||||
|
num_rejected_tokens_gpu: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_tokens = target_token_ids.shape[0]
|
num_tokens = target_token_ids.shape[0]
|
||||||
batch_size = next_token_ids.shape[0]
|
batch_size = next_token_ids.shape[0]
|
||||||
@ -414,6 +415,17 @@ class EagleProposer:
|
|||||||
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||||||
self.token_arange_np[: batch_size + 1]
|
self.token_arange_np[: batch_size + 1]
|
||||||
).clone()
|
).clone()
|
||||||
|
|
||||||
|
# In padded drafter batch, we need to adjust the sequence lengths
|
||||||
|
# to remove the "padding" (i.e. rejected tokens).
|
||||||
|
# Only apply this adjustment when we have rejected tokens
|
||||||
|
# (i.e., not the first proposal).
|
||||||
|
if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
|
||||||
|
common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
|
||||||
|
# Invalidate the CPU-side shadows to avoid H<>D sync.
|
||||||
|
common_attn_metadata._seq_lens_cpu = None
|
||||||
|
common_attn_metadata._num_computed_tokens_cpu = None
|
||||||
|
|
||||||
for token_index in range(self.num_speculative_tokens - 1):
|
for token_index in range(self.num_speculative_tokens - 1):
|
||||||
# Update the inputs.
|
# Update the inputs.
|
||||||
# cast to int32 is crucial when eagle model is compiled.
|
# cast to int32 is crucial when eagle model is compiled.
|
||||||
@ -628,13 +640,14 @@ class EagleProposer:
|
|||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
spec_decode_metadata: SpecDecodeMetadata,
|
spec_decode_metadata: SpecDecodeMetadata,
|
||||||
valid_sampled_tokens_count: torch.Tensor,
|
valid_sampled_tokens_count: torch.Tensor,
|
||||||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
This function is used to prepare the inputs for speculative decoding
|
This function is used to prepare the inputs for speculative decoding
|
||||||
It updates the common_attn_metadata for speculative decoding,
|
It updates the common_attn_metadata for speculative decoding,
|
||||||
but does not consider the rejected tokens. Instead, all tokens
|
but does not consider the rejected tokens. Instead, all tokens
|
||||||
are included as inputs to the speculator, with the rejected tokens
|
are included as inputs to the speculator, with the rejected tokens
|
||||||
used as padding and filtered out later by `token_indices_to_sample`.
|
used as padding and filtered out later by `token_indices_to_sample`.
|
||||||
|
No blocking CPU operations should be introduced in this function.
|
||||||
"""
|
"""
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
device = valid_sampled_tokens_count.device
|
device = valid_sampled_tokens_count.device
|
||||||
@ -642,14 +655,17 @@ class EagleProposer:
|
|||||||
token_indices_to_sample = torch.empty(
|
token_indices_to_sample = torch.empty(
|
||||||
(num_reqs,), dtype=torch.int32, device=device
|
(num_reqs,), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
num_rejected_tokens_gpu = torch.empty(
|
||||||
|
(num_reqs,), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
# Kernel grid: one program per request (row)
|
|
||||||
grid = (num_reqs,)
|
grid = (num_reqs,)
|
||||||
eagle_prepare_inputs_padded_kernel[grid](
|
eagle_prepare_inputs_padded_kernel[grid](
|
||||||
spec_decode_metadata.cu_num_draft_tokens,
|
spec_decode_metadata.cu_num_draft_tokens,
|
||||||
valid_sampled_tokens_count,
|
valid_sampled_tokens_count,
|
||||||
common_attn_metadata.query_start_loc,
|
common_attn_metadata.query_start_loc,
|
||||||
token_indices_to_sample,
|
token_indices_to_sample,
|
||||||
|
num_rejected_tokens_gpu,
|
||||||
num_reqs,
|
num_reqs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -674,7 +690,11 @@ class EagleProposer:
|
|||||||
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
|
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
return spec_common_attn_metadata, token_indices_to_sample
|
return (
|
||||||
|
spec_common_attn_metadata,
|
||||||
|
token_indices_to_sample,
|
||||||
|
num_rejected_tokens_gpu,
|
||||||
|
)
|
||||||
|
|
||||||
def propose_tree(
|
def propose_tree(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -23,6 +23,7 @@ def eagle_prepare_inputs_padded_kernel(
|
|||||||
valid_sampled_tokens_count_ptr, # [num_reqs]
|
valid_sampled_tokens_count_ptr, # [num_reqs]
|
||||||
query_start_loc_gpu_ptr, # [num_reqs + 1]
|
query_start_loc_gpu_ptr, # [num_reqs + 1]
|
||||||
token_indices_to_sample_ptr, # [num_reqs] (output)
|
token_indices_to_sample_ptr, # [num_reqs] (output)
|
||||||
|
num_rejected_tokens_gpu_ptr, # [num_reqs] (output)
|
||||||
num_reqs, # tl.int32
|
num_reqs, # tl.int32
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -56,6 +57,7 @@ def eagle_prepare_inputs_padded_kernel(
|
|||||||
|
|
||||||
index_to_sample = q_last_tok_idx - num_rejected_tokens
|
index_to_sample = q_last_tok_idx - num_rejected_tokens
|
||||||
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
|
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
|
||||||
|
tl.store(num_rejected_tokens_gpu_ptr + req_idx, num_rejected_tokens)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
|||||||
@ -3534,6 +3534,7 @@ class GPUModelRunner(
|
|||||||
next_token_ids, valid_sampled_tokens_count
|
next_token_ids, valid_sampled_tokens_count
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_rejected_tokens_gpu = None
|
||||||
if spec_decode_metadata is None:
|
if spec_decode_metadata is None:
|
||||||
token_indices_to_sample = None
|
token_indices_to_sample = None
|
||||||
# input_ids can be None for multimodal models.
|
# input_ids can be None for multimodal models.
|
||||||
@ -3564,12 +3565,14 @@ class GPUModelRunner(
|
|||||||
else:
|
else:
|
||||||
target_hidden_states = hidden_states[token_indices]
|
target_hidden_states = hidden_states[token_indices]
|
||||||
else:
|
else:
|
||||||
common_attn_metadata, token_indices_to_sample = (
|
(
|
||||||
self.drafter.prepare_inputs_padded(
|
common_attn_metadata,
|
||||||
common_attn_metadata,
|
token_indices_to_sample,
|
||||||
spec_decode_metadata,
|
num_rejected_tokens_gpu,
|
||||||
valid_sampled_tokens_count,
|
) = self.drafter.prepare_inputs_padded(
|
||||||
)
|
common_attn_metadata,
|
||||||
|
spec_decode_metadata,
|
||||||
|
valid_sampled_tokens_count,
|
||||||
)
|
)
|
||||||
total_num_tokens = common_attn_metadata.num_actual_tokens
|
total_num_tokens = common_attn_metadata.num_actual_tokens
|
||||||
# When padding the batch, token_indices is just a range
|
# When padding the batch, token_indices is just a range
|
||||||
@ -3600,6 +3603,7 @@ class GPUModelRunner(
|
|||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
mm_embed_inputs=mm_embed_inputs,
|
mm_embed_inputs=mm_embed_inputs,
|
||||||
|
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user