mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[CI] Initial tests for SM100 Blackwell runner (#21877)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
881e1af43a
commit
88faa466d7
@ -647,13 +647,31 @@ steps:
|
||||
- label: Blackwell Test
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
optional: true
|
||||
# optional: true
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/
|
||||
- csrc/quantization/fp4/
|
||||
- csrc/attention/mla/
|
||||
- csrc/quantization/cutlass_w8a8/moe/
|
||||
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/fusion.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
# Attention
|
||||
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py
|
||||
- pytest -v -s tests/kernels/test_cutlass_mla_decode.py
|
||||
# Quantization
|
||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
# Fusion
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
@ -136,12 +136,15 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("test_model", [
|
||||
TestAllReduceRMSNormModel,
|
||||
TestAllReduceFusedAddRMSNormModel,
|
||||
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
|
||||
TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"test_model",
|
||||
[
|
||||
TestAllReduceRMSNormModel,
|
||||
TestAllReduceFusedAddRMSNormModel,
|
||||
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
|
||||
# TODO: Enable with torch==2.8.0
|
||||
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [8])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
|
||||
@ -559,8 +559,6 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
m_a_scales = m_g if per_act_token else 1
|
||||
n_b_scales = n_g if per_out_ch else 1
|
||||
|
||||
print("shape:", m_g, n_g, k_g)
|
||||
|
||||
# Create group-specific A and B (FP8) and output (FP16/FP32)
|
||||
a_g = to_fp8(torch.randn((m_g, k_g), device=device))
|
||||
b_g = to_fp8(torch.randn((n_g, k_g), device=device).t())
|
||||
@ -639,7 +637,4 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
for g in range(num_experts):
|
||||
baseline = baseline_tensors[g]
|
||||
c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]]
|
||||
print(baseline)
|
||||
print(c)
|
||||
print("*")
|
||||
torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user