mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:15:00 +08:00
Add test for batched triton fallback behavior
Co-authored-by: tlrmchlsmth <1236979+tlrmchlsmth@users.noreply.github.com>
This commit is contained in:
parent
c292032b44
commit
c72d44ba4a
111
tests/kernels/moe/test_batched_triton_fallback.py
Normal file
111
tests/kernels/moe/test_batched_triton_fallback.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Test for batched triton kernel fallback behavior when deepgemm is unavailable."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import (
|
||||||
|
BatchedTritonOrDeepGemmExperts,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_batched_triton_fallback_disabled_by_default():
|
||||||
|
"""Test that batched triton fallback is disabled by default when deepgemm is requested."""
|
||||||
|
# Create a quant config that doesn't support deepgemm (not fp8_w8a8)
|
||||||
|
quant_config = FusedMoEQuantConfig(
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
use_int8_w8a16=True,
|
||||||
|
use_int4_w4a16=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure the env variable is not set (or set to 0)
|
||||||
|
original_value = os.environ.get("VLLM_ALLOW_BATCHED_TRITON_FALLBACK")
|
||||||
|
os.environ["VLLM_ALLOW_BATCHED_TRITON_FALLBACK"] = "0"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# This should raise RuntimeError because deepgemm was requested
|
||||||
|
# but is unavailable, and fallback is disabled
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError, match="DeepGemm was requested but is not available"
|
||||||
|
):
|
||||||
|
BatchedTritonOrDeepGemmExperts(
|
||||||
|
max_num_tokens=128,
|
||||||
|
num_dispatchers=1,
|
||||||
|
quant_config=quant_config,
|
||||||
|
allow_deep_gemm=True, # Request deepgemm
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Restore original value
|
||||||
|
if original_value is None:
|
||||||
|
os.environ.pop("VLLM_ALLOW_BATCHED_TRITON_FALLBACK", None)
|
||||||
|
else:
|
||||||
|
os.environ["VLLM_ALLOW_BATCHED_TRITON_FALLBACK"] = original_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_batched_triton_fallback_enabled_with_env_var():
|
||||||
|
"""Test that batched triton fallback works when env variable is set."""
|
||||||
|
# Create a quant config that doesn't support deepgemm (not fp8_w8a8)
|
||||||
|
quant_config = FusedMoEQuantConfig(
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
use_int8_w8a16=True,
|
||||||
|
use_int4_w4a16=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the env variable to allow fallback
|
||||||
|
original_value = os.environ.get("VLLM_ALLOW_BATCHED_TRITON_FALLBACK")
|
||||||
|
os.environ["VLLM_ALLOW_BATCHED_TRITON_FALLBACK"] = "1"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# This should NOT raise an error - it should fall back to batched triton
|
||||||
|
experts = BatchedTritonOrDeepGemmExperts(
|
||||||
|
max_num_tokens=128,
|
||||||
|
num_dispatchers=1,
|
||||||
|
quant_config=quant_config,
|
||||||
|
allow_deep_gemm=True, # Request deepgemm
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that deepgemm is not used and batched triton is used instead
|
||||||
|
assert experts.batched_deep_gemm_experts is None
|
||||||
|
assert experts.batched_triton_experts is not None
|
||||||
|
finally:
|
||||||
|
# Restore original value
|
||||||
|
if original_value is None:
|
||||||
|
os.environ.pop("VLLM_ALLOW_BATCHED_TRITON_FALLBACK", None)
|
||||||
|
else:
|
||||||
|
os.environ["VLLM_ALLOW_BATCHED_TRITON_FALLBACK"] = original_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_batched_triton_no_error_when_deepgemm_not_requested():
|
||||||
|
"""Test that no error is raised when deepgemm is not requested."""
|
||||||
|
# Create a quant config
|
||||||
|
quant_config = FusedMoEQuantConfig(
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
use_int8_w8a16=True,
|
||||||
|
use_int4_w4a16=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure the env variable is not set (or set to 0)
|
||||||
|
original_value = os.environ.get("VLLM_ALLOW_BATCHED_TRITON_FALLBACK")
|
||||||
|
os.environ["VLLM_ALLOW_BATCHED_TRITON_FALLBACK"] = "0"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# This should NOT raise an error because deepgemm was not requested
|
||||||
|
experts = BatchedTritonOrDeepGemmExperts(
|
||||||
|
max_num_tokens=128,
|
||||||
|
num_dispatchers=1,
|
||||||
|
quant_config=quant_config,
|
||||||
|
allow_deep_gemm=False, # Don't request deepgemm
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that batched triton is used
|
||||||
|
assert experts.batched_deep_gemm_experts is None
|
||||||
|
assert experts.batched_triton_experts is not None
|
||||||
|
finally:
|
||||||
|
# Restore original value
|
||||||
|
if original_value is None:
|
||||||
|
os.environ.pop("VLLM_ALLOW_BATCHED_TRITON_FALLBACK", None)
|
||||||
|
else:
|
||||||
|
os.environ["VLLM_ALLOW_BATCHED_TRITON_FALLBACK"] = original_value
|
||||||
Loading…
x
Reference in New Issue
Block a user