mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 02:05:18 +08:00
Use NVFP4 Marlin for CompressedTensorsW4A16Fp4 (#18000)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Dipika <dipikasikka1@gmail.com> Co-authored-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
parent
9d7ea9dbbf
commit
307939f299
@ -2,13 +2,12 @@
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
dequantize_to_dtype)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
@ -31,6 +30,10 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# Weight
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
@ -60,48 +63,30 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
assert (scale.dtype == torch.float8_e4m3fn)
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
if scale.ndim == 2:
|
||||
scale = scale.unsqueeze(0)
|
||||
assert scale.ndim == 3
|
||||
B, M, K = scale.shape
|
||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||
M_padded = round_up_multiple(M, 128)
|
||||
K_padded = round_up_multiple(K, 4)
|
||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||
padded_scale[:B, :M, :K] = scale
|
||||
batches, rows, cols = padded_scale.shape
|
||||
assert rows % 128 == 0
|
||||
assert cols % 4 == 0
|
||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||
cols // 4, 4)
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (swizzled_scale.reshape(M, K)
|
||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
layer.weight_global_scale = Parameter(
|
||||
layer.weight_global_scale.max().to(torch.float32),
|
||||
# Process parameters for marlin repacking
|
||||
|
||||
# Rename weight_packed to weight that marlin expects
|
||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
del layer.weight_packed
|
||||
# Rename weight_global_scale to weight_scale_2 that marlin expects
|
||||
# Note: ct stores the inverse of what is expected by the marlin kernel
|
||||
layer.weight_scale_2 = Parameter(
|
||||
1 / layer.weight_global_scale.max().to(torch.float32),
|
||||
requires_grad=False)
|
||||
# Note: a post weight loading step but not required for the emulation
|
||||
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
del layer.weight_global_scale
|
||||
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
w_fp4 = layer.weight_packed.data
|
||||
w_global_scale = layer.weight_global_scale
|
||||
w_blockscale = layer.weight_scale_swizzled.data
|
||||
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
|
||||
x.dtype, x.device, self.group_size)
|
||||
out = F.linear(x, w_dq)
|
||||
del w_dq, w_fp4, w_global_scale, w_blockscale
|
||||
return out
|
||||
return apply_fp4_marlin_linear(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user