mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:44:56 +08:00
[Kernel][FP8] Initial support with dynamic per-tensor scaling (#4118)
Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726 This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine. Algorithm: We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass. Initial Results: Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128: BF16: 1.47s FP8: 1.66s I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code.
This commit is contained in:
parent
682789d402
commit
a22cdea371
24
tests/quantization/test_fp8.py
Normal file
24
tests/quantization/test_fp8.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""Tests whether FP8 computation is enabled correctly.
|
||||
|
||||
Run `pytest tests/quantization/test_fp8.py --forked`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
capability < QUANTIZATION_METHODS["fp8"].get_min_capability(),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_load_fp16_model(vllm_runner) -> None:
|
||||
llm = vllm_runner("facebook/opt-125m", quantization="fp8")
|
||||
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
fc1 = model.model.decoder.layers[0].fc1
|
||||
assert isinstance(fc1.linear_method, Fp8LinearMethod)
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||
@ -42,10 +42,11 @@ class LLM:
|
||||
However, if the `torch_dtype` in the config is `float32`, we will
|
||||
use `float16` instead.
|
||||
quantization: The method used to quantize the model weights. Currently,
|
||||
we support "awq", "gptq" and "squeezellm". If None, we first check
|
||||
the `quantization_config` attribute in the model config file. If
|
||||
that is None, we assume the model weights are not quantized and use
|
||||
`dtype` to determine the data type of the weights.
|
||||
we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
|
||||
If None, we first check the `quantization_config` attribute in the
|
||||
model config file. If that is None, we assume the model weights are
|
||||
not quantized and use `dtype` to determine the data type of
|
||||
the weights.
|
||||
revision: The specific model version to use. It can be a branch name,
|
||||
a tag name, or a commit id.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
@ -48,6 +49,13 @@ class LinearMethodBase(ABC):
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Process the weight after loading.
|
||||
|
||||
This can be used for example, to transpose weights for computation.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class UnquantizedLinearMethod(LinearMethodBase):
|
||||
"""Linear method without quantization.
|
||||
|
||||
@ -3,12 +3,14 @@ from typing import Type
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.fp8 import FP8Config
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
|
||||
QUANTIZATION_METHODS = {
|
||||
"awq": AWQConfig,
|
||||
"fp8": FP8Config,
|
||||
"gptq": GPTQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
"marlin": MarlinConfig,
|
||||
|
||||
138
vllm/model_executor/layers/quantization/fp8.py
Normal file
138
vllm/model_executor/layers/quantization/fp8.py
Normal file
@ -0,0 +1,138 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
|
||||
class FP8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "fp8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# TODO: PyTorch 2.3.0+ is required to run FP8 on
|
||||
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
|
||||
# be included: https://github.com/pytorch/pytorch/pull/118881
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
|
||||
return cls()
|
||||
|
||||
def get_linear_method(self) -> "Fp8LinearMethod":
|
||||
return Fp8LinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class Fp8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for FP8.
|
||||
We now support common FP16/BF16 model checkpoints ONLY. The weight
|
||||
scaling factor will be initialized after the model weights are loaded.
|
||||
|
||||
Limitations:
|
||||
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
||||
2. Only support float8_e4m3fn data type due to the limitation of
|
||||
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: FP8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
w_scale = Parameter(
|
||||
torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("weight_scaling_factor", w_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Although the linear_method is propagated to all layers,
|
||||
# only linear layers invoke "create_weights". So we check
|
||||
# whether "weight_scaling_facor" is registered to determine
|
||||
# whether the layer is a linear layer that requires quantization.
|
||||
if not hasattr(layer, "weight_scaling_factor"):
|
||||
return
|
||||
|
||||
qweight, weight_scale = per_tensor_quantize(layer.weight)
|
||||
# torch._scaled_mm requires column-major in the second
|
||||
# input (weight), so we transpose the quantized weight.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scaling_factor.data.copy_(weight_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qinput, x_scale = per_tensor_quantize(x)
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scaling_factor,
|
||||
bias=bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]:
|
||||
"""Quantize a tensor using per-tensor static scaling factor.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor.
|
||||
"""
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
# Calculate the scale as dtype max divided by absmax.
|
||||
# Since .abs() creates a new tensor, we use aminmax to get
|
||||
# the min and max first and then calculate the absmax.
|
||||
min_val, max_val = tensor.aminmax()
|
||||
amax = min_val.abs().max(max_val.abs())
|
||||
scale = finfo.max / amax.clamp(min=1e-12)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
scale = scale.float().reciprocal()
|
||||
return qweight, scale
|
||||
@ -228,6 +228,10 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
model,
|
||||
"fall_back_to_pt_during_load",
|
||||
True)), )
|
||||
for _, module in model.named_modules():
|
||||
linear_method = getattr(module, "linear_method", None)
|
||||
if linear_method is not None:
|
||||
linear_method.process_weights_after_loading(module)
|
||||
return model.eval()
|
||||
|
||||
|
||||
|
||||
@ -134,11 +134,18 @@ def get_quant_config(model_config: ModelConfig,
|
||||
tqdm_class=DisabledTqdm)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
possible_config_filenames = quant_cls.get_config_filenames()
|
||||
|
||||
# If the quantization config is not found, use the default config.
|
||||
if not possible_config_filenames:
|
||||
return quant_cls()
|
||||
|
||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||
|
||||
quant_config_files = [
|
||||
f for f in config_files if any(
|
||||
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||
f.endswith(x) for x in possible_config_filenames)
|
||||
]
|
||||
if len(quant_config_files) == 0:
|
||||
raise ValueError(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user