mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 05:47:05 +08:00
[Kernel] Support Fp8 Checkpoints (Dynamic + Static) (#4332)
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
b31a1fb63c
commit
111815d482
90
tests/models/test_fp8.py
Normal file
90
tests/models/test_fp8.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
"""Tests fp8 models against ground truth generation
|
||||||
|
Note: these tests will only pass on L4 GPU.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
|
|
||||||
|
MAX_MODEL_LEN = 1024
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
|
||||||
|
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
|
]
|
||||||
|
|
||||||
|
EXPECTED_STRS_MAP = {
|
||||||
|
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": [
|
||||||
|
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
|
||||||
|
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||||
|
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||||
|
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
||||||
|
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
|
||||||
|
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
|
||||||
|
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||||
|
'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**'
|
||||||
|
],
|
||||||
|
"meta-llama/Meta-Llama-3-8B-Instruct": [
|
||||||
|
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||||
|
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||||
|
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||||
|
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
||||||
|
'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of',
|
||||||
|
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||||
|
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||||
|
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
fp8_not_supported = (capability <
|
||||||
|
QUANTIZATION_METHODS["fp8"].get_min_capability())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(fp8_not_supported,
|
||||||
|
reason="fp8 is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("model_name", MODELS)
|
||||||
|
def test_models(
|
||||||
|
example_prompts,
|
||||||
|
model_name,
|
||||||
|
) -> None:
|
||||||
|
model = LLM(model=model_name,
|
||||||
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
enforce_eager=True,
|
||||||
|
quantization="fp8")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
formatted_prompts = [
|
||||||
|
tokenizer.apply_chat_template([{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
for prompt in example_prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
params = SamplingParams(max_tokens=20, temperature=0)
|
||||||
|
generations = []
|
||||||
|
# Note: these need to be run 1 at a time due to numerical precision,
|
||||||
|
# since the expected strs were generated this way.
|
||||||
|
for prompt in formatted_prompts:
|
||||||
|
outputs = model.generate(prompt, params)
|
||||||
|
generations.append(outputs[0].outputs[0].text)
|
||||||
|
del model
|
||||||
|
|
||||||
|
print(generations)
|
||||||
|
expected_strs = EXPECTED_STRS_MAP[model_name]
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
generated_str = generations[i]
|
||||||
|
expected_str = expected_strs[i]
|
||||||
|
assert expected_str == generated_str, (
|
||||||
|
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")
|
||||||
@ -246,6 +246,10 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
|
# Special case for Fp8 scales.
|
||||||
|
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||||
|
None)
|
||||||
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
@ -254,6 +258,12 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
shard_size)
|
shard_size)
|
||||||
|
# Special case for Fp8 scales.
|
||||||
|
elif fp8_scales_shard_indexer is not None:
|
||||||
|
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
|
||||||
|
loaded_weight,
|
||||||
|
shard_id=0)
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
@ -317,7 +327,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
|
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
# Special case for AQLM codebooks.
|
||||||
is_metadata = getattr(param, "is_metadata", False)
|
is_metadata = getattr(param, "is_metadata", False)
|
||||||
|
# Special case for Fp8 scales.
|
||||||
|
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||||
|
None)
|
||||||
|
|
||||||
if loaded_shard_id is None:
|
if loaded_shard_id is None:
|
||||||
# Loaded weight is already packed.
|
# Loaded weight is already packed.
|
||||||
if output_dim is None:
|
if output_dim is None:
|
||||||
@ -331,14 +346,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
current_shard_offset += output_size
|
current_shard_offset += output_size
|
||||||
packed_dim = getattr(param, "packed_dim", None)
|
packed_dim = getattr(param, "packed_dim", None)
|
||||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||||
|
# Special case for Quantization.
|
||||||
# If quantized, we need to adjust the offset and size to account
|
# If quantized, we need to adjust the offset and size to account
|
||||||
# for the packing.
|
# for the packing.
|
||||||
if packed_dim == output_dim:
|
if packed_dim == output_dim:
|
||||||
shard_size = shard_size // param.pack_factor
|
shard_size = shard_size // param.pack_factor
|
||||||
shard_offset = shard_offset // param.pack_factor
|
shard_offset = shard_offset // param.pack_factor
|
||||||
|
# Special case for Marlin.
|
||||||
# If marlin, we need to adjust the offset and size to
|
|
||||||
# account for the tiling.
|
|
||||||
shard_size, shard_offset = adjust_marlin_shard(
|
shard_size, shard_offset = adjust_marlin_shard(
|
||||||
param, shard_size, shard_offset)
|
param, shard_size, shard_offset)
|
||||||
|
|
||||||
@ -353,15 +367,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
if output_dim is not None:
|
if output_dim is not None:
|
||||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||||
|
# Special case for quantization.
|
||||||
# If quantized, we need to adjust the offset and size to account
|
# If quantized, we need to adjust the offset and size to account
|
||||||
# for the packing.
|
# for the packing.
|
||||||
packed_dim = getattr(param, "packed_dim", None)
|
packed_dim = getattr(param, "packed_dim", None)
|
||||||
if packed_dim == output_dim:
|
if packed_dim == output_dim:
|
||||||
shard_size = shard_size // param.pack_factor
|
shard_size = shard_size // param.pack_factor
|
||||||
shard_offset = shard_offset // param.pack_factor
|
shard_offset = shard_offset // param.pack_factor
|
||||||
|
# Special case for Marlin.
|
||||||
# If marlin, we need to adjust the offset and size to
|
|
||||||
# account for the tiling.
|
|
||||||
shard_size, shard_offset = adjust_marlin_shard(
|
shard_size, shard_offset = adjust_marlin_shard(
|
||||||
param, shard_size, shard_offset)
|
param, shard_size, shard_offset)
|
||||||
|
|
||||||
@ -370,11 +383,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
shard_size)
|
shard_size)
|
||||||
|
# Special case for AQLM codebooks.
|
||||||
elif is_metadata:
|
elif is_metadata:
|
||||||
# metadata indicates fixed size concatenated along dim 0
|
# metadata indicates fixed size concatenated along dim 0
|
||||||
shard_size = loaded_weight.shape[0]
|
shard_size = loaded_weight.shape[0]
|
||||||
shard_offset = loaded_shard_id * shard_size
|
shard_offset = loaded_shard_id * shard_size
|
||||||
param_data = param_data.narrow(0, shard_offset, shard_size)
|
param_data = param_data.narrow(0, shard_offset, shard_size)
|
||||||
|
# Special case for Fp8 scales.
|
||||||
|
elif fp8_scales_shard_indexer is not None:
|
||||||
|
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||||
|
param_data, loaded_weight, loaded_shard_id)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
ignore_warning = getattr(param, "ignore_warning", False)
|
ignore_warning = getattr(param, "ignore_warning", False)
|
||||||
if not ignore_warning:
|
if not ignore_warning:
|
||||||
@ -455,7 +474,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
loaded_shard_id: Optional[str] = None):
|
loaded_shard_id: Optional[str] = None):
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
# Special case for AQLM codebooks.
|
||||||
is_metadata = getattr(param, "is_metadata", False)
|
is_metadata = getattr(param, "is_metadata", False)
|
||||||
|
# Special case for Fp8 scales.
|
||||||
|
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||||
|
None)
|
||||||
|
|
||||||
if loaded_shard_id is None:
|
if loaded_shard_id is None:
|
||||||
# Loaded weight is already packed.
|
# Loaded weight is already packed.
|
||||||
@ -473,14 +496,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
]
|
]
|
||||||
packed_dim = getattr(param, "packed_dim", None)
|
packed_dim = getattr(param, "packed_dim", None)
|
||||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||||
|
# Special case for Quantized Weights.
|
||||||
# If quantized, we need to adjust the offset and size to account
|
# If quantized, we need to adjust the offset and size to account
|
||||||
# for the packing.
|
# for the packing.
|
||||||
if packed_dim == output_dim:
|
if packed_dim == output_dim:
|
||||||
shard_size = shard_size // param.pack_factor
|
shard_size = shard_size // param.pack_factor
|
||||||
shard_offset = shard_offset // param.pack_factor
|
shard_offset = shard_offset // param.pack_factor
|
||||||
|
|
||||||
# If marlin, we need to adjust the offset and size to
|
# Special case for Marlin.
|
||||||
# account for the tiling.
|
|
||||||
shard_size, shard_offset = adjust_marlin_shard(
|
shard_size, shard_offset = adjust_marlin_shard(
|
||||||
param, shard_size, shard_offset)
|
param, shard_size, shard_offset)
|
||||||
|
|
||||||
@ -502,6 +525,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
shard_offset = (self.num_heads +
|
shard_offset = (self.num_heads +
|
||||||
self.num_kv_heads) * self.head_size
|
self.num_kv_heads) * self.head_size
|
||||||
shard_size = self.num_kv_heads * self.head_size
|
shard_size = self.num_kv_heads * self.head_size
|
||||||
|
# Special case for Quantized Weights.
|
||||||
# If quantized, we need to adjust the offset and size to account
|
# If quantized, we need to adjust the offset and size to account
|
||||||
# for the packing.
|
# for the packing.
|
||||||
packed_dim = getattr(param, "packed_dim", None)
|
packed_dim = getattr(param, "packed_dim", None)
|
||||||
@ -509,8 +533,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
shard_size = shard_size // param.pack_factor
|
shard_size = shard_size // param.pack_factor
|
||||||
shard_offset = shard_offset // param.pack_factor
|
shard_offset = shard_offset // param.pack_factor
|
||||||
|
|
||||||
# If marlin, we need to adjust the offset and size to
|
# Special case for Marlin.
|
||||||
# account for the tiling.
|
|
||||||
shard_size, shard_offset = adjust_marlin_shard(
|
shard_size, shard_offset = adjust_marlin_shard(
|
||||||
param, shard_size, shard_offset)
|
param, shard_size, shard_offset)
|
||||||
|
|
||||||
@ -523,12 +546,17 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
start_idx = shard_id * shard_size
|
start_idx = shard_id * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
shard_size)
|
shard_size)
|
||||||
|
# Special case for for AQLM codebooks.
|
||||||
elif is_metadata:
|
elif is_metadata:
|
||||||
# metadata indicates fixed size concatenated along dim 0
|
# metadata indicates fixed size concatenated along dim 0
|
||||||
shard_size = loaded_weight.shape[0]
|
shard_size = loaded_weight.shape[0]
|
||||||
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
||||||
param_data = param_data.narrow(0, shard_index * shard_size,
|
param_data = param_data.narrow(0, shard_index * shard_size,
|
||||||
shard_size)
|
shard_size)
|
||||||
|
# Special case for Fp8 scales.
|
||||||
|
elif fp8_scales_shard_indexer is not None:
|
||||||
|
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||||
|
param_data, loaded_weight, loaded_shard_id)
|
||||||
else:
|
else:
|
||||||
ignore_warning = getattr(param, "ignore_warning", False)
|
ignore_warning = getattr(param, "ignore_warning", False)
|
||||||
if not ignore_warning:
|
if not ignore_warning:
|
||||||
@ -611,6 +639,10 @@ class RowParallelLinear(LinearBase):
|
|||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
|
# Special case for Fp8 scales.
|
||||||
|
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||||
|
None)
|
||||||
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
input_dim = getattr(param, "input_dim", None)
|
input_dim = getattr(param, "input_dim", None)
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
@ -619,6 +651,12 @@ class RowParallelLinear(LinearBase):
|
|||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||||
shard_size)
|
shard_size)
|
||||||
|
# Special case for Fp8 scales.
|
||||||
|
elif fp8_scales_shard_indexer is not None:
|
||||||
|
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
|
||||||
|
loaded_weight,
|
||||||
|
shard_id=0)
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|||||||
@ -1,23 +1,36 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Fp8Config(QuantizationConfig):
|
class Fp8Config(QuantizationConfig):
|
||||||
"""Config class for FP8."""
|
"""Config class for FP8."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
is_checkpoint_fp8_serialized: bool = False,
|
||||||
activation_scheme: str = "dynamic",
|
activation_scheme: str = "dynamic",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||||
|
if is_checkpoint_fp8_serialized:
|
||||||
|
logger.warning("Detected fp8 checkpoint. Please note that the "
|
||||||
|
"format is experimental and subject to change.")
|
||||||
|
if activation_scheme not in ACTIVATION_SCHEMES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported activation scheme {activation_scheme}")
|
||||||
self.activation_scheme = activation_scheme
|
self.activation_scheme = activation_scheme
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -30,10 +43,7 @@ class Fp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
# TODO: PyTorch 2.3.0+ is required to run FP8 on
|
return 89
|
||||||
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
|
|
||||||
# be included: https://github.com/pytorch/pytorch/pull/118881
|
|
||||||
return 90
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
@ -41,11 +51,14 @@ class Fp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
||||||
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||||
|
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
|
||||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||||
return cls(activation_scheme)
|
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||||
|
activation_scheme=activation_scheme)
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
|
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return Fp8LinearMethod(self)
|
return Fp8LinearMethod(self)
|
||||||
return None
|
return None
|
||||||
@ -56,8 +69,12 @@ class Fp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
class Fp8LinearMethod(LinearMethodBase):
|
class Fp8LinearMethod(LinearMethodBase):
|
||||||
"""Linear method for FP8.
|
"""Linear method for FP8.
|
||||||
We now support common FP16/BF16 model checkpoints ONLY. The weight
|
Supports loading FP8 checkpoints with static weight scale and
|
||||||
scaling factor will be initialized after the model weights are loaded.
|
dynamic/static activation scale.
|
||||||
|
|
||||||
|
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||||
|
activation scaling. The weight scaling factor will be initialized after
|
||||||
|
the model weights are loaded.
|
||||||
|
|
||||||
Limitations:
|
Limitations:
|
||||||
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
||||||
@ -71,6 +88,24 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
def __init__(self, quant_config: Fp8Config):
|
def __init__(self, quant_config: Fp8Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def _create_scale_param(
|
||||||
|
self,
|
||||||
|
scale_name: str,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
**extra_weight_attrs,
|
||||||
|
) -> None:
|
||||||
|
scale = Parameter(torch.empty(len(output_partition_sizes),
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter(scale_name, scale)
|
||||||
|
set_weight_attrs(
|
||||||
|
scale, {
|
||||||
|
**extra_weight_attrs,
|
||||||
|
"fp8_scales_shard_indexer":
|
||||||
|
self.scales_shard_indexer,
|
||||||
|
})
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -81,46 +116,150 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
|
del input_size, output_size
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
|
||||||
|
layer.process_after_load = True
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
|
# WEIGHT
|
||||||
|
weight_dtype = (torch.float8_e4m3fn
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||||
|
params_dtype)
|
||||||
weight = Parameter(torch.empty(output_size_per_partition,
|
weight = Parameter(torch.empty(output_size_per_partition,
|
||||||
input_size_per_partition,
|
input_size_per_partition,
|
||||||
dtype=params_dtype),
|
dtype=weight_dtype),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("weight", weight)
|
layer.register_parameter("weight", weight)
|
||||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
set_weight_attrs(weight, {
|
||||||
set_weight_attrs(weight, extra_weight_attrs)
|
**extra_weight_attrs,
|
||||||
|
"input_dim": 1,
|
||||||
|
"output_dim": 0,
|
||||||
|
})
|
||||||
|
|
||||||
w_scale = Parameter(
|
# If checkpoint is serialized fp8, load them.
|
||||||
torch.empty(1, dtype=torch.float32),
|
# Otherwise, wait until process_weights_after_loading.
|
||||||
requires_grad=False,
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
)
|
# WEIGHT SCALE
|
||||||
layer.register_parameter("weight_scaling_factor", w_scale)
|
self._create_scale_param(
|
||||||
|
scale_name="weight_scale",
|
||||||
|
layer=layer,
|
||||||
|
output_partition_sizes=output_partition_sizes,
|
||||||
|
**extra_weight_attrs)
|
||||||
|
|
||||||
|
# ACTIVATION SCALE
|
||||||
|
if self.quant_config.activation_scheme == "static":
|
||||||
|
self._create_scale_param(
|
||||||
|
scale_name="act_scale",
|
||||||
|
layer=layer,
|
||||||
|
output_partition_sizes=output_partition_sizes,
|
||||||
|
**extra_weight_attrs)
|
||||||
|
|
||||||
|
def scales_shard_indexer(
|
||||||
|
self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
||||||
|
shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||||
|
|
||||||
|
if isinstance(shard_id, int):
|
||||||
|
pass
|
||||||
|
elif isinstance(shard_id, str):
|
||||||
|
if shard_id not in qkv_idxs:
|
||||||
|
raise ValueError(f"Unknown shard_id: {shard_id}")
|
||||||
|
shard_id = qkv_idxs[shard_id]
|
||||||
|
else:
|
||||||
|
ValueError(f"Shard id must be int or str but got {type(shard_id)}")
|
||||||
|
|
||||||
|
return param[shard_id], loaded_weight
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
# Although the quant_method is propagated to all layers,
|
if (not hasattr(layer, "process_after_load")
|
||||||
# only linear layers invoke "create_weights". So we check
|
or not layer.process_after_load):
|
||||||
# whether "weight_scaling_facor" is registered to determine
|
|
||||||
# whether the layer is a linear layer that requires quantization.
|
|
||||||
if not hasattr(layer, "weight_scaling_factor"):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
|
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
|
||||||
# torch._scaled_mm requires column-major in the second
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
# input (weight), so we transpose the quantized weight.
|
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
scale=None)
|
||||||
layer.weight_scaling_factor.data.copy_(weight_scale)
|
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||||
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
|
layer.logical_widths = None
|
||||||
|
layer.act_scale = None
|
||||||
|
return
|
||||||
|
|
||||||
|
# If checkpoint is fp8, requantize the separately quantized logical
|
||||||
|
# weights into a single fp8 weight with a single weight scale.
|
||||||
|
else:
|
||||||
|
# WEIGHT_SCALE / WEIGHT
|
||||||
|
# Loop over logical weights, requantizing with single scale.
|
||||||
|
max_w_scale = layer.weight_scale.max()
|
||||||
|
start = 0
|
||||||
|
for idx, logical_width in enumerate(layer.logical_widths):
|
||||||
|
end = start + logical_width
|
||||||
|
weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
|
||||||
|
layer.weight_scale[idx])
|
||||||
|
|
||||||
|
layer.weight[start:end, :] = per_tensor_quantize(
|
||||||
|
weight_dq, layer.weight_scale.max())
|
||||||
|
start = end
|
||||||
|
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||||
|
|
||||||
|
# WEIGHT
|
||||||
|
# Transpose weight for passing to torch._scaled_mm
|
||||||
|
weight = layer.weight
|
||||||
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
|
|
||||||
|
# ACT_SCALE
|
||||||
|
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
|
||||||
|
# Static: set to max of the act_scales (since they are equal).
|
||||||
|
if self.quant_config.activation_scheme == "dynamic":
|
||||||
|
layer.act_scale = None
|
||||||
|
elif self.quant_config.activation_scheme == "static":
|
||||||
|
if not all_close_1d(layer.act_scale):
|
||||||
|
raise ValueError(
|
||||||
|
"All the act_scales for the logical weights of a layer "
|
||||||
|
f"must be equal. But got {layer.act_scale}")
|
||||||
|
layer.act_scale = Parameter(layer.act_scale.max(),
|
||||||
|
requires_grad=False)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown scheme {self.quant_config.activation_scheme}")
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(x)
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
|
# If dynamic, layer.act_scale is None and x_scale computed from x.
|
||||||
|
# If static, layer.act_scale is scalar and x_scale set to act_scale.
|
||||||
|
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
|
||||||
|
|
||||||
|
# Fused GEMM_DQ
|
||||||
output, _ = torch._scaled_mm(
|
output, _ = torch._scaled_mm(
|
||||||
qinput,
|
qinput,
|
||||||
layer.weight,
|
layer.weight,
|
||||||
out_dtype=x.dtype,
|
out_dtype=x.dtype,
|
||||||
scale_a=x_scale,
|
scale_a=x_scale,
|
||||||
scale_b=layer.weight_scaling_factor,
|
scale_b=layer.weight_scale,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def all_close_1d(x: torch.Tensor) -> bool:
|
||||||
|
assert len(x.shape) == 1
|
||||||
|
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def per_tensor_quantize(tensor: torch.Tensor,
|
||||||
|
inv_scale: float) -> torch.Tensor:
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
return qweight.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
|
||||||
|
def per_tensor_dequantize(tensor: torch.Tensor,
|
||||||
|
inv_scale: float) -> torch.Tensor:
|
||||||
|
fake_qweight = tensor.to(torch.float16)
|
||||||
|
dq_weight = fake_qweight * inv_scale
|
||||||
|
return dq_weight
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user