From cfb476fe539d2d49a97bf1d858fd4bb34b92d6d9 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 7 Nov 2025 07:48:30 +0000 Subject: [PATCH] minor fixes Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 4 ++-- .../model_executor/layers/quantization/fbgemm_fp8.py | 3 ++- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 1 + .../layers/quantization/kernels/scaled_mm/rocm.py | 12 ++++++------ 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index a1c60fadce6d6..2cd29e0905d06 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -70,10 +70,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme self.weight_block_size = self.weight_quant.block_structure - self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() - self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() if self.weight_block_size is not None: + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() assert not self.is_static_input_scheme self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index c19dd708b2339..bcd02554008ca 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -28,6 +28,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, kFp8DynamicTokenSym, + kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( maybe_create_device_identity, @@ -96,7 +97,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): self.out_dtype = torch.get_default_dtype() self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=kFp8DynamicTokenSym, - weight_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticTokenSym, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 5baa7f73077aa..e2b4f08f6db4c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -22,6 +22,7 @@ class ScaledMMLinearLayerConfig: @dataclass class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): + # TODO: Chnage to QuantKey like FP8ScaledMMLinearLayerConfig is_static_input_scheme: bool is_channelwise: bool input_symmetric: bool diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 852e0088d0d97..bcc92ef209af5 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -78,12 +78,12 @@ def rocm_per_tensor_float_w8a8_scaled_mm( return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) -# if current_platform.is_rocm(): -direct_register_custom_op( - op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", - op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, - fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, -) +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, + fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, + ) class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):