mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 16:08:41 +08:00
[Kernel] Support Fp8 Checkpoints (Dynamic + Static) (#4332)
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
b31a1fb63c
commit
111815d482
90
tests/models/test_fp8.py
Normal file
90
tests/models/test_fp8.py
Normal file
@ -0,0 +1,90 @@
|
||||
# flake8: noqa
|
||||
"""Tests fp8 models against ground truth generation
|
||||
Note: these tests will only pass on L4 GPU.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
|
||||
MODELS = [
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
]
|
||||
|
||||
EXPECTED_STRS_MAP = {
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": [
|
||||
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
|
||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
||||
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
|
||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
|
||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||
'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**'
|
||||
],
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct": [
|
||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
||||
'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of',
|
||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
|
||||
],
|
||||
}
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
fp8_not_supported = (capability <
|
||||
QUANTIZATION_METHODS["fp8"].get_min_capability())
|
||||
|
||||
|
||||
@pytest.mark.skipif(fp8_not_supported,
|
||||
reason="fp8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_name", MODELS)
|
||||
def test_models(
|
||||
example_prompts,
|
||||
model_name,
|
||||
) -> None:
|
||||
model = LLM(model=model_name,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
enforce_eager=True,
|
||||
quantization="fp8")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
formatted_prompts = [
|
||||
tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
for prompt in example_prompts
|
||||
]
|
||||
|
||||
params = SamplingParams(max_tokens=20, temperature=0)
|
||||
generations = []
|
||||
# Note: these need to be run 1 at a time due to numerical precision,
|
||||
# since the expected strs were generated this way.
|
||||
for prompt in formatted_prompts:
|
||||
outputs = model.generate(prompt, params)
|
||||
generations.append(outputs[0].outputs[0].text)
|
||||
del model
|
||||
|
||||
print(generations)
|
||||
expected_strs = EXPECTED_STRS_MAP[model_name]
|
||||
for i in range(len(example_prompts)):
|
||||
generated_str = generations[i]
|
||||
expected_str = expected_strs[i]
|
||||
assert expected_str == generated_str, (
|
||||
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")
|
||||
@ -246,6 +246,10 @@ class ColumnParallelLinear(LinearBase):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
# Special case for Fp8 scales.
|
||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||
None)
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
param_data = param.data
|
||||
@ -254,6 +258,12 @@ class ColumnParallelLinear(LinearBase):
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
|
||||
loaded_weight,
|
||||
shard_id=0)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
@ -317,7 +327,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
# Special case for Fp8 scales.
|
||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||
None)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
if output_dim is None:
|
||||
@ -331,14 +346,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
current_shard_offset += output_size
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to
|
||||
# account for the tiling.
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
@ -353,15 +367,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
if output_dim is not None:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
# Special case for quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to
|
||||
# account for the tiling.
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
@ -370,11 +383,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
# Special case for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_offset = loaded_shard_id * shard_size
|
||||
param_data = param_data.narrow(0, shard_offset, shard_size)
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
@ -455,7 +474,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
# Special case for Fp8 scales.
|
||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||
None)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
@ -473,14 +496,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
]
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to
|
||||
# account for the tiling.
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
@ -502,6 +525,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = (self.num_heads +
|
||||
self.num_kv_heads) * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
@ -509,8 +533,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to
|
||||
# account for the tiling.
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
@ -523,12 +546,17 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
start_idx = shard_id * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
# Special case for for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
||||
param_data = param_data.narrow(0, shard_index * shard_size,
|
||||
shard_size)
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
@ -611,6 +639,10 @@ class RowParallelLinear(LinearBase):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
# Special case for Fp8 scales.
|
||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||
None)
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
param_data = param.data
|
||||
@ -619,6 +651,12 @@ class RowParallelLinear(LinearBase):
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||
shard_size)
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
|
||||
loaded_weight,
|
||||
shard_id=0)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@ -1,23 +1,36 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
activation_scheme: str = "dynamic",
|
||||
) -> None:
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
if is_checkpoint_fp8_serialized:
|
||||
logger.warning("Detected fp8 checkpoint. Please note that the "
|
||||
"format is experimental and subject to change.")
|
||||
if activation_scheme not in ACTIVATION_SCHEMES:
|
||||
raise ValueError(
|
||||
f"Unsupported activation scheme {activation_scheme}")
|
||||
self.activation_scheme = activation_scheme
|
||||
|
||||
@classmethod
|
||||
@ -30,10 +43,7 @@ class Fp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# TODO: PyTorch 2.3.0+ is required to run FP8 on
|
||||
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
|
||||
# be included: https://github.com/pytorch/pytorch/pull/118881
|
||||
return 90
|
||||
return 89
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@ -41,11 +51,14 @@ class Fp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
return cls(activation_scheme)
|
||||
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||
activation_scheme=activation_scheme)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
|
||||
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return Fp8LinearMethod(self)
|
||||
return None
|
||||
@ -56,8 +69,12 @@ class Fp8Config(QuantizationConfig):
|
||||
|
||||
class Fp8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for FP8.
|
||||
We now support common FP16/BF16 model checkpoints ONLY. The weight
|
||||
scaling factor will be initialized after the model weights are loaded.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
|
||||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||
activation scaling. The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
|
||||
Limitations:
|
||||
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
||||
@ -71,6 +88,24 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def _create_scale_param(
|
||||
self,
|
||||
scale_name: str,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
scale = Parameter(torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter(scale_name, scale)
|
||||
set_weight_attrs(
|
||||
scale, {
|
||||
**extra_weight_attrs,
|
||||
"fp8_scales_shard_indexer":
|
||||
self.scales_shard_indexer,
|
||||
})
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -81,46 +116,150 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
layer.process_after_load = True
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
params_dtype)
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
dtype=weight_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
set_weight_attrs(weight, {
|
||||
**extra_weight_attrs,
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
|
||||
w_scale = Parameter(
|
||||
torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("weight_scaling_factor", w_scale)
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
self._create_scale_param(
|
||||
scale_name="weight_scale",
|
||||
layer=layer,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
|
||||
# ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
self._create_scale_param(
|
||||
scale_name="act_scale",
|
||||
layer=layer,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
|
||||
def scales_shard_indexer(
|
||||
self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
||||
shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
if isinstance(shard_id, int):
|
||||
pass
|
||||
elif isinstance(shard_id, str):
|
||||
if shard_id not in qkv_idxs:
|
||||
raise ValueError(f"Unknown shard_id: {shard_id}")
|
||||
shard_id = qkv_idxs[shard_id]
|
||||
else:
|
||||
ValueError(f"Shard id must be int or str but got {type(shard_id)}")
|
||||
|
||||
return param[shard_id], loaded_weight
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Although the quant_method is propagated to all layers,
|
||||
# only linear layers invoke "create_weights". So we check
|
||||
# whether "weight_scaling_facor" is registered to determine
|
||||
# whether the layer is a linear layer that requires quantization.
|
||||
if not hasattr(layer, "weight_scaling_factor"):
|
||||
if (not hasattr(layer, "process_after_load")
|
||||
or not layer.process_after_load):
|
||||
return
|
||||
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
|
||||
# torch._scaled_mm requires column-major in the second
|
||||
# input (weight), so we transpose the quantized weight.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scaling_factor.data.copy_(weight_scale)
|
||||
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
scale=None)
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.logical_widths = None
|
||||
layer.act_scale = None
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, requantize the separately quantized logical
|
||||
# weights into a single fp8 weight with a single weight scale.
|
||||
else:
|
||||
# WEIGHT_SCALE / WEIGHT
|
||||
# Loop over logical weights, requantizing with single scale.
|
||||
max_w_scale = layer.weight_scale.max()
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(layer.logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
|
||||
layer.weight_scale[idx])
|
||||
|
||||
layer.weight[start:end, :] = per_tensor_quantize(
|
||||
weight_dq, layer.weight_scale.max())
|
||||
start = end
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# WEIGHT
|
||||
# Transpose weight for passing to torch._scaled_mm
|
||||
weight = layer.weight
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
# ACT_SCALE
|
||||
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
|
||||
# Static: set to max of the act_scales (since they are equal).
|
||||
if self.quant_config.activation_scheme == "dynamic":
|
||||
layer.act_scale = None
|
||||
elif self.quant_config.activation_scheme == "static":
|
||||
if not all_close_1d(layer.act_scale):
|
||||
raise ValueError(
|
||||
"All the act_scales for the logical weights of a layer "
|
||||
f"must be equal. But got {layer.act_scale}")
|
||||
layer.act_scale = Parameter(layer.act_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown scheme {self.quant_config.activation_scheme}")
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x)
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.act_scale is None and x_scale computed from x.
|
||||
# If static, layer.act_scale is scalar and x_scale set to act_scale.
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scaling_factor,
|
||||
scale_b=layer.weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def all_close_1d(x: torch.Tensor) -> bool:
|
||||
assert len(x.shape) == 1
|
||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||
|
||||
|
||||
def per_tensor_quantize(tensor: torch.Tensor,
|
||||
inv_scale: float) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return qweight.to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def per_tensor_dequantize(tensor: torch.Tensor,
|
||||
inv_scale: float) -> torch.Tensor:
|
||||
fake_qweight = tensor.to(torch.float16)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user