mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 05:27:04 +08:00
minor fixes
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
9ff9b44e0d
commit
cfb476fe53
@ -70,10 +70,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
|
||||
if self.weight_block_size is not None:
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
assert not self.is_static_input_scheme
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
|
||||
@ -28,6 +28,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTokenSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
maybe_create_device_identity,
|
||||
@ -96,7 +97,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8StaticTokenSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@ -22,6 +22,7 @@ class ScaledMMLinearLayerConfig:
|
||||
|
||||
@dataclass
|
||||
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
# TODO: Chnage to QuantKey like FP8ScaledMMLinearLayerConfig
|
||||
is_static_input_scheme: bool
|
||||
is_channelwise: bool
|
||||
input_symmetric: bool
|
||||
|
||||
@ -78,12 +78,12 @@ def rocm_per_tensor_float_w8a8_scaled_mm(
|
||||
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
# if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
|
||||
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
|
||||
)
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
|
||||
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
|
||||
)
|
||||
|
||||
|
||||
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user