mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-26 14:17:10 +08:00
Use NVFP4 Marlin for CompressedTensorsW4A16Fp4 (#18000)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Dipika <dipikasikka1@gmail.com> Co-authored-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
parent
9d7ea9dbbf
commit
307939f299
@ -2,13 +2,12 @@
|
|||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||||
dequantize_to_dtype)
|
apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin)
|
||||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
@ -31,6 +30,10 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
|||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, weight_loader: Callable,
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
|
|
||||||
# Weight
|
# Weight
|
||||||
weight = ModelWeightParameter(data=torch.empty(
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
@ -60,48 +63,30 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
|||||||
|
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
def swizzle_blockscale(self, scale: torch.tensor):
|
|
||||||
assert (scale.dtype == torch.float8_e4m3fn)
|
|
||||||
# Pad and blockwise interleave weight_scale
|
|
||||||
scale_ndim = scale.ndim
|
|
||||||
if scale.ndim == 2:
|
|
||||||
scale = scale.unsqueeze(0)
|
|
||||||
assert scale.ndim == 3
|
|
||||||
B, M, K = scale.shape
|
|
||||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
|
||||||
M_padded = round_up_multiple(M, 128)
|
|
||||||
K_padded = round_up_multiple(K, 4)
|
|
||||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
|
||||||
padded_scale[:B, :M, :K] = scale
|
|
||||||
batches, rows, cols = padded_scale.shape
|
|
||||||
assert rows % 128 == 0
|
|
||||||
assert cols % 4 == 0
|
|
||||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
|
||||||
cols // 4, 4)
|
|
||||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
|
||||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
|
||||||
return (swizzled_scale.reshape(M, K)
|
|
||||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer) -> None:
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
layer.weight_global_scale = Parameter(
|
# Process parameters for marlin repacking
|
||||||
layer.weight_global_scale.max().to(torch.float32),
|
|
||||||
|
# Rename weight_packed to weight that marlin expects
|
||||||
|
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||||
|
del layer.weight_packed
|
||||||
|
# Rename weight_global_scale to weight_scale_2 that marlin expects
|
||||||
|
# Note: ct stores the inverse of what is expected by the marlin kernel
|
||||||
|
layer.weight_scale_2 = Parameter(
|
||||||
|
1 / layer.weight_global_scale.max().to(torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
# Note: a post weight loading step but not required for the emulation
|
del layer.weight_global_scale
|
||||||
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
|
||||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
prepare_fp4_layer_for_marlin(layer)
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
def apply_weights(self,
|
def apply_weights(self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
return apply_fp4_marlin_linear(input=x,
|
||||||
w_fp4 = layer.weight_packed.data
|
weight=layer.weight,
|
||||||
w_global_scale = layer.weight_global_scale
|
weight_scale=layer.weight_scale,
|
||||||
w_blockscale = layer.weight_scale_swizzled.data
|
weight_scale_2=layer.weight_scale_2,
|
||||||
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
|
workspace=layer.workspace,
|
||||||
x.dtype, x.device, self.group_size)
|
size_n=layer.output_size_per_partition,
|
||||||
out = F.linear(x, w_dq)
|
size_k=layer.input_size_per_partition,
|
||||||
del w_dq, w_fp4, w_global_scale, w_blockscale
|
bias=bias)
|
||||||
return out
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user