mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:05:48 +08:00
[Quantization] Add compressed-tensors NVFP4 support (#18312)
This commit is contained in:
parent
b9a1791e2c
commit
c123bc33f9
@ -14,9 +14,10 @@ from compressed_tensors.quantization import QuantizationType
|
|||||||
from tests.models.utils import check_logprobs_close
|
from tests.models.utils import check_logprobs_close
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||||
CompressedTensors24, CompressedTensorsLinearMethod,
|
CompressedTensors24, CompressedTensorsLinearMethod,
|
||||||
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
|
||||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||||
|
CompressedTensorsWNA16)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
sparse_cutlass_supported)
|
sparse_cutlass_supported)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -651,9 +652,13 @@ def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
|
|||||||
assert output
|
assert output
|
||||||
|
|
||||||
|
|
||||||
def test_compressed_tensors_nvfp4a16(vllm_runner):
|
@pytest.mark.parametrize(
|
||||||
# run weight only example
|
"args",
|
||||||
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16"
|
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16",
|
||||||
|
CompressedTensorsW4A16Fp4),
|
||||||
|
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4)])
|
||||||
|
def test_compressed_tensors_nvfp4(vllm_runner, args):
|
||||||
|
model, scheme = args
|
||||||
with vllm_runner(model, enforce_eager=True) as llm:
|
with vllm_runner(model, enforce_eager=True) as llm:
|
||||||
|
|
||||||
def check_model(model):
|
def check_model(model):
|
||||||
@ -662,7 +667,7 @@ def test_compressed_tensors_nvfp4a16(vllm_runner):
|
|||||||
qkv_proj = layer.self_attn.qkv_proj
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
assert isinstance(qkv_proj.quant_method,
|
assert isinstance(qkv_proj.quant_method,
|
||||||
CompressedTensorsLinearMethod)
|
CompressedTensorsLinearMethod)
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
|
assert isinstance(qkv_proj.scheme, scheme)
|
||||||
assert qkv_proj.scheme.group_size == 16
|
assert qkv_proj.scheme.group_size == 16
|
||||||
|
|
||||||
llm.apply_model(check_model)
|
llm.apply_model(check_model)
|
||||||
|
|||||||
@ -24,10 +24,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|||||||
CompressedTensorsMoEMethod)
|
CompressedTensorsMoEMethod)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
|
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
|
||||||
CompressedTensorsScheme, CompressedTensorsW4A16Fp4,
|
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
|
||||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
||||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||||
CompressedTensorsWNA16)
|
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
find_matched_target, is_activation_quantization_format,
|
find_matched_target, is_activation_quantization_format,
|
||||||
should_ignore_layer)
|
should_ignore_layer)
|
||||||
@ -218,6 +218,26 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
|
||||||
|
|
||||||
|
if weight_quant is None or input_quant is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
is_tensor_group_quant = (weight_quant.strategy
|
||||||
|
== QuantizationStrategy.TENSOR_GROUP.value
|
||||||
|
and input_quant.strategy
|
||||||
|
== QuantizationStrategy.TENSOR_GROUP.value)
|
||||||
|
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||||
|
|
||||||
|
is_group_size_16 = (weight_quant.group_size == 16
|
||||||
|
and input_quant.group_size == 16)
|
||||||
|
is_float_type = (weight_quant.type == QuantizationType.FLOAT
|
||||||
|
and input_quant.type == QuantizationType.FLOAT.value)
|
||||||
|
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
|
||||||
|
|
||||||
|
return (is_tensor_group_quant and is_float_type and is_4_bits
|
||||||
|
and is_group_size_16 and is_symmetric)
|
||||||
|
|
||||||
def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
|
def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
|
||||||
input_quant: BaseModel):
|
input_quant: BaseModel):
|
||||||
|
|
||||||
@ -353,6 +373,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
actorder=weight_quant.actorder)
|
actorder=weight_quant.actorder)
|
||||||
|
|
||||||
if is_activation_quantization_format(self.quant_format):
|
if is_activation_quantization_format(self.quant_format):
|
||||||
|
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||||
|
return CompressedTensorsW4A4Fp4()
|
||||||
|
|
||||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||||
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
|
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||||
|
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
|
||||||
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
||||||
CompressedTensorsW4A16Sparse24)
|
CompressedTensorsW4A16Sparse24)
|
||||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||||
@ -18,5 +19,6 @@ __all__ = [
|
|||||||
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
|
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
|
||||||
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
|
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
|
||||||
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
|
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
|
||||||
"CompressedTensors24", "CompressedTensorsW4A16Fp4"
|
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
|
||||||
|
"CompressedTensorsW4A4Fp4"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -0,0 +1,178 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
||||||
|
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsScheme)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||||
|
dequantize_to_dtype, ref_nvfp4_quant)
|
||||||
|
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||||
|
ModelWeightParameter,
|
||||||
|
PerTensorScaleParameter)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
__all__ = ["CompressedTensorsW4A4Fp4"]
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_fp4_supported() -> bool:
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False
|
||||||
|
capability_tuple = current_platform.get_device_capability()
|
||||||
|
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||||
|
return cutlass_scaled_mm_supports_fp4(capability)
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.group_size = 16
|
||||||
|
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||||
|
if not self.cutlass_nvfp4_supported:
|
||||||
|
logger.warning("Current platform does not support cutlass NVFP4."
|
||||||
|
" Running emulations.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
# dont restrict as emulations
|
||||||
|
return 80
|
||||||
|
|
||||||
|
def run_nvfp4_emulations(self, x: torch.Tensor, layer):
|
||||||
|
x_m, x_k = x.shape
|
||||||
|
output_dtype = x.dtype
|
||||||
|
|
||||||
|
# quantize input to (FP4 and interleaved block scale)
|
||||||
|
x_fp4, x_blockscale = ref_nvfp4_quant(x, layer.input_global_scale,
|
||||||
|
self.group_size)
|
||||||
|
|
||||||
|
# dequantize input
|
||||||
|
x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size)
|
||||||
|
x_blockscale = x_blockscale.unsqueeze(-1) / layer.input_global_scale
|
||||||
|
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
|
||||||
|
del x_fp4, x_blockscale
|
||||||
|
|
||||||
|
# dequantize weight
|
||||||
|
w_fp4 = layer.weight.data.view(torch.uint8)
|
||||||
|
w_blockscale = layer.weight_scale_swizzled.data
|
||||||
|
w_global_scale = layer.weight_global_scale
|
||||||
|
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
|
||||||
|
output_dtype, x.device, self.group_size)
|
||||||
|
|
||||||
|
# matmul
|
||||||
|
out = torch.matmul(x_dq, w_dq.t())
|
||||||
|
del w_dq, x_dq
|
||||||
|
return out
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
|
output_partition_sizes: list[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
|
**kwargs):
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
|
|
||||||
|
# Weight
|
||||||
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
|
sum(output_partition_sizes),
|
||||||
|
input_size_per_partition // 2,
|
||||||
|
dtype=torch.uint8),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("weight_packed", weight)
|
||||||
|
|
||||||
|
# Global Weight Scale
|
||||||
|
weight_global_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||||
|
|
||||||
|
# Per Group Weight Scale
|
||||||
|
weight_scale = GroupQuantScaleParameter(data=torch.empty(
|
||||||
|
sum(output_partition_sizes),
|
||||||
|
input_size_per_partition // self.group_size,
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
input_global_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("input_global_scale", input_global_scale)
|
||||||
|
|
||||||
|
def swizzle_blockscale(self, scale: torch.tensor):
|
||||||
|
assert (scale.dtype == torch.float8_e4m3fn)
|
||||||
|
# Pad and blockwise interleave weight_scale
|
||||||
|
scale_ndim = scale.ndim
|
||||||
|
if scale.ndim == 2:
|
||||||
|
scale = scale.unsqueeze(0)
|
||||||
|
assert scale.ndim == 3
|
||||||
|
B, M, K = scale.shape
|
||||||
|
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||||
|
M_padded = round_up_multiple(M, 128)
|
||||||
|
K_padded = round_up_multiple(K, 4)
|
||||||
|
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||||
|
padded_scale[:B, :M, :K] = scale
|
||||||
|
batches, rows, cols = padded_scale.shape
|
||||||
|
assert rows % 128 == 0
|
||||||
|
assert cols % 4 == 0
|
||||||
|
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||||
|
cols // 4, 4)
|
||||||
|
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||||
|
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||||
|
return (swizzled_scale.reshape(M, K)
|
||||||
|
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
|
|
||||||
|
global_input_scale = layer.input_global_scale.max().to(torch.float32)
|
||||||
|
layer.input_global_scale = Parameter(global_input_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.weight_global_scale = Parameter(
|
||||||
|
layer.weight_global_scale.max().to(torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
||||||
|
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
# required by cutlass kernel; need Parameter, not ModelWeightParameter
|
||||||
|
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||||
|
|
||||||
|
if self.cutlass_nvfp4_supported:
|
||||||
|
layer.alpha = Parameter(layer.input_global_scale *
|
||||||
|
layer.weight_global_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
|
if self.cutlass_nvfp4_supported:
|
||||||
|
output_dtype = x.dtype
|
||||||
|
output_shape = [x.shape[0], layer.weight.shape[0]]
|
||||||
|
|
||||||
|
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||||
|
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
|
||||||
|
|
||||||
|
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
|
||||||
|
layer.weight_scale_swizzled,
|
||||||
|
1 / layer.alpha, output_dtype)
|
||||||
|
if bias is not None:
|
||||||
|
out = out + bias
|
||||||
|
return out.view(*output_shape)
|
||||||
|
return self.run_nvfp4_emulations(x, layer)
|
||||||
@ -15,6 +15,7 @@ def is_activation_quantization_format(format: str) -> bool:
|
|||||||
CompressionFormat.naive_quantized.value,
|
CompressionFormat.naive_quantized.value,
|
||||||
CompressionFormat.int_quantized.value,
|
CompressionFormat.int_quantized.value,
|
||||||
CompressionFormat.float_quantized.value,
|
CompressionFormat.float_quantized.value,
|
||||||
|
CompressionFormat.nvfp4_pack_quantized.value
|
||||||
]
|
]
|
||||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||||
|
|
||||||
|
|||||||
@ -2,10 +2,11 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
__all__ = [
|
from vllm.scalar_type import scalar_types
|
||||||
"break_fp4_bytes",
|
|
||||||
"dequantize_to_dtype",
|
__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"]
|
||||||
]
|
|
||||||
|
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||||
|
|
||||||
kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
|
kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
|
||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
@ -60,3 +61,44 @@ def dequantize_to_dtype(tensor_fp4,
|
|||||||
# scale the tensor
|
# scale the tensor
|
||||||
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
||||||
return out.to(dtype)
|
return out.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def get_reciprocal(x):
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
|
||||||
|
elif isinstance(x, (float, int)):
|
||||||
|
return 0.0 if x == 0 else 1.0 / x
|
||||||
|
else:
|
||||||
|
raise TypeError("Input must be a float, int, or a torch.Tensor.")
|
||||||
|
|
||||||
|
|
||||||
|
def cast_to_fp4(x):
|
||||||
|
sign = torch.sign(x)
|
||||||
|
x = torch.abs(x)
|
||||||
|
x[(x >= 0.0) & (x <= 0.25)] = 0.0
|
||||||
|
x[(x > 0.25) & (x < 0.75)] = 0.5
|
||||||
|
x[(x >= 0.75) & (x <= 1.25)] = 1.0
|
||||||
|
x[(x > 1.25) & (x < 1.75)] = 1.5
|
||||||
|
x[(x >= 1.75) & (x <= 2.5)] = 2.0
|
||||||
|
x[(x > 2.5) & (x < 3.5)] = 3.0
|
||||||
|
x[(x >= 3.5) & (x <= 5.0)] = 4.0
|
||||||
|
x[x > 5.0] = 6.0
|
||||||
|
return x * sign
|
||||||
|
|
||||||
|
|
||||||
|
def ref_nvfp4_quant(x, global_scale, block_size):
|
||||||
|
assert global_scale.dtype == torch.float32
|
||||||
|
assert x.ndim == 2
|
||||||
|
m, n = x.shape
|
||||||
|
x = torch.reshape(x, (m, n // block_size, block_size))
|
||||||
|
vec_max = torch.max(torch.abs(x), dim=-1,
|
||||||
|
keepdim=True)[0].to(torch.float32)
|
||||||
|
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
|
||||||
|
scale = torch.clamp(scale, max=448, min=-448)
|
||||||
|
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
|
||||||
|
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
|
||||||
|
|
||||||
|
scaled_x = x.to(torch.float32) * output_scale
|
||||||
|
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
|
||||||
|
# both outputs are float32
|
||||||
|
return cast_to_fp4(clipped_x), scale.squeeze(-1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user