mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
[quantization] use channel scales for w4a8 + misc fixes (#23570)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
parent
c7c80af084
commit
2c2b140ae8
@ -14,10 +14,10 @@ from compressed_tensors.quantization import QuantizationType
|
|||||||
from tests.models.utils import check_logprobs_close
|
from tests.models.utils import check_logprobs_close
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||||
CompressedTensors24, CompressedTensorsLinearMethod,
|
CompressedTensors24, CompressedTensorsLinearMethod,
|
||||||
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
|
CompressedTensorsW4A4Fp4, CompressedTensorsW4A8Fp8,
|
||||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
||||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||||
CompressedTensorsWNA16)
|
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
cutlass_fp4_supported)
|
cutlass_fp4_supported)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
@ -683,3 +683,39 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
|
|||||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
print(output)
|
print(output)
|
||||||
assert output
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda()
|
||||||
|
or not current_platform.has_device_capability(90),
|
||||||
|
reason="W4A8 FP8 is not yet supported on this GPU type.",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("args", [
|
||||||
|
("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8)
|
||||||
|
])
|
||||||
|
def test_compressed_tensors_w4a8_fp8(vllm_runner, args):
|
||||||
|
model, scheme = args
|
||||||
|
with vllm_runner(model, enforce_eager=True) as llm:
|
||||||
|
|
||||||
|
def check_model(model):
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
o_proj = layer.self_attn.o_proj
|
||||||
|
gate_up_proj = layer.mlp.gate_up_proj
|
||||||
|
down_proj = layer.mlp.down_proj
|
||||||
|
|
||||||
|
for proj in (qkv_proj, o_proj, gate_up_proj, down_proj):
|
||||||
|
assert isinstance(proj.quant_method,
|
||||||
|
CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(proj.scheme, scheme)
|
||||||
|
|
||||||
|
assert proj.weight_packed.dtype is torch.int32
|
||||||
|
assert proj.weight_scale.dtype is torch.float8_e4m3fn
|
||||||
|
assert proj.weight_chan_scale.dtype is torch.float32
|
||||||
|
assert proj.scheme.group_size == 128
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
|
print(output)
|
||||||
|
assert output
|
||||||
|
|||||||
@ -79,7 +79,8 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
|||||||
act_type=torch.float8_e4m3fn, # always use fp8(e4m3)
|
act_type=torch.float8_e4m3fn, # always use fp8(e4m3)
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
zero_points=not self.symmetric,
|
zero_points=not self.symmetric,
|
||||||
has_g_idx=self.has_g_idx
|
has_g_idx=self.has_g_idx,
|
||||||
|
out_type=params_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||||
@ -122,7 +123,7 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
|||||||
torch.empty(
|
torch.empty(
|
||||||
output_size_per_partition,
|
output_size_per_partition,
|
||||||
scales_and_zp_size,
|
scales_and_zp_size,
|
||||||
dtype=params_dtype,
|
dtype=torch.float8_e4m3fn,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,9 +141,17 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
|||||||
dtype=torch.int64),
|
dtype=torch.int64),
|
||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
|
# per-channel scales
|
||||||
|
weight_chan_scale = ChannelQuantScaleParameter(
|
||||||
|
data=torch.empty((output_size_per_partition, 1),
|
||||||
|
dtype=torch.float32),
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
layer.register_parameter("weight_packed", weight)
|
layer.register_parameter("weight_packed", weight)
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
layer.register_parameter("weight_shape", weight_shape)
|
layer.register_parameter("weight_shape", weight_shape)
|
||||||
|
layer.register_parameter("weight_chan_scale", weight_chan_scale)
|
||||||
|
|
||||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||||
w_q_param_name="weight_packed",
|
w_q_param_name="weight_packed",
|
||||||
|
|||||||
@ -20,6 +20,7 @@ class MPLinearLayerConfig:
|
|||||||
group_size: int
|
group_size: int
|
||||||
zero_points: bool
|
zero_points: bool
|
||||||
has_g_idx: bool
|
has_g_idx: bool
|
||||||
|
out_type: Optional[torch.dtype] = None
|
||||||
|
|
||||||
|
|
||||||
class MPLinearKernel(ABC):
|
class MPLinearKernel(ABC):
|
||||||
|
|||||||
@ -60,13 +60,17 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
|||||||
if in_features % 128 or out_features % 128:
|
if in_features % 128 or out_features % 128:
|
||||||
return False, "K and N must be divisible by 128, got "\
|
return False, "K and N must be divisible by 128, got "\
|
||||||
f"{c.partition_weight_shape}"
|
f"{c.partition_weight_shape}"
|
||||||
|
|
||||||
|
if c.out_type != torch.bfloat16:
|
||||||
|
return False, "Only bfloat16 output type currently supported"\
|
||||||
|
f"got {c.out_type=}"
|
||||||
|
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
# note assumes that
|
# note assumes that
|
||||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||||
c = self.config
|
|
||||||
|
|
||||||
# TODO(czhu): optimize speed/mem usage
|
# TODO(czhu): optimize speed/mem usage
|
||||||
def transform_w_q(x):
|
def transform_w_q(x):
|
||||||
@ -86,19 +90,15 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
|||||||
# Encode/reorder weights and pack scales
|
# Encode/reorder weights and pack scales
|
||||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||||
|
self._transform_param(layer, "weight_chan_scale", lambda x: x)
|
||||||
# TODO(czhu): support loading channel scales
|
|
||||||
self.w_ch_s = torch.ones((c.partition_weight_shape[1], ),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device='cuda')
|
|
||||||
|
|
||||||
def apply_weights(self,
|
def apply_weights(self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
assert bias is None, "bias not supported by CUTLASS W4A8"
|
|
||||||
c = self.config
|
c = self.config
|
||||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||||
|
w_ch_s = layer.weight_chan_scale
|
||||||
|
|
||||||
x_2d = x.reshape(-1, x.shape[-1])
|
x_2d = x.reshape(-1, x.shape[-1])
|
||||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||||
@ -109,6 +109,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
|||||||
b_group_scales=w_s,
|
b_group_scales=w_s,
|
||||||
b_group_size=c.group_size,
|
b_group_size=c.group_size,
|
||||||
a_token_scales=act_scales,
|
a_token_scales=act_scales,
|
||||||
b_channel_scales=self.w_ch_s)
|
b_channel_scales=w_ch_s)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias) # In-place add
|
||||||
|
|
||||||
return output.reshape(out_shape)
|
return output.reshape(out_shape)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user