mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 06:25:01 +08:00
[Feature]: Support NVIDIA ModelOpt HF FP8 variants FP8_PER_CHANNEL_PER_TOKEN and FP8_PB_WO in vLLM (#30957)
This commit is contained in:
parent
097978a15d
commit
19cc9468fd
@ -8,6 +8,16 @@ We recommend installing the library with:
|
||||
pip install nvidia-modelopt
|
||||
```
|
||||
|
||||
## Supported ModelOpt checkpoint formats
|
||||
|
||||
vLLM detects ModelOpt checkpoints via `hf_quant_config.json` and supports the
|
||||
following `quantization.quant_algo` values:
|
||||
|
||||
- `FP8`: per-tensor weight scale (+ optional static activation scale).
|
||||
- `FP8_PER_CHANNEL_PER_TOKEN`: per-channel weight scale and dynamic per-token activation quantization.
|
||||
- `FP8_PB_WO` (ModelOpt may emit `fp8_pb_wo`): block-scaled FP8 weight-only (typically 128×128 blocks).
|
||||
- `NVFP4`: ModelOpt NVFP4 checkpoints (use `quantization="modelopt_fp4"`).
|
||||
|
||||
## Quantizing HuggingFace Models with PTQ
|
||||
|
||||
You can quantize HuggingFace models using the example scripts provided in the Model Optimizer repository. The primary script for LLM PTQ is typically found within the `examples/llm_ptq` directory.
|
||||
@ -80,3 +90,24 @@ The quantized checkpoint can then be deployed with vLLM. As an example, the foll
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
## Running the OpenAI-compatible server
|
||||
|
||||
To serve a local ModelOpt checkpoint via the OpenAI-compatible API:
|
||||
|
||||
```bash
|
||||
vllm serve <path_to_exported_checkpoint> \
|
||||
--quantization modelopt \
|
||||
--host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
## Testing (local checkpoints)
|
||||
|
||||
vLLM's ModelOpt unit tests are gated by local checkpoint paths and are skipped
|
||||
by default in CI. To run the tests locally:
|
||||
|
||||
```bash
|
||||
export VLLM_TEST_MODELOPT_FP8_PC_PT_MODEL_PATH=<path_to_fp8_pc_pt_checkpoint>
|
||||
export VLLM_TEST_MODELOPT_FP8_PB_WO_MODEL_PATH=<path_to_fp8_pb_wo_checkpoint>
|
||||
pytest -q tests/quantization/test_modelopt.py
|
||||
```
|
||||
|
||||
@ -6,6 +6,7 @@ Run `pytest tests/quantization/test_modelopt.py`.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import NoReturn
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -19,6 +20,28 @@ def enable_pickle(monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
def _skip(msg: str) -> NoReturn:
|
||||
pytest.skip(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
def _snapshot_download_or_skip(model_id: str) -> str:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except Exception as e: # pragma: no cover
|
||||
_skip(f"huggingface_hub is required to download {model_id}: {e}")
|
||||
|
||||
try:
|
||||
return snapshot_download(
|
||||
repo_id=model_id,
|
||||
repo_type="model",
|
||||
# These checkpoints are already small; download full repo for simplicity.
|
||||
allow_patterns=["*"],
|
||||
)
|
||||
except Exception as e:
|
||||
_skip(f"Failed to download {model_id} from the HF Hub: {e}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("modelopt"),
|
||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||
@ -91,3 +114,121 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
|
||||
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
|
||||
assert output
|
||||
print(f"ModelOpt FP8 output: {output}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("modelopt"),
|
||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
|
||||
"""Test ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoint setup."""
|
||||
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pc-pt"
|
||||
model_path = _snapshot_download_or_skip(model_id)
|
||||
|
||||
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
o_proj = layer.self_attn.o_proj
|
||||
gate_up_proj = layer.mlp.gate_up_proj
|
||||
down_proj = layer.mlp.down_proj
|
||||
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptFp8PcPtLinearMethod,
|
||||
)
|
||||
|
||||
assert isinstance(qkv_proj.quant_method, ModelOptFp8PcPtLinearMethod)
|
||||
assert isinstance(o_proj.quant_method, ModelOptFp8PcPtLinearMethod)
|
||||
assert isinstance(gate_up_proj.quant_method, ModelOptFp8PcPtLinearMethod)
|
||||
assert isinstance(down_proj.quant_method, ModelOptFp8PcPtLinearMethod)
|
||||
|
||||
assert qkv_proj.weight.dtype == torch.float8_e4m3fn
|
||||
assert o_proj.weight.dtype == torch.float8_e4m3fn
|
||||
assert gate_up_proj.weight.dtype == torch.float8_e4m3fn
|
||||
assert down_proj.weight.dtype == torch.float8_e4m3fn
|
||||
|
||||
# Per-channel scales; activations are dynamically scaled per token.
|
||||
assert hasattr(qkv_proj, "weight_scale")
|
||||
assert qkv_proj.weight_scale.dtype == torch.float32
|
||||
assert qkv_proj.weight_scale.dim() == 1
|
||||
assert not hasattr(qkv_proj, "input_scale")
|
||||
|
||||
assert hasattr(o_proj, "weight_scale")
|
||||
assert o_proj.weight_scale.dtype == torch.float32
|
||||
assert o_proj.weight_scale.dim() == 1
|
||||
assert not hasattr(o_proj, "input_scale")
|
||||
|
||||
assert hasattr(gate_up_proj, "weight_scale")
|
||||
assert gate_up_proj.weight_scale.dtype == torch.float32
|
||||
assert gate_up_proj.weight_scale.dim() == 1
|
||||
assert not hasattr(gate_up_proj, "input_scale")
|
||||
|
||||
assert hasattr(down_proj, "weight_scale")
|
||||
assert down_proj.weight_scale.dtype == torch.float32
|
||||
assert down_proj.weight_scale.dim() == 1
|
||||
assert not hasattr(down_proj, "input_scale")
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
|
||||
assert output
|
||||
print(f"ModelOpt FP8_PER_CHANNEL_PER_TOKEN output: {output}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("modelopt"),
|
||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_modelopt_fp8_pb_wo_checkpoint_setup(vllm_runner):
|
||||
"""Test ModelOpt FP8_PB_WO checkpoint setup."""
|
||||
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pb-wo"
|
||||
model_path = _snapshot_download_or_skip(model_id)
|
||||
|
||||
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
o_proj = layer.self_attn.o_proj
|
||||
gate_up_proj = layer.mlp.gate_up_proj
|
||||
down_proj = layer.mlp.down_proj
|
||||
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptFp8PbWoLinearMethod,
|
||||
)
|
||||
|
||||
assert isinstance(qkv_proj.quant_method, ModelOptFp8PbWoLinearMethod)
|
||||
assert isinstance(o_proj.quant_method, ModelOptFp8PbWoLinearMethod)
|
||||
assert isinstance(gate_up_proj.quant_method, ModelOptFp8PbWoLinearMethod)
|
||||
assert isinstance(down_proj.quant_method, ModelOptFp8PbWoLinearMethod)
|
||||
|
||||
assert qkv_proj.weight.dtype == torch.float8_e4m3fn
|
||||
assert o_proj.weight.dtype == torch.float8_e4m3fn
|
||||
assert gate_up_proj.weight.dtype == torch.float8_e4m3fn
|
||||
assert down_proj.weight.dtype == torch.float8_e4m3fn
|
||||
|
||||
# Block scales; should be materialized as a 2D [out_blk, in_blk] tensor.
|
||||
assert hasattr(qkv_proj, "weight_scale")
|
||||
assert qkv_proj.weight_scale.dtype == torch.float32
|
||||
assert qkv_proj.weight_scale.dim() == 2
|
||||
|
||||
assert hasattr(o_proj, "weight_scale")
|
||||
assert o_proj.weight_scale.dtype == torch.float32
|
||||
assert o_proj.weight_scale.dim() == 2
|
||||
|
||||
assert hasattr(gate_up_proj, "weight_scale")
|
||||
assert gate_up_proj.weight_scale.dtype == torch.float32
|
||||
assert gate_up_proj.weight_scale.dim() == 2
|
||||
|
||||
assert hasattr(down_proj, "weight_scale")
|
||||
assert down_proj.weight_scale.dtype == torch.float32
|
||||
assert down_proj.weight_scale.dim() == 2
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
|
||||
assert output
|
||||
print(f"ModelOpt FP8_PB_WO output: {output}")
|
||||
|
||||
@ -843,12 +843,18 @@ class ModelConfig:
|
||||
producer_name = quant_cfg.get("producer", {}).get("name")
|
||||
if producer_name == "modelopt":
|
||||
quant_algo = quant_cfg.get("quantization", {}).get("quant_algo")
|
||||
if quant_algo == "FP8":
|
||||
quant_cfg["quant_method"] = "modelopt"
|
||||
elif quant_algo == "NVFP4":
|
||||
quant_cfg["quant_method"] = "modelopt_fp4"
|
||||
elif quant_algo is not None:
|
||||
raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}")
|
||||
if quant_algo is not None:
|
||||
quant_algo_upper = str(quant_algo).upper()
|
||||
if quant_algo_upper in {
|
||||
"FP8",
|
||||
"FP8_PER_CHANNEL_PER_TOKEN",
|
||||
"FP8_PB_WO",
|
||||
}:
|
||||
quant_cfg["quant_method"] = "modelopt"
|
||||
elif quant_algo_upper == "NVFP4":
|
||||
quant_cfg["quant_method"] = "modelopt_fp4"
|
||||
else:
|
||||
raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}")
|
||||
|
||||
return quant_cfg
|
||||
|
||||
|
||||
@ -53,6 +53,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"GPTQLinearMethod",
|
||||
"FBGEMMFp8LinearMethod",
|
||||
"ModelOptFp8LinearMethod",
|
||||
"ModelOptFp8PcPtLinearMethod",
|
||||
"ModelOptFp8PbWoLinearMethod",
|
||||
"IPEXAWQLinearMethod",
|
||||
"IPEXGPTQLinearMethod",
|
||||
"HQQMarlinMethod",
|
||||
|
||||
@ -55,6 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
select_cutlass_fp8_gemm_impl,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
)
|
||||
@ -72,9 +75,15 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
cutlass_block_fp8_supported,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_scaled_fp4_mm,
|
||||
@ -88,7 +97,16 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
QUANT_ALGOS = ["FP8", "NVFP4"]
|
||||
QUANT_ALGOS = [
|
||||
# FP8 (per-tensor weight + optional static activation scale).
|
||||
"FP8",
|
||||
# FP8 per-channel weight scale + per-token activation scale.
|
||||
"FP8_PER_CHANNEL_PER_TOKEN",
|
||||
# FP8 per-block weight-only (ModelOpt may emit this as lowercase).
|
||||
"FP8_PB_WO",
|
||||
# FP4
|
||||
"NVFP4",
|
||||
]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||
|
||||
|
||||
@ -255,6 +273,9 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
if not quant_method:
|
||||
raise ValueError("Missing 'quant_algo' in quantization config")
|
||||
|
||||
# Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
|
||||
quant_method = str(quant_method).upper()
|
||||
|
||||
if kv_cache_quant_method is None:
|
||||
# No KV cache quantization, keep this branch just to have this comment
|
||||
pass
|
||||
@ -263,6 +284,8 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
f"kv_cache_quant_algo must be a string, got "
|
||||
f"{type(kv_cache_quant_method)}"
|
||||
)
|
||||
else:
|
||||
kv_cache_quant_method = kv_cache_quant_method.upper()
|
||||
|
||||
if not isinstance(exclude_modules, list):
|
||||
raise ValueError(
|
||||
@ -302,17 +325,34 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_method: str,
|
||||
is_checkpoint_fp8_serialized: bool,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
) -> None:
|
||||
super().__init__(exclude_modules)
|
||||
self.quant_method = quant_method
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
self.kv_cache_quant_method = kv_cache_quant_method
|
||||
if is_checkpoint_fp8_serialized:
|
||||
logger.warning(
|
||||
"Detected ModelOpt fp8 checkpoint. Please note that"
|
||||
" the format is experimental and could change."
|
||||
"Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
|
||||
"that the format is experimental and could change.",
|
||||
quant_method,
|
||||
)
|
||||
|
||||
# Select LinearMethod implementation based on quant_algo.
|
||||
if self.quant_method == "FP8":
|
||||
self.LinearMethodCls = ModelOptFp8LinearMethod
|
||||
elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
|
||||
self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
|
||||
elif self.quant_method == "FP8_PB_WO":
|
||||
self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported ModelOpt FP8 quant_algo for vLLM: "
|
||||
f"{self.quant_method}. Supported: FP8 / "
|
||||
"FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
|
||||
)
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
@ -346,13 +386,13 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = quant_config.get("quant_algo", "")
|
||||
if "FP8" in quant_algo:
|
||||
quant_algo = str(quant_config.get("quant_algo", ""))
|
||||
if "FP8" in quant_algo.upper():
|
||||
return "modelopt"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific quant_algo
|
||||
quant_algo = hf_quant_cfg.get("quant_algo", "")
|
||||
if isinstance(quant_algo, str) and "FP8" in quant_algo:
|
||||
quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
|
||||
if "FP8" in quant_algo.upper():
|
||||
return "modelopt"
|
||||
|
||||
return None
|
||||
@ -369,7 +409,12 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
|
||||
) -> "ModelOptFp8Config":
|
||||
is_checkpoint_fp8_serialized = "FP8" in quant_method
|
||||
|
||||
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
|
||||
return cls(
|
||||
quant_method,
|
||||
is_checkpoint_fp8_serialized,
|
||||
kv_cache_quant_method,
|
||||
exclude_modules,
|
||||
)
|
||||
|
||||
|
||||
class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
@ -464,6 +509,203 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
|
||||
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
|
||||
"""Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.
|
||||
|
||||
Expected checkpoint structure (per Linear):
|
||||
- weight: fp8-e4m3fn, shape [out, in]
|
||||
- weight_scale: fp32, shape [out] (per-output-channel)
|
||||
- no input_scale (activations are dynamically quantized per-token)
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"FP8_PER_CHANNEL_PER_TOKEN currently only supports "
|
||||
"FP8-serialized checkpoints."
|
||||
)
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty(output_size_per_partition, dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
|
||||
"""Linear method for ModelOpt FP8_PB_WO checkpoints.
|
||||
|
||||
ModelOpt exports `weight_scale` as a 4D tensor:
|
||||
[out_blk, 1, in_blk, 1]
|
||||
where block size is typically 128 for both dims.
|
||||
|
||||
vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
|
||||
"""
|
||||
|
||||
_WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
block_n, block_k = self._WEIGHT_BLOCK_SIZE
|
||||
self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(block_n, block_k),
|
||||
act_quant_group_shape=GroupShape(1, block_k),
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
||||
use_aiter_and_is_supported=False,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"FP8_PB_WO currently only supports FP8-serialized checkpoints."
|
||||
)
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# Expose block size so the v2 weight loaders can translate offsets from
|
||||
# element-space -> block-space for BlockQuantScaleParameter.
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
block_n, block_k = self._WEIGHT_BLOCK_SIZE
|
||||
if output_size_per_partition % block_n != 0:
|
||||
raise ValueError(
|
||||
"ModelOpt FP8_PB_WO requires out_features divisible by "
|
||||
f"{block_n}, got {output_size_per_partition}."
|
||||
)
|
||||
if input_size_per_partition % block_k != 0:
|
||||
raise ValueError(
|
||||
"ModelOpt FP8_PB_WO requires in_features divisible by "
|
||||
f"{block_k}, got {input_size_per_partition}."
|
||||
)
|
||||
|
||||
out_blks = output_size_per_partition // block_n
|
||||
in_blks = input_size_per_partition // block_k
|
||||
|
||||
# Match ModelOpt's exported shape so weight loading works without a
|
||||
# custom loader: [out_blk, 1, in_blk, 1]
|
||||
weight_scale = BlockQuantScaleParameter(
|
||||
data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
|
||||
input_dim=2,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
scale = layer.weight_scale
|
||||
if scale.dim() == 4:
|
||||
# [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
|
||||
scale = scale.squeeze(1).squeeze(-1)
|
||||
elif scale.dim() != 2:
|
||||
raise ValueError(
|
||||
"Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
|
||||
f"{tuple(scale.shape)}."
|
||||
)
|
||||
|
||||
layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for ModelOpt FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user