mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Feature] Add env var VLLM_MOE_USE_DEEP_GEMM (#28422)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
39029d5192
commit
de540c0354
@ -147,6 +147,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_TPU_MOST_MODEL_LEN: int | None = None
|
VLLM_TPU_MOST_MODEL_LEN: int | None = None
|
||||||
VLLM_TPU_USING_PATHWAYS: bool = False
|
VLLM_TPU_USING_PATHWAYS: bool = False
|
||||||
VLLM_USE_DEEP_GEMM: bool = True
|
VLLM_USE_DEEP_GEMM: bool = True
|
||||||
|
VLLM_MOE_USE_DEEP_GEMM: bool = True
|
||||||
VLLM_USE_DEEP_GEMM_E8M0: bool = True
|
VLLM_USE_DEEP_GEMM_E8M0: bool = True
|
||||||
VLLM_DEEP_GEMM_WARMUP: Literal[
|
VLLM_DEEP_GEMM_WARMUP: Literal[
|
||||||
"skip",
|
"skip",
|
||||||
@ -1116,6 +1117,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
),
|
),
|
||||||
# Allow use of DeepGemm kernels for fused moe ops.
|
# Allow use of DeepGemm kernels for fused moe ops.
|
||||||
"VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))),
|
"VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))),
|
||||||
|
# Allow use of DeepGemm specifically for MoE fused ops (overrides only MoE).
|
||||||
|
"VLLM_MOE_USE_DEEP_GEMM": lambda: bool(
|
||||||
|
int(os.getenv("VLLM_MOE_USE_DEEP_GEMM", "1"))
|
||||||
|
),
|
||||||
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
|
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
|
||||||
"VLLM_USE_DEEP_GEMM_E8M0": lambda: bool(
|
"VLLM_USE_DEEP_GEMM_E8M0": lambda: bool(
|
||||||
int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))
|
int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))
|
||||||
@ -1569,6 +1574,7 @@ def compute_hash() -> str:
|
|||||||
"VLLM_USE_FLASHINFER_SAMPLER",
|
"VLLM_USE_FLASHINFER_SAMPLER",
|
||||||
"VLLM_DISABLED_KERNELS",
|
"VLLM_DISABLED_KERNELS",
|
||||||
"VLLM_USE_DEEP_GEMM",
|
"VLLM_USE_DEEP_GEMM",
|
||||||
|
"VLLM_MOE_USE_DEEP_GEMM",
|
||||||
"VLLM_USE_DEEP_GEMM_E8M0",
|
"VLLM_USE_DEEP_GEMM_E8M0",
|
||||||
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
|
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
|
||||||
"VLLM_USE_FLASHINFER_MOE_FP16",
|
"VLLM_USE_FLASHINFER_MOE_FP16",
|
||||||
|
|||||||
@ -966,10 +966,18 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
max_num_tokens=max_num_tokens_per_rank,
|
max_num_tokens=max_num_tokens_per_rank,
|
||||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||||
quant_config=self.moe_quant_config,
|
quant_config=self.moe_quant_config,
|
||||||
|
allow_deep_gemm=(
|
||||||
|
envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
|
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
|
||||||
return TritonOrDeepGemmExperts(self.moe_quant_config, allow_deep_gemm=True)
|
return TritonOrDeepGemmExperts(
|
||||||
|
self.moe_quant_config,
|
||||||
|
allow_deep_gemm=(
|
||||||
|
envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def get_fused_moe_quant_config(
|
def get_fused_moe_quant_config(
|
||||||
self, layer: torch.nn.Module
|
self, layer: torch.nn.Module
|
||||||
|
|||||||
@ -158,7 +158,7 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
|
|||||||
return Fp8MoeBackend.MARLIN
|
return Fp8MoeBackend.MARLIN
|
||||||
|
|
||||||
# deepGEMM on supported platforms with block-quantized weights
|
# deepGEMM on supported platforms with block-quantized weights
|
||||||
if envs.VLLM_USE_DEEP_GEMM and block_quant:
|
if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
|
||||||
if not has_deep_gemm():
|
if not has_deep_gemm():
|
||||||
logger.warning_once("DeepGEMM backend requested but not available.")
|
logger.warning_once("DeepGEMM backend requested but not available.")
|
||||||
elif is_deep_gemm_supported():
|
elif is_deep_gemm_supported():
|
||||||
|
|||||||
@ -148,6 +148,9 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||||
|
if not (envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM):
|
||||||
|
return False
|
||||||
|
|
||||||
if not isinstance(module, FusedMoE):
|
if not isinstance(module, FusedMoE):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user