mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 12:35:31 +08:00
[CI] Add E2E Blackwell Quantized MoE Test (#25723)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
5157781987
commit
b6f16d37b0
@ -522,7 +522,7 @@ steps:
|
||||
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
||||
# we can only upgrade after this is resolved
|
||||
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/
|
||||
|
||||
- label: LM Eval Small Models # 53min
|
||||
timeout_in_minutes: 75
|
||||
@ -830,6 +830,23 @@ steps:
|
||||
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
||||
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 --server-args '--tensor-parallel-size 2'
|
||||
|
||||
- label: Blackwell Quantized MoE Test
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
source_file_dependencies:
|
||||
- tests/quantization/test_blackwell_moe.py
|
||||
- vllm/model_executor/models/deepseek_v2.py
|
||||
- vllm/model_executor/models/gpt_oss.py
|
||||
- vllm/model_executor/models/llama4.py
|
||||
- vllm/model_executor/layers/fused_moe
|
||||
- vllm/model_executor/layers/quantization/compressed_tensors
|
||||
- vllm/model_executor/layers/quantization/modelopt.py
|
||||
- vllm/model_executor/layers/quantization/mxfp4.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
commands:
|
||||
- pytest -s -v tests/quantization/test_blackwell_moe.py
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
|
||||
132
tests/quantization/test_blackwell_moe.py
Normal file
132
tests/quantization/test_blackwell_moe.py
Normal file
@ -0,0 +1,132 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
pytest.skip("This test only runs on Blackwell GPUs (SM100).",
|
||||
allow_module_level=True)
|
||||
|
||||
os.environ["FLASHINFER_NVCC_THREADS"] = "16"
|
||||
|
||||
# dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4,
|
||||
# "text_config": {"num_layers": 4, "num_hidden_layers": 4}}
|
||||
dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4}
|
||||
|
||||
|
||||
def can_initialize(model: str, extra_args: list[str]):
|
||||
|
||||
# Server arguments
|
||||
server_args = [
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-batched-tokens",
|
||||
"256",
|
||||
"--load-format",
|
||||
"dummy",
|
||||
"--trust-remote-code",
|
||||
"--limit-mm-per-prompt",
|
||||
json.dumps({"image": 0}),
|
||||
*extra_args,
|
||||
]
|
||||
|
||||
# Launch server and make a simple request
|
||||
with RemoteOpenAIServer(
|
||||
model,
|
||||
server_args,
|
||||
max_wait_seconds=1000, # Due to FlashInfer compile
|
||||
override_hf_configs=dummy_hf_overrides) as server:
|
||||
client = server.get_client()
|
||||
# Make a simple request to verify the server works
|
||||
completion = client.completions.create(
|
||||
model=model,
|
||||
prompt=["Hello, World!"],
|
||||
temperature=0,
|
||||
max_tokens=2,
|
||||
)
|
||||
print(completion)
|
||||
assert completion.choices[0].text is not None
|
||||
|
||||
|
||||
## Llama4 ##
|
||||
|
||||
|
||||
@pytest.mark.skip(reason=(
|
||||
"RuntimeError: run_moe() Expected a value of type "
|
||||
"'Optional[List[Tensor]]' for argument '_9' but instead found type "
|
||||
"'list'."))
|
||||
def test_llama4_fp8_tensor_moe_flashinfer_cutlass(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
|
||||
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", [])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Works, but takes too long to run")
|
||||
def test_llama4_fp8_tensor_moe_flashinfer_trtllm(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
|
||||
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", [])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Works, but takes too long to run")
|
||||
def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
|
||||
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", [])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options")
|
||||
def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
|
||||
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", [])
|
||||
|
||||
|
||||
## DeepSeekV3 ##
|
||||
|
||||
|
||||
def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
can_initialize("deepseek-ai/DeepSeek-V3.1", [])
|
||||
|
||||
|
||||
def test_deepseek_nvfp4_moe_flashinfer_cutlass(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
|
||||
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", [])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options")
|
||||
def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
|
||||
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", [])
|
||||
|
||||
|
||||
## GPT-OSS ##
|
||||
|
||||
|
||||
def test_gptoss_mxfp4bf16_moe_flashinfer(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
|
||||
can_initialize("openai/gpt-oss-20b", [])
|
||||
|
||||
|
||||
def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "1")
|
||||
can_initialize("openai/gpt-oss-20b", [])
|
||||
|
||||
|
||||
def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
|
||||
can_initialize("openai/gpt-oss-20b", [])
|
||||
@ -91,8 +91,10 @@ class RemoteOpenAIServer:
|
||||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||
if env_dict is not None:
|
||||
env.update(env_dict)
|
||||
serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
|
||||
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
|
||||
self.proc: subprocess.Popen = subprocess.Popen(
|
||||
["vllm", "serve", model, *vllm_serve_args],
|
||||
serve_cmd,
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
|
||||
@ -40,6 +40,8 @@ def flashinfer_fused_moe_blockscale_fp8(
|
||||
assert global_num_experts % 4 == 0
|
||||
assert top_k < (topk_group * global_num_experts / num_expert_group)
|
||||
assert block_shape == [128, 128]
|
||||
# Routing kernel expects #experts <= #threads 256
|
||||
assert global_num_experts <= 256
|
||||
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user