mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 10:45:46 +08:00
Support RL online quantization with torchao (#23014)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
parent
4134312b35
commit
c31246800c
@ -20,7 +20,6 @@ def test_pre_quantized_model(vllm_runner):
|
|||||||
output = llm.generate_greedy(["The capital of France is"],
|
output = llm.generate_greedy(["The capital of France is"],
|
||||||
max_tokens=32)
|
max_tokens=32)
|
||||||
assert output
|
assert output
|
||||||
print(output)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||||
@ -42,7 +41,6 @@ def test_opt_125m_int8wo_model_loading_with_params(vllm_runner,
|
|||||||
max_tokens=32)
|
max_tokens=32)
|
||||||
|
|
||||||
assert output
|
assert output
|
||||||
print(output)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||||
@ -57,7 +55,6 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
|
|||||||
max_tokens=32)
|
max_tokens=32)
|
||||||
|
|
||||||
assert output
|
assert output
|
||||||
print(output)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||||
@ -72,7 +69,6 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
|
|||||||
max_tokens=32)
|
max_tokens=32)
|
||||||
|
|
||||||
assert output
|
assert output
|
||||||
print(output)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||||
@ -92,7 +88,127 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
|
|||||||
max_tokens=32)
|
max_tokens=32)
|
||||||
|
|
||||||
assert output
|
assert output
|
||||||
print(output)
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||||
|
def test_on_the_fly_quant_config_dict_json(vllm_runner):
|
||||||
|
"""Testing on the fly quantization, load_weights integration point,
|
||||||
|
with config dict serialized to json string
|
||||||
|
"""
|
||||||
|
torch._dynamo.reset()
|
||||||
|
model_name = "facebook/opt-125m"
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from torchao.core.config import config_to_dict
|
||||||
|
from torchao.quantization import (
|
||||||
|
Float8DynamicActivationFloat8WeightConfig, PerRow)
|
||||||
|
|
||||||
|
torchao_quant_config = Float8DynamicActivationFloat8WeightConfig(
|
||||||
|
granularity=PerRow())
|
||||||
|
hf_overrides = {
|
||||||
|
"quantization_config_dict_json":
|
||||||
|
json.dumps(config_to_dict(torchao_quant_config))
|
||||||
|
}
|
||||||
|
with vllm_runner(model_name=model_name,
|
||||||
|
dtype="bfloat16",
|
||||||
|
pt_load_map_location="cuda:0",
|
||||||
|
quantization="torchao",
|
||||||
|
hf_overrides=hf_overrides) as llm:
|
||||||
|
output = llm.generate_greedy(["The capital of France is"],
|
||||||
|
max_tokens=32)
|
||||||
|
|
||||||
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||||
|
def test_on_the_fly_quant_config_file(vllm_runner):
|
||||||
|
"""Testing on the fly quantization, load_weights integration point,
|
||||||
|
with config file
|
||||||
|
"""
|
||||||
|
torch._dynamo.reset()
|
||||||
|
model_name = "facebook/opt-125m"
|
||||||
|
import json
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
|
from torchao.core.config import config_to_dict
|
||||||
|
from torchao.quantization import (
|
||||||
|
Float8DynamicActivationFloat8WeightConfig, PerRow)
|
||||||
|
|
||||||
|
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
|
||||||
|
|
||||||
|
with NamedTemporaryFile(mode="w", delete=False) as f:
|
||||||
|
f.write(json.dumps(config_to_dict(config)))
|
||||||
|
# close the file to save it
|
||||||
|
f.close()
|
||||||
|
config_file_name = str(f.name)
|
||||||
|
|
||||||
|
hf_overrides = {"quantization_config_file": config_file_name}
|
||||||
|
with vllm_runner(model_name=model_name,
|
||||||
|
dtype="bfloat16",
|
||||||
|
pt_load_map_location="cuda:0",
|
||||||
|
quantization="torchao",
|
||||||
|
hf_overrides=hf_overrides) as llm:
|
||||||
|
output = llm.generate_greedy(["The capital of France is"],
|
||||||
|
max_tokens=32)
|
||||||
|
|
||||||
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||||
|
def test_reload_weights():
|
||||||
|
import json
|
||||||
|
|
||||||
|
from torchao.core.config import config_to_dict
|
||||||
|
from torchao.quantization import (
|
||||||
|
Float8DynamicActivationFloat8WeightConfig, PerRow)
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
torchao_quant_config = Float8DynamicActivationFloat8WeightConfig(
|
||||||
|
granularity=PerRow())
|
||||||
|
|
||||||
|
hf_overrides = {
|
||||||
|
"quantization_config_dict_json":
|
||||||
|
json.dumps(config_to_dict(torchao_quant_config))
|
||||||
|
}
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="Qwen/Qwen3-0.6B",
|
||||||
|
dtype="bfloat16",
|
||||||
|
load_format="dummy",
|
||||||
|
enforce_eager=True,
|
||||||
|
quantization="torchao",
|
||||||
|
hf_overrides=hf_overrides,
|
||||||
|
)
|
||||||
|
# Update load format from `dummy` to `auto`
|
||||||
|
llm.collective_rpc("update_config",
|
||||||
|
args=({
|
||||||
|
"load_config": {
|
||||||
|
"load_format": "auto"
|
||||||
|
}
|
||||||
|
}, ))
|
||||||
|
# Now reload real weights inplace
|
||||||
|
llm.collective_rpc("reload_weights")
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(temperature=0, top_p=0.95)
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
# make sure it runs
|
||||||
|
for output in outputs:
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text
|
||||||
|
# can also uncomment locally to make sure the generated
|
||||||
|
# output makes sense
|
||||||
|
# prompt = output.prompt
|
||||||
|
# print(f"Prompt: {prompt!r}")
|
||||||
|
# print(f"Output: {generated_text!r}")
|
||||||
|
# print("-" * 60)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import json
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -40,7 +41,8 @@ class TorchAOConfig(QuantizationConfig):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
torchao_config,
|
torchao_config,
|
||||||
skip_modules: Optional[list[str]] = None) -> None:
|
skip_modules: Optional[list[str]] = None,
|
||||||
|
is_checkpoint_torchao_serialized: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
# TorchAO quantization relies on tensor subclasses. In order,
|
# TorchAO quantization relies on tensor subclasses. In order,
|
||||||
# to enable proper caching this needs standalone compile
|
# to enable proper caching this needs standalone compile
|
||||||
@ -58,9 +60,11 @@ class TorchAOConfig(QuantizationConfig):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.torchao_config = torchao_config
|
self.torchao_config = torchao_config
|
||||||
self.skip_modules = skip_modules or []
|
self.skip_modules = skip_modules or []
|
||||||
|
self.is_checkpoint_torchao_serialized = is_checkpoint_torchao_serialized
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"TorchAOConfig({self.torchao_config})"
|
return f"TorchAOConfig({self.torchao_config=}, {self.skip_modules=}, " \
|
||||||
|
f"{self.is_checkpoint_torchao_serialized=})"
|
||||||
|
|
||||||
def get_name(self) -> QuantizationMethods:
|
def get_name(self) -> QuantizationMethods:
|
||||||
return "torchao"
|
return "torchao"
|
||||||
@ -74,7 +78,10 @@ class TorchAOConfig(QuantizationConfig):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_config_filenames() -> list[str]:
|
def get_config_filenames() -> list[str]:
|
||||||
return ["config.json"]
|
"""torchao doesn't require additional config files, we use
|
||||||
|
`config.json` from huggingface: `model_config.hf_config`
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
|
def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
|
||||||
@ -87,6 +94,10 @@ class TorchAOConfig(QuantizationConfig):
|
|||||||
"`pip install torchao>=0.10.0` to use torchao quantization."
|
"`pip install torchao>=0.10.0` to use torchao quantization."
|
||||||
) from err
|
) from err
|
||||||
|
|
||||||
|
quant_method = cls.get_from_keys_or(config, ["quant_method"], None)
|
||||||
|
is_checkpoint_torchao_serialized = (quant_method is not None
|
||||||
|
and "torchao" in quant_method)
|
||||||
|
|
||||||
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
|
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
|
||||||
assert hf_config is not None, "quant_type must be specified"
|
assert hf_config is not None, "quant_type must be specified"
|
||||||
assert len(hf_config) == 1 and "default" in hf_config, (
|
assert len(hf_config) == 1 and "default" in hf_config, (
|
||||||
@ -110,7 +121,38 @@ class TorchAOConfig(QuantizationConfig):
|
|||||||
if layer_cfg is None:
|
if layer_cfg is None:
|
||||||
skip_modules.append(layer)
|
skip_modules.append(layer)
|
||||||
|
|
||||||
return cls(ao_config, skip_modules)
|
return cls(ao_config, skip_modules, is_checkpoint_torchao_serialized)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config_file(cls, config_file: str) -> "TorchAOConfig":
|
||||||
|
"""Initialize class from a config file. Example:
|
||||||
|
```
|
||||||
|
config = (
|
||||||
|
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
|
||||||
|
)
|
||||||
|
fn = "torchao_config.json"
|
||||||
|
|
||||||
|
with open(fn, "w") as f:
|
||||||
|
f.write(json.dumps(config_to_dict(config)))
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
with open(config_file) as f:
|
||||||
|
f.seek(0)
|
||||||
|
f_read = f.read()
|
||||||
|
config_dict = json.loads(f_read)
|
||||||
|
|
||||||
|
hf_config = {"quant_type": {"default": config_dict}}
|
||||||
|
return cls.from_config(hf_config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config_dict_json(cls, config_dict_json: str) -> "TorchAOConfig":
|
||||||
|
"""Iniitalize class from a config_dict json string, got from
|
||||||
|
torchao_config_object = some AOBaseConfig object
|
||||||
|
json.dumps(config_to_dict(torchao_config_object))
|
||||||
|
"""
|
||||||
|
config_dict = json.loads(config_dict_json)
|
||||||
|
hf_config = {"quant_type": {"default": config_dict}}
|
||||||
|
return cls.from_config(hf_config)
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
@ -128,7 +170,9 @@ class TorchAOConfig(QuantizationConfig):
|
|||||||
c = module_fqn_to_config.get(
|
c = module_fqn_to_config.get(
|
||||||
module_fqn) or module_fqn_to_config.get("_default", None)
|
module_fqn) or module_fqn_to_config.get("_default", None)
|
||||||
if c is not None:
|
if c is not None:
|
||||||
current_torchao_config = TorchAOConfig(c, self.skip_modules)
|
current_torchao_config = TorchAOConfig(
|
||||||
|
c, self.skip_modules,
|
||||||
|
self.is_checkpoint_torchao_serialized)
|
||||||
return TorchAOLinearMethod(current_torchao_config)
|
return TorchAOLinearMethod(current_torchao_config)
|
||||||
else:
|
else:
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
@ -172,7 +216,7 @@ class TorchAOLinearMethod(LinearMethodBase):
|
|||||||
"""Linear method for torchao.
|
"""Linear method for torchao.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
quant_config: The torchao quantization config, a string that encodes
|
quant_config: The torchao quantization config, a string that encodes
|
||||||
the type of quantization and all relevant arguments.
|
the type of quantization and all relevant arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -197,8 +241,9 @@ class TorchAOLinearMethod(LinearMethodBase):
|
|||||||
),
|
),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
weight = torchao_quantize_param_data(weight,
|
if self.quant_config.is_checkpoint_torchao_serialized:
|
||||||
self.quant_config.torchao_config)
|
weight = torchao_quantize_param_data(
|
||||||
|
weight, self.quant_config.torchao_config)
|
||||||
|
|
||||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||||
|
|
||||||
@ -212,3 +257,14 @@ class TorchAOLinearMethod(LinearMethodBase):
|
|||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return F.linear(x, layer.weight, bias)
|
return F.linear(x, layer.weight, bias)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
if self.quant_config.is_checkpoint_torchao_serialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# quantize the weight on the fly if the checkpoint is not already
|
||||||
|
# quantized by torchao
|
||||||
|
weight = torchao_quantize_param_data(layer.weight,
|
||||||
|
self.quant_config.torchao_config)
|
||||||
|
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
|||||||
@ -261,8 +261,35 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
def load_weights(self, model: nn.Module,
|
def load_weights(self, model: nn.Module,
|
||||||
model_config: ModelConfig) -> None:
|
model_config: ModelConfig) -> None:
|
||||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||||
loaded_weights = model.load_weights(
|
|
||||||
self.get_all_weights(model_config, model))
|
# if we don't have `model.weight_metadata_and_attr_saved` defined and
|
||||||
|
# set to True, it means that this is either offline quantization case
|
||||||
|
# or the first run of online quantization
|
||||||
|
# see online_quantization.py for detailed notes
|
||||||
|
offline_quantization_or_first_run_of_online_quantization = not getattr(
|
||||||
|
model, "weight_metadata_and_attr_saved", False)
|
||||||
|
|
||||||
|
if model_config.quantization is None:
|
||||||
|
# model is not quantized
|
||||||
|
loaded_weights = model.load_weights(
|
||||||
|
self.get_all_weights(model_config, model))
|
||||||
|
elif offline_quantization_or_first_run_of_online_quantization:
|
||||||
|
# case 1: offline quantized checkpoint
|
||||||
|
# case 2: Step I1 first run of weight loading with
|
||||||
|
# online quantization
|
||||||
|
# see online_quantization.py for detailed notes
|
||||||
|
loaded_weights = model.load_weights(
|
||||||
|
self.get_all_weights(model_config, model))
|
||||||
|
else:
|
||||||
|
# to avoid circular dependency
|
||||||
|
from vllm.model_executor.model_loader.online_quantization import (
|
||||||
|
load_weights_and_online_quantize)
|
||||||
|
|
||||||
|
# subsequent runs of weight loading with online
|
||||||
|
# quantization
|
||||||
|
loaded_weights = load_weights_and_online_quantize(
|
||||||
|
self, model, model_config)
|
||||||
|
|
||||||
self.counter_after_loading_weights = time.perf_counter()
|
self.counter_after_loading_weights = time.perf_counter()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Loading weights took %.2f seconds",
|
"Loading weights took %.2f seconds",
|
||||||
|
|||||||
217
vllm/model_executor/model_loader/online_quantization.py
Normal file
217
vllm/model_executor/model_loader/online_quantization.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import types
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||||
|
from vllm.model_executor.model_loader.utils import (
|
||||||
|
process_weights_after_loading)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Notes for Online Quantization
|
||||||
|
# In terms of state of checkpoints, quantization config and their
|
||||||
|
# correspondance to online quantization:
|
||||||
|
# | Use Case | Checkpoints | model_config.quantization |
|
||||||
|
# | no quant | high precision | None |
|
||||||
|
# | offline quant | quantized | fp8, torchao etc. |
|
||||||
|
# | online quant | high precision | torchao etc. |
|
||||||
|
#
|
||||||
|
# The process for loading non-quantized checkpoint
|
||||||
|
# 1. load non-quantized weights (load_weights)
|
||||||
|
# 2. do any additional post processing (process_weights_after_loading)
|
||||||
|
#
|
||||||
|
# The process for loading offline quantized checkpoint
|
||||||
|
# 1. load offline-quantized weights (load_weights)
|
||||||
|
# 2. do any additional post processing (process_weights_after_loading)
|
||||||
|
|
||||||
|
# The process for unquantized model reloading
|
||||||
|
# (repeated run in RL training loop)
|
||||||
|
# first run
|
||||||
|
# UI1. load_weights: load bfloat16 weights
|
||||||
|
# UI2. process_weights_after_loading: any additional post processing
|
||||||
|
# subsequent run
|
||||||
|
# UC1: load_weights: load bfloat16 weights
|
||||||
|
# (shouldn't be any issues since we didn't change any attributes
|
||||||
|
# of the weights)
|
||||||
|
# UC2: process_weights_after_loading: any additional post processing
|
||||||
|
|
||||||
|
# The process for weight reloading with online quantization
|
||||||
|
# (repeated run in RL training loop)
|
||||||
|
# first run
|
||||||
|
# I1. load_weights: load bfloat16 weights
|
||||||
|
# I2. process_weights_after_loading:
|
||||||
|
# record weight metadata and attributes for R1 and R2
|
||||||
|
# quantize weights to fp8
|
||||||
|
# subsequent run
|
||||||
|
# (beginning model weight is in fp8)
|
||||||
|
# load_weights:
|
||||||
|
# R1. restore bfloat16 model weight metadata
|
||||||
|
# R2. restore the model weight attributes
|
||||||
|
# R3. reload bfloat16 weights
|
||||||
|
# R4. quantize weights (by calling process_weights_after_loading),
|
||||||
|
# also set `process_weights_after_loading_already_called` to
|
||||||
|
# True to stop it from running again
|
||||||
|
# process_weights_after_loading (if called):
|
||||||
|
# this will be skipped since it's already ran in
|
||||||
|
# load_weights
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_save_metadata_and_attributes_for_weight_reloading(
|
||||||
|
model: nn.Module, model_config: ModelConfig):
|
||||||
|
# following is to support on the fly quantization, currently only supported
|
||||||
|
# for torchao
|
||||||
|
if model_config.quantization != "torchao":
|
||||||
|
return
|
||||||
|
|
||||||
|
if getattr(model, "process_weights_after_loading_already_called", False):
|
||||||
|
# In case `process_weights_after_loading` is called multiple times
|
||||||
|
# we'll skip it at later times
|
||||||
|
logger.warning(
|
||||||
|
"process_weights_after_loading already called for model %s", model)
|
||||||
|
return
|
||||||
|
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import get_quant_config
|
||||||
|
quant_config = get_quant_config(model_config, None)
|
||||||
|
|
||||||
|
# If checkpoint is already torchao serialized, this means it's
|
||||||
|
# pre-quantized quantization case, we'll skip saving the metadata
|
||||||
|
# Otherwise, this is Step I2 of initialization steps of
|
||||||
|
# online quantization
|
||||||
|
# This step record the weights metadata and weight attributes so we can
|
||||||
|
# restore the bfloat16 model weights during the relad step (R1 and R2)
|
||||||
|
# see Notes in online_quantization.py for more details
|
||||||
|
if not (hasattr(quant_config, "is_checkpoint_torchao_serialized") and \
|
||||||
|
not quant_config.is_checkpoint_torchao_serialized):
|
||||||
|
return
|
||||||
|
|
||||||
|
# This is the I2 step of online quantiztion that saves
|
||||||
|
# metadata and attributes of weights so they can be used in R1 and
|
||||||
|
# R2 step, note that we only save these during initialization
|
||||||
|
|
||||||
|
# Includes two things
|
||||||
|
# 1. save floating point metadata (shape, dtype, device) for init
|
||||||
|
# 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init
|
||||||
|
|
||||||
|
if getattr(model, "weight_metadata_and_attr_saved", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
# save the dtype, shape and device for model parameter, used for
|
||||||
|
# restoring the model high precision parameters before
|
||||||
|
# reloading the weights
|
||||||
|
assert not hasattr(model, "original_weights_rebuild_keys")
|
||||||
|
model.original_weights_rebuild_keys = {}
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
model.original_weights_rebuild_keys[name] = {
|
||||||
|
"shape": p.shape,
|
||||||
|
"dtype": p.dtype,
|
||||||
|
"device": p.device,
|
||||||
|
}
|
||||||
|
|
||||||
|
# record the weight attributes (loader functions etc.)
|
||||||
|
# so these can be recovered later when we reload the weights
|
||||||
|
# structure: {"weight_name": {"weight_attr_key": attr}}
|
||||||
|
assert not hasattr(model, "recorded_weight_attr")
|
||||||
|
model.recorded_weight_attr = {}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
model.recorded_weight_attr[name] = {}
|
||||||
|
for key in param.__dict__:
|
||||||
|
if hasattr(param, key):
|
||||||
|
attr = getattr(param, key)
|
||||||
|
if not callable(attr):
|
||||||
|
model.recorded_weight_attr[name][key] = attr
|
||||||
|
elif hasattr(attr, "__self__") and param is attr.__self__:
|
||||||
|
# if attr is a bonded method for an instance, and
|
||||||
|
# attr.__self__ points to the instance (param)
|
||||||
|
# we'll record the underlying function object
|
||||||
|
model.recorded_weight_attr[name][key] = attr.__func__
|
||||||
|
else:
|
||||||
|
model.recorded_weight_attr[name][key] = attr
|
||||||
|
# mark the metadata and attributes saved so we don't run it again
|
||||||
|
model.weight_metadata_and_attr_saved = True
|
||||||
|
|
||||||
|
|
||||||
|
def _bond_method_to_cls(func, obj):
|
||||||
|
if hasattr(func, "__self__") or not callable(func):
|
||||||
|
# If the function is already bound to an instance, return it as is
|
||||||
|
return func
|
||||||
|
else:
|
||||||
|
return types.MethodType(func, obj)
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights_and_online_quantize(model_loader: DefaultModelLoader,
|
||||||
|
model: nn.Module,
|
||||||
|
model_config: ModelConfig) -> set[str]:
|
||||||
|
# online quantization, right now only enabled for
|
||||||
|
# torchao
|
||||||
|
# R1, R2, R3, R4 in the Notes
|
||||||
|
|
||||||
|
# TODO: Add fp8 support
|
||||||
|
assert model_config.quantization == "torchao", "online " \
|
||||||
|
"quantization is only enabled for torchao currently"
|
||||||
|
# TODO: use create_weights to restore the weights to original state
|
||||||
|
|
||||||
|
# Step R1: First restore the quantized weights to original bfloat16
|
||||||
|
# weights, with original metadata (shape, dtype, device)
|
||||||
|
# and attributes, so that bfloat16 weights can be loaded properly
|
||||||
|
existing_param_names = dict(
|
||||||
|
model.named_parameters(remove_duplicate=False)).keys()
|
||||||
|
named_modules = dict(model.named_modules(remove_duplicate=False))
|
||||||
|
model_device = None
|
||||||
|
|
||||||
|
# Step R2: recover the parameter to the state before first loading
|
||||||
|
for name, d in model.original_weights_rebuild_keys.items():
|
||||||
|
_shape = d["shape"]
|
||||||
|
_dtype = d["dtype"]
|
||||||
|
_device = d["device"]
|
||||||
|
if model_device is not None:
|
||||||
|
assert model_device == _device, "Expecting all weights " \
|
||||||
|
"to be in the same device for now, got both: " \
|
||||||
|
f"{model_device} and {_device}"
|
||||||
|
else:
|
||||||
|
model_device = _device
|
||||||
|
|
||||||
|
if name in existing_param_names:
|
||||||
|
module_name, weight_name = name.rsplit(".", 1)
|
||||||
|
module = named_modules[module_name]
|
||||||
|
setattr(
|
||||||
|
module, weight_name,
|
||||||
|
torch.nn.Parameter(
|
||||||
|
torch.empty(_shape, dtype=_dtype, device=_device)))
|
||||||
|
|
||||||
|
# recorded_weight_attr is
|
||||||
|
# {"weight_name": {"weight_attr_key": attr}}
|
||||||
|
# e.g.
|
||||||
|
# {
|
||||||
|
# {
|
||||||
|
# "layer.0.weight": {
|
||||||
|
# "weight_loader": weight_loader_function_object,
|
||||||
|
# "input_dim": 0, ...
|
||||||
|
# },
|
||||||
|
# "layer.1.weight": ...,
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
for full_weight_name, weight_attr_dict in \
|
||||||
|
model.recorded_weight_attr.items():
|
||||||
|
for attr_name, attr in weight_attr_dict.items():
|
||||||
|
module_name, weight_name = full_weight_name.rsplit(".", 1)
|
||||||
|
module = named_modules[module_name]
|
||||||
|
weight = getattr(module, weight_name)
|
||||||
|
if not hasattr(weight, attr_name):
|
||||||
|
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
|
||||||
|
|
||||||
|
# Step I1: reload bfloat16 / high precision weights
|
||||||
|
loaded_weights = model.load_weights(
|
||||||
|
model_loader.get_all_weights(model_config, model))
|
||||||
|
|
||||||
|
# Step I2: online quantize the weights
|
||||||
|
# manually process weights after loading
|
||||||
|
model.process_weights_after_loading_already_called = False
|
||||||
|
process_weights_after_loading(model, model_config, model_device)
|
||||||
|
model.process_weights_after_loading_already_called = True
|
||||||
|
return loaded_weights
|
||||||
@ -95,6 +95,13 @@ def initialize_model(
|
|||||||
|
|
||||||
def process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
def process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
||||||
target_device: torch.device) -> None:
|
target_device: torch.device) -> None:
|
||||||
|
|
||||||
|
# to avoid circular dependency
|
||||||
|
from vllm.model_executor.model_loader.online_quantization import (
|
||||||
|
maybe_save_metadata_and_attributes_for_weight_reloading)
|
||||||
|
maybe_save_metadata_and_attributes_for_weight_reloading(
|
||||||
|
model, model_config)
|
||||||
|
|
||||||
for _, module in model.named_modules():
|
for _, module in model.named_modules():
|
||||||
if isinstance(module, QKVCrossParallelLinear):
|
if isinstance(module, QKVCrossParallelLinear):
|
||||||
# NOTE(Isotr0py): special case for cross QKV layer because
|
# NOTE(Isotr0py): special case for cross QKV layer because
|
||||||
@ -243,7 +250,7 @@ def get_architecture_class_name(model_config: ModelConfig) -> str:
|
|||||||
class ParamMapping:
|
class ParamMapping:
|
||||||
"""
|
"""
|
||||||
A class to handle parameter mapping for model weight loading.
|
A class to handle parameter mapping for model weight loading.
|
||||||
It creates a bidirectional mapping between packed parameters and their
|
It creates a bidirectional mapping between packed parameters and their
|
||||||
constituent parts.
|
constituent parts.
|
||||||
"""
|
"""
|
||||||
packed_mapping: dict[str, list[str]]
|
packed_mapping: dict[str, list[str]]
|
||||||
|
|||||||
@ -246,8 +246,34 @@ def get_quant_config(model_config: ModelConfig,
|
|||||||
# compressed-tensors uses a compressions_config
|
# compressed-tensors uses a compressions_config
|
||||||
hf_quant_config = getattr(model_config.hf_config, "compression_config",
|
hf_quant_config = getattr(model_config.hf_config, "compression_config",
|
||||||
None)
|
None)
|
||||||
|
|
||||||
if hf_quant_config is not None:
|
if hf_quant_config is not None:
|
||||||
return quant_cls.from_config(hf_quant_config)
|
return quant_cls.from_config(hf_quant_config)
|
||||||
|
|
||||||
|
# if hf_quant_config is None, we will try to get config from
|
||||||
|
# hf_overrides
|
||||||
|
hf_overrides = model_config.hf_overrides
|
||||||
|
quantization_config_file = hf_overrides.get("quantization_config_file",
|
||||||
|
None)
|
||||||
|
if quantization_config_file is not None:
|
||||||
|
if hasattr(quant_cls, "from_config_file"):
|
||||||
|
return quant_cls.from_config_file(quantization_config_file)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"from_config_file is specified in hf_override config, "
|
||||||
|
"but quant_cls.from_config_file is not implemented in "
|
||||||
|
f"{quant_cls}")
|
||||||
|
quantization_config_json = hf_overrides.get(
|
||||||
|
"quantization_config_dict_json", None)
|
||||||
|
if quantization_config_json is not None:
|
||||||
|
if hasattr(quant_cls, "from_config_dict_json"):
|
||||||
|
return quant_cls.from_config_dict_json(quantization_config_json)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"from_config_dict_json is specified in hf_override config, "
|
||||||
|
"but quant_cls.from_config_dict_json is not implemented in "
|
||||||
|
f"{quant_cls}")
|
||||||
|
|
||||||
# Inflight BNB quantization
|
# Inflight BNB quantization
|
||||||
if model_config.quantization == "bitsandbytes":
|
if model_config.quantization == "bitsandbytes":
|
||||||
return quant_cls.from_config({})
|
return quant_cls.from_config({})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user