mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 16:45:01 +08:00
[CI] Add E2E Blackwell Quantized MoE Test (#25723)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
5157781987
commit
b6f16d37b0
@ -522,7 +522,7 @@ steps:
|
|||||||
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
||||||
# we can only upgrade after this is resolved
|
# we can only upgrade after this is resolved
|
||||||
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
|
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/
|
||||||
|
|
||||||
- label: LM Eval Small Models # 53min
|
- label: LM Eval Small Models # 53min
|
||||||
timeout_in_minutes: 75
|
timeout_in_minutes: 75
|
||||||
@ -830,6 +830,23 @@ steps:
|
|||||||
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
||||||
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 --server-args '--tensor-parallel-size 2'
|
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 --server-args '--tensor-parallel-size 2'
|
||||||
|
|
||||||
|
- label: Blackwell Quantized MoE Test
|
||||||
|
timeout_in_minutes: 60
|
||||||
|
working_dir: "/vllm-workspace/"
|
||||||
|
gpu: b200
|
||||||
|
source_file_dependencies:
|
||||||
|
- tests/quantization/test_blackwell_moe.py
|
||||||
|
- vllm/model_executor/models/deepseek_v2.py
|
||||||
|
- vllm/model_executor/models/gpt_oss.py
|
||||||
|
- vllm/model_executor/models/llama4.py
|
||||||
|
- vllm/model_executor/layers/fused_moe
|
||||||
|
- vllm/model_executor/layers/quantization/compressed_tensors
|
||||||
|
- vllm/model_executor/layers/quantization/modelopt.py
|
||||||
|
- vllm/model_executor/layers/quantization/mxfp4.py
|
||||||
|
- vllm/v1/attention/backends/flashinfer.py
|
||||||
|
commands:
|
||||||
|
- pytest -s -v tests/quantization/test_blackwell_moe.py
|
||||||
|
|
||||||
##### 1 GPU test #####
|
##### 1 GPU test #####
|
||||||
##### multi gpus test #####
|
##### multi gpus test #####
|
||||||
|
|
||||||
|
|||||||
132
tests/quantization/test_blackwell_moe.py
Normal file
132
tests/quantization/test_blackwell_moe.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if not current_platform.is_device_capability(100):
|
||||||
|
pytest.skip("This test only runs on Blackwell GPUs (SM100).",
|
||||||
|
allow_module_level=True)
|
||||||
|
|
||||||
|
os.environ["FLASHINFER_NVCC_THREADS"] = "16"
|
||||||
|
|
||||||
|
# dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4,
|
||||||
|
# "text_config": {"num_layers": 4, "num_hidden_layers": 4}}
|
||||||
|
dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4}
|
||||||
|
|
||||||
|
|
||||||
|
def can_initialize(model: str, extra_args: list[str]):
|
||||||
|
|
||||||
|
# Server arguments
|
||||||
|
server_args = [
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--max-num-batched-tokens",
|
||||||
|
"256",
|
||||||
|
"--load-format",
|
||||||
|
"dummy",
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--limit-mm-per-prompt",
|
||||||
|
json.dumps({"image": 0}),
|
||||||
|
*extra_args,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Launch server and make a simple request
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
model,
|
||||||
|
server_args,
|
||||||
|
max_wait_seconds=1000, # Due to FlashInfer compile
|
||||||
|
override_hf_configs=dummy_hf_overrides) as server:
|
||||||
|
client = server.get_client()
|
||||||
|
# Make a simple request to verify the server works
|
||||||
|
completion = client.completions.create(
|
||||||
|
model=model,
|
||||||
|
prompt=["Hello, World!"],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=2,
|
||||||
|
)
|
||||||
|
print(completion)
|
||||||
|
assert completion.choices[0].text is not None
|
||||||
|
|
||||||
|
|
||||||
|
## Llama4 ##
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason=(
|
||||||
|
"RuntimeError: run_moe() Expected a value of type "
|
||||||
|
"'Optional[List[Tensor]]' for argument '_9' but instead found type "
|
||||||
|
"'list'."))
|
||||||
|
def test_llama4_fp8_tensor_moe_flashinfer_cutlass(
|
||||||
|
monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
|
||||||
|
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
|
||||||
|
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", [])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Works, but takes too long to run")
|
||||||
|
def test_llama4_fp8_tensor_moe_flashinfer_trtllm(
|
||||||
|
monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
|
||||||
|
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
|
||||||
|
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", [])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Works, but takes too long to run")
|
||||||
|
def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||||
|
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
|
||||||
|
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", [])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options")
|
||||||
|
def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||||
|
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
|
||||||
|
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", [])
|
||||||
|
|
||||||
|
|
||||||
|
## DeepSeekV3 ##
|
||||||
|
|
||||||
|
|
||||||
|
def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||||
|
can_initialize("deepseek-ai/DeepSeek-V3.1", [])
|
||||||
|
|
||||||
|
|
||||||
|
def test_deepseek_nvfp4_moe_flashinfer_cutlass(
|
||||||
|
monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||||
|
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
|
||||||
|
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", [])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options")
|
||||||
|
def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||||
|
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
|
||||||
|
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", [])
|
||||||
|
|
||||||
|
|
||||||
|
## GPT-OSS ##
|
||||||
|
|
||||||
|
|
||||||
|
def test_gptoss_mxfp4bf16_moe_flashinfer(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
|
||||||
|
can_initialize("openai/gpt-oss-20b", [])
|
||||||
|
|
||||||
|
|
||||||
|
def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(
|
||||||
|
monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "1")
|
||||||
|
can_initialize("openai/gpt-oss-20b", [])
|
||||||
|
|
||||||
|
|
||||||
|
def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(
|
||||||
|
monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
|
||||||
|
can_initialize("openai/gpt-oss-20b", [])
|
||||||
@ -91,8 +91,10 @@ class RemoteOpenAIServer:
|
|||||||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||||
if env_dict is not None:
|
if env_dict is not None:
|
||||||
env.update(env_dict)
|
env.update(env_dict)
|
||||||
|
serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
|
||||||
|
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
|
||||||
self.proc: subprocess.Popen = subprocess.Popen(
|
self.proc: subprocess.Popen = subprocess.Popen(
|
||||||
["vllm", "serve", model, *vllm_serve_args],
|
serve_cmd,
|
||||||
env=env,
|
env=env,
|
||||||
stdout=sys.stdout,
|
stdout=sys.stdout,
|
||||||
stderr=sys.stderr,
|
stderr=sys.stderr,
|
||||||
|
|||||||
@ -40,6 +40,8 @@ def flashinfer_fused_moe_blockscale_fp8(
|
|||||||
assert global_num_experts % 4 == 0
|
assert global_num_experts % 4 == 0
|
||||||
assert top_k < (topk_group * global_num_experts / num_expert_group)
|
assert top_k < (topk_group * global_num_experts / num_expert_group)
|
||||||
assert block_shape == [128, 128]
|
assert block_shape == [128, 128]
|
||||||
|
# Routing kernel expects #experts <= #threads 256
|
||||||
|
assert global_num_experts <= 256
|
||||||
|
|
||||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||||
# NOTE: scales of hidden states have to be transposed!
|
# NOTE: scales of hidden states have to be transposed!
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user