minor fixes

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-07 07:48:30 +00:00
parent 9ff9b44e0d
commit cfb476fe53
4 changed files with 11 additions and 9 deletions

View File

@ -70,10 +70,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
if self.weight_block_size is not None:
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
assert not self.is_static_input_scheme
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(

View File

@ -28,6 +28,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
kFp8DynamicTokenSym,
kFp8StaticTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
maybe_create_device_identity,
@ -96,7 +97,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
self.out_dtype = torch.get_default_dtype()
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)

View File

@ -22,6 +22,7 @@ class ScaledMMLinearLayerConfig:
@dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
# TODO: Chnage to QuantKey like FP8ScaledMMLinearLayerConfig
is_static_input_scheme: bool
is_channelwise: bool
input_symmetric: bool

View File

@ -78,12 +78,12 @@ def rocm_per_tensor_float_w8a8_scaled_mm(
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
# if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
)
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):