diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 40e124a03eb08..f2fbb1200eecc 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -29,8 +29,9 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default FUSED_OPS: dict[QuantKey, OpOverload] = { kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501 } -if current_platform.is_cuda() and hasattr(torch.ops._C, - "silu_and_mul_nvfp4_quant"): +silu_and_mul_nvfp4_quant_supported = (current_platform.is_cuda() and hasattr( + torch.ops._C, "silu_and_mul_nvfp4_quant")) +if silu_and_mul_nvfp4_quant_supported: FUSED_OPS[ kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 @@ -171,8 +172,9 @@ class ActivationQuantFusionPass(VllmInductorPass): pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern() pattern_silu_mul_fp8.register(self.patterns) - pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() - pattern_silu_mul_nvfp4.register(self.patterns) + if silu_and_mul_nvfp4_quant_supported: + pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() + pattern_silu_mul_nvfp4.register(self.patterns) def __call__(self, graph: torch.fx.Graph): self.begin()