diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 0bc69fe7f930..a4cfc7d6c15c 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -130,12 +130,14 @@ def _w8a8_triton_block_scaled_mm_fake( device=qx.device) -direct_register_custom_op( - "w8a8_triton_block_scaled_mm_func", - _w8a8_triton_block_scaled_mm_func, - fake_impl=_w8a8_triton_block_scaled_mm_fake, - dispatch_key="CUDA", -) +# Note: the check can be removed when CPU torch > 2.7 +if not current_platform.is_cpu(): + direct_register_custom_op( + "w8a8_triton_block_scaled_mm_func", + _w8a8_triton_block_scaled_mm_func, + fake_impl=_w8a8_triton_block_scaled_mm_fake, + dispatch_key="CUDA", + ) # TODO fix ROCm->Triton custom path: