mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 11:37:53 +08:00
[ Misc ] Refactor w8a8 to use process_weights_after_load (Simplify Weight Loading) (#5940)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
parent
7836fdcc11
commit
af9ad46fca
@ -11,14 +11,18 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|||||||
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
|
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
|
||||||
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
|
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
|
||||||
CompressedTensorsWNA16)
|
CompressedTensorsWNA16)
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
|
QuantizationType)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_args", [
|
@pytest.mark.parametrize("model_args", [
|
||||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor"),
|
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
|
||||||
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel"),
|
QuantizationType.INT, 2560),
|
||||||
|
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
|
||||||
|
QuantizationType.INT, 2560),
|
||||||
])
|
])
|
||||||
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||||
model_path, strategy = model_args
|
model_path, strategy, quant_type, shape_0 = model_args
|
||||||
with vllm_runner(model_path, enforce_eager=True) as llm:
|
with vllm_runner(model_path, enforce_eager=True) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
layer = model.model.layers[0]
|
layer = model.model.layers[0]
|
||||||
@ -34,17 +38,23 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
|||||||
CompressedTensorsLinearMethod)
|
CompressedTensorsLinearMethod)
|
||||||
assert isinstance(down_proj.quant_method,
|
assert isinstance(down_proj.quant_method,
|
||||||
CompressedTensorsLinearMethod)
|
CompressedTensorsLinearMethod)
|
||||||
|
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)
|
||||||
|
|
||||||
assert qkv_proj.scheme.strategy == strategy
|
assert qkv_proj.scheme.strategy == strategy
|
||||||
assert qkv_proj.weight.dtype is torch.int8
|
expected_type = (torch.int8 if quant_type == QuantizationType.INT else
|
||||||
assert o_proj.weight.dtype is torch.int8
|
torch.float8_e4m3fn)
|
||||||
assert gate_up_proj.weight.dtype is torch.int8
|
|
||||||
|
assert qkv_proj.weight.dtype is expected_type
|
||||||
|
assert o_proj.weight.dtype is expected_type
|
||||||
|
assert gate_up_proj.weight.dtype is expected_type
|
||||||
|
|
||||||
if qkv_proj.scheme.strategy == "tensor":
|
if qkv_proj.scheme.strategy == "tensor":
|
||||||
assert qkv_proj.weight_scale.shard_splitter is not None
|
# Make sure it is a channelwise buffer
|
||||||
assert qkv_proj.weight_scale.logical_widths is not None
|
# After running process_weights_after_loading
|
||||||
|
assert len(qkv_proj.weight_scale.shape) == 2
|
||||||
|
assert qkv_proj.weight_scale.shape[0] == shape_0
|
||||||
|
assert qkv_proj.weight_scale.shape[1] == 1
|
||||||
|
assert qkv_proj.weight_scale.dtype is torch.float32
|
||||||
assert qkv_proj.input_scale.dtype is torch.float32
|
assert qkv_proj.input_scale.dtype is torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,23 @@ from tests.quantization.utils import is_quant_method_supported
|
|||||||
from vllm._custom_ops import scaled_fp8_quant
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8",
|
||||||
|
"nm-testing/Phi-3-mini-128k-instruct-FP8",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||||
|
reason="FP8 is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
def test_model_load_and_run(vllm_runner, model: str):
|
||||||
|
with vllm_runner(model) as llm:
|
||||||
|
# note: this does not test accuracy, just that we can run through
|
||||||
|
# see lm-eval tests for accuracy
|
||||||
|
outputs = llm.generate_greedy(prompts=["Hello my name is"],
|
||||||
|
max_tokens=10)
|
||||||
|
print(outputs[0][1])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
|
|||||||
@ -41,6 +41,29 @@ def adjust_bitsandbytes_shard(param: Parameter,
|
|||||||
return quantized_size, quantized_offset
|
return quantized_size, quantized_offset
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
||||||
|
"""For fused modules (QKV and MLP) we have an array of length
|
||||||
|
N that holds 1 scale for each "logical" matrix. So the param
|
||||||
|
is an array of length N. The loaded_weight corresponds to
|
||||||
|
one of the shards on disk. Here, we slice the param based on
|
||||||
|
the shard_id for loading.
|
||||||
|
"""
|
||||||
|
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||||
|
|
||||||
|
if isinstance(shard_id, str):
|
||||||
|
shard_id = qkv_idxs[shard_id]
|
||||||
|
elif not isinstance(shard_id, int):
|
||||||
|
raise ValueError(f"Unknown Shard Id {shard_id}")
|
||||||
|
|
||||||
|
# AutoFP8 scales do not have a shape
|
||||||
|
# compressed-tensors scales do have a shape
|
||||||
|
if len(loaded_weight.shape) != 0:
|
||||||
|
assert loaded_weight.shape[0] == 1
|
||||||
|
loaded_weight = loaded_weight[0]
|
||||||
|
|
||||||
|
return param[shard_id], loaded_weight
|
||||||
|
|
||||||
|
|
||||||
class LinearMethodBase(QuantizeMethodBase):
|
class LinearMethodBase(QuantizeMethodBase):
|
||||||
"""Base class for different (maybe quantized) linear methods."""
|
"""Base class for different (maybe quantized) linear methods."""
|
||||||
|
|
||||||
@ -358,37 +381,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
# Special case for AQLM codebooks.
|
# Special case for AQLM codebooks.
|
||||||
is_metadata = getattr(param, "is_metadata", False)
|
is_metadata = getattr(param, "is_metadata", False)
|
||||||
|
# Special case for per-tensor scale to load scalar into fused array.
|
||||||
param_shard_splitter = getattr(param, "shard_splitter", None)
|
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||||
|
|
||||||
if output_dim is not None and param_shard_splitter is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We do not currently support output_dim != None and "
|
|
||||||
"shard_splitter != None for a parameter. Please open an issue."
|
|
||||||
)
|
|
||||||
# If a parameter has defined a shard_splitter to be used for
|
|
||||||
# the weight, it should be applied before the weight is
|
|
||||||
# loaded/copied to the parameter. The shard_splitter applies
|
|
||||||
# logic by using the loaded_shard_id to ensure that the loaded
|
|
||||||
# param is loaded to the correct location
|
|
||||||
# within the parameter defined by the linear method.
|
|
||||||
if loaded_shard_id is None and param_shard_splitter is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We do not currently support loaded_shard_id == None and "
|
|
||||||
"shard_splitter != None for a parameter. Please open an issue."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Special case for Fp8 scales.
|
|
||||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
|
||||||
None)
|
|
||||||
|
|
||||||
if loaded_shard_id is None:
|
if loaded_shard_id is None:
|
||||||
# Loaded weight is already fused on disk (qkv/mlp).
|
# Loaded weight is already fused on disk (qkv/mlp).
|
||||||
if output_dim is None:
|
if output_dim is None:
|
||||||
# If fp8 + scale, need to send to each shard.
|
if needs_scalar_to_array is not None:
|
||||||
if fp8_scales_shard_indexer is not None:
|
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
param_data, loaded_weight, 0)
|
||||||
param_data, loaded_weight, loaded_shard_id)
|
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
@ -450,15 +451,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
shard_offset = loaded_shard_id * shard_size
|
shard_offset = loaded_shard_id * shard_size
|
||||||
param_data = param_data.narrow(0, shard_offset, shard_size)
|
param_data = param_data.narrow(0, shard_offset, shard_size)
|
||||||
|
|
||||||
# If a param_shard_splitter is defined by the LinearMethod, use it.
|
# Special case for per-tensor scales in fused case.
|
||||||
elif param_shard_splitter is not None:
|
elif needs_scalar_to_array:
|
||||||
logical_widths = getattr(param, "logical_widths", None)
|
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||||
param_data, loaded_weight = param_shard_splitter(
|
|
||||||
param_data, loaded_weight, loaded_shard_id, logical_widths)
|
|
||||||
|
|
||||||
# Special case for Fp8 scales.
|
|
||||||
elif fp8_scales_shard_indexer is not None:
|
|
||||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
|
||||||
param_data, loaded_weight, loaded_shard_id)
|
param_data, loaded_weight, loaded_shard_id)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -548,36 +543,15 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
# Special case for AQLM codebooks.
|
# Special case for AQLM codebooks.
|
||||||
is_metadata = getattr(param, "is_metadata", False)
|
is_metadata = getattr(param, "is_metadata", False)
|
||||||
|
|
||||||
param_shard_splitter = getattr(param, "shard_splitter", None)
|
# Special case for per-tensor scales in fused case.
|
||||||
|
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||||
if output_dim is not None and param_shard_splitter is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We do not currently support output_dim != None and "
|
|
||||||
"shard_splitter != None for a parameter. Please open an issue."
|
|
||||||
)
|
|
||||||
# If a parameter has defined a shard_splitter to be used for
|
|
||||||
# the weight, it should be applied before the weight is
|
|
||||||
# loaded/copied to the parameter. The shard_splitter applies
|
|
||||||
# logic by using the loaded_shard_id to ensure that the loaded
|
|
||||||
# param is loaded to the correct location
|
|
||||||
# within the parameter defined by the linear method.
|
|
||||||
if loaded_shard_id is None and param_shard_splitter is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We do not currently support loaded_shard_id == None and "
|
|
||||||
"shard_splitter != None for a parameter. Please open an issue."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Special case for Fp8 scales.
|
|
||||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
|
||||||
None)
|
|
||||||
|
|
||||||
if loaded_shard_id is None:
|
if loaded_shard_id is None:
|
||||||
# Loaded weight is already fused on disk (qkv/mlp).
|
# Loaded weight is already fused on disk (qkv/mlp).
|
||||||
if output_dim is None:
|
if output_dim is None:
|
||||||
# If fp8 + scale, need to send to each shard.
|
if needs_scalar_to_array is not None:
|
||||||
if fp8_scales_shard_indexer is not None:
|
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
param_data, loaded_weight, 0)
|
||||||
param_data, loaded_weight, loaded_shard_id)
|
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
@ -667,15 +641,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
||||||
param_data = param_data.narrow(0, shard_index * shard_size,
|
param_data = param_data.narrow(0, shard_index * shard_size,
|
||||||
shard_size)
|
shard_size)
|
||||||
# If a param_shard_splitter is defined by the LinearMethod, use it.
|
# Special case for per-tensor scales in fused case.
|
||||||
elif param_shard_splitter is not None:
|
elif needs_scalar_to_array:
|
||||||
logical_widths = getattr(param, "logical_widths", None)
|
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||||
param_data, loaded_weight = param_shard_splitter(
|
|
||||||
param_data, loaded_weight, loaded_shard_id, logical_widths)
|
|
||||||
|
|
||||||
# Special case for Fp8 scales.
|
|
||||||
elif fp8_scales_shard_indexer is not None:
|
|
||||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
|
||||||
param_data, loaded_weight, loaded_shard_id)
|
param_data, loaded_weight, loaded_shard_id)
|
||||||
else:
|
else:
|
||||||
ignore_warning = getattr(param, "ignore_warning", False)
|
ignore_warning = getattr(param, "ignore_warning", False)
|
||||||
|
|||||||
@ -186,6 +186,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
|||||||
def __init__(self, quantization_config: CompressedTensorsConfig):
|
def __init__(self, quantization_config: CompressedTensorsConfig):
|
||||||
self.quantization_config = quantization_config
|
self.quantization_config = quantization_config
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
return layer.scheme.process_weights_after_loading(layer)
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
output_partition_sizes: List[int], input_size: int,
|
output_partition_sizes: List[int], input_size: int,
|
||||||
|
|||||||
@ -31,3 +31,11 @@ class CompressedTensorsScheme(ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Called after weight loading is complete for any cleanup that
|
||||||
|
needs to occur.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@ -18,6 +18,9 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
|
|||||||
in a linear transformation.
|
in a linear transformation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
output_partition_sizes: List[int],
|
output_partition_sizes: List[int],
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
|
|||||||
@ -29,6 +29,9 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"group_size must be given when using strategy group")
|
"group_size must be given when using strategy group")
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||||
output_partition_sizes: List[int],
|
output_partition_sizes: List[int],
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
|
|||||||
@ -15,70 +15,63 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
|
|||||||
def __init__(self, strategy: str):
|
def __init__(self, strategy: str):
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
|
|
||||||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
# Cutlass kernels support only per-tensor and per-channel cases.
|
||||||
if isinstance(shard_id, int):
|
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||||
return shard_id
|
# scales being passed to the kernel), we convert to the per-channel case.
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
if (self.strategy == QuantizationStrategy.TENSOR
|
||||||
|
and len(self.logical_widths) > 1):
|
||||||
|
|
||||||
assert isinstance(shard_id, str)
|
# Load the N per-tensor scales into the channelwise buffer.
|
||||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
weight_scale_channel = torch.empty(
|
||||||
assert shard_id in qkv_idxs
|
(sum(self.logical_widths), 1),
|
||||||
return qkv_idxs[shard_id]
|
dtype=torch.float32,
|
||||||
|
device=layer.weight_scale.device)
|
||||||
|
start = 0
|
||||||
|
for idx, logical_width in enumerate(self.logical_widths):
|
||||||
|
end = start + logical_width
|
||||||
|
weight_scale_channel[start:end, :] = layer.weight_scale[idx]
|
||||||
|
start = end
|
||||||
|
|
||||||
def scales_shard_splitter(
|
layer.weight_scale = Parameter(weight_scale_channel,
|
||||||
self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
requires_grad=False)
|
||||||
shard_id: Union[str, int],
|
|
||||||
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
shard_id = self._shard_id_as_int(shard_id)
|
|
||||||
offset = sum(logical_widths[:shard_id])
|
|
||||||
size = logical_widths[shard_id]
|
|
||||||
# update loaded weight with copies for broadcast.
|
|
||||||
loaded_weight = loaded_weight.repeat(size)
|
|
||||||
return param[offset:offset + size], loaded_weight
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
output_partition_sizes: List[int],
|
output_partition_sizes: List[int],
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, weight_loader: Callable,
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
self.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
is_tensor_partitioned = len(output_partition_sizes) != 1
|
# WEIGHT SCALE
|
||||||
weight_scale_dim = sum(output_partition_sizes) if (
|
shape: Union[Tuple[int], Tuple[int, int]]
|
||||||
is_tensor_partitioned
|
|
||||||
or self.strategy == QuantizationStrategy.CHANNEL) else 1
|
|
||||||
|
|
||||||
shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
|
|
||||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
shape = (weight_scale_dim, 1)
|
shape = (sum(self.logical_widths), 1)
|
||||||
|
else:
|
||||||
|
shape = (len(self.logical_widths), )
|
||||||
|
|
||||||
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
|
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
|
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
|
set_weight_attrs(weight_scale, {
|
||||||
|
"weight_loader": weight_loader,
|
||||||
|
"output_dim": 0,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
set_weight_attrs(weight_scale, {
|
||||||
|
"weight_loader": weight_loader,
|
||||||
|
"needs_scalar_to_array": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
# WEIGHT
|
||||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||||
input_size_per_partition,
|
input_size_per_partition,
|
||||||
dtype=torch.int8),
|
dtype=torch.int8),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
layer.register_parameter("weight", weight)
|
layer.register_parameter("weight", weight)
|
||||||
set_weight_attrs(
|
set_weight_attrs(weight, {
|
||||||
weight, {
|
"input_dim": 1,
|
||||||
"input_dim": 1,
|
"output_dim": 0,
|
||||||
"output_dim": 0,
|
"weight_loader": weight_loader,
|
||||||
"weight_loader": weight_loader,
|
})
|
||||||
"logical_widths": output_partition_sizes
|
|
||||||
})
|
|
||||||
|
|
||||||
# Don't need a shard_splitter for channel-wise quantization
|
|
||||||
# Use the default loading method
|
|
||||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
|
||||||
set_weight_attrs(weight_scale, {
|
|
||||||
"output_dim": 0,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
set_weight_attrs(
|
|
||||||
weight_scale, {
|
|
||||||
"logical_widths": output_partition_sizes,
|
|
||||||
"shard_splitter": self.scales_shard_splitter,
|
|
||||||
})
|
|
||||||
|
|||||||
@ -29,6 +29,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"group_size must be given when using strategy group")
|
"group_size must be given when using strategy group")
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||||
output_partition_sizes: List[int],
|
output_partition_sizes: List[int],
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
@ -98,7 +98,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
def __init__(self, quant_config: Fp8Config):
|
||||||
self.fused_module_in_checkpoint = False
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
|
||||||
@ -114,12 +113,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
scale[:] = torch.finfo(torch.float8_e4m3fn).min
|
scale[:] = torch.finfo(torch.float8_e4m3fn).min
|
||||||
layer.register_parameter(scale_name, scale)
|
layer.register_parameter(scale_name, scale)
|
||||||
set_weight_attrs(
|
set_weight_attrs(scale, {
|
||||||
scale, {
|
**extra_weight_attrs,
|
||||||
**extra_weight_attrs,
|
"needs_scalar_to_array": True,
|
||||||
"fp8_scales_shard_indexer":
|
})
|
||||||
self.scales_shard_indexer,
|
|
||||||
})
|
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@ -170,26 +167,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
output_partition_sizes=output_partition_sizes,
|
output_partition_sizes=output_partition_sizes,
|
||||||
**extra_weight_attrs)
|
**extra_weight_attrs)
|
||||||
|
|
||||||
def scales_shard_indexer(
|
|
||||||
self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
|
||||||
shard_id: Optional[Union[str,
|
|
||||||
int]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
|
||||||
|
|
||||||
if shard_id is None:
|
|
||||||
shard_id = 0
|
|
||||||
self.fused_module_in_checkpoint = True
|
|
||||||
elif isinstance(shard_id, int):
|
|
||||||
pass
|
|
||||||
elif isinstance(shard_id, str):
|
|
||||||
if shard_id not in qkv_idxs:
|
|
||||||
raise ValueError(f"Unknown shard_id: {shard_id}")
|
|
||||||
shard_id = qkv_idxs[shard_id]
|
|
||||||
else:
|
|
||||||
ValueError(f"Shard id must be int or str but got {type(shard_id)}")
|
|
||||||
|
|
||||||
return param[shard_id], loaded_weight
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if (not hasattr(layer, "process_after_load")
|
if (not hasattr(layer, "process_after_load")
|
||||||
or not layer.process_after_load):
|
or not layer.process_after_load):
|
||||||
@ -212,7 +189,17 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# Loop over logical weights, requantizing with single scale.
|
# Loop over logical weights, requantizing with single scale.
|
||||||
max_w_scale = layer.weight_scale.max()
|
max_w_scale = layer.weight_scale.max()
|
||||||
|
|
||||||
if not self.fused_module_in_checkpoint:
|
# QKV / MLP is fused in the on disk checkpoint if any of the
|
||||||
|
# weight scales are still set to the default since we initialize
|
||||||
|
# N weight scales for N shards but we only load 1 weight scale
|
||||||
|
# from disk in this case. As a result, we skip dequant -> requant
|
||||||
|
# since we already have quantized QKV together.
|
||||||
|
# Sample Model with fused checkpoint:
|
||||||
|
# * nm-testing/Phi-3-mini-128k-instruct-FP8
|
||||||
|
unfused_module_in_checkpoint = (
|
||||||
|
layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min)
|
||||||
|
|
||||||
|
if unfused_module_in_checkpoint:
|
||||||
start = 0
|
start = 0
|
||||||
for idx, logical_width in enumerate(layer.logical_widths):
|
for idx, logical_width in enumerate(layer.logical_widths):
|
||||||
end = start + logical_width
|
end = start + logical_width
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user