From c31246800c98272ab28b7d454d6b72f38e396972 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 1 Oct 2025 16:39:29 -0700 Subject: [PATCH] Support RL online quantization with torchao (#23014) Signed-off-by: Jerry Zhang --- tests/quantization/test_torchao.py | 126 +++++++++- .../layers/quantization/torchao.py | 72 +++++- .../model_loader/default_loader.py | 31 ++- .../model_loader/online_quantization.py | 217 ++++++++++++++++++ vllm/model_executor/model_loader/utils.py | 9 +- .../model_loader/weight_utils.py | 26 +++ 6 files changed, 465 insertions(+), 16 deletions(-) create mode 100644 vllm/model_executor/model_loader/online_quantization.py diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index 8e68f6a2e019..37cf7ef8417b 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -20,7 +20,6 @@ def test_pre_quantized_model(vllm_runner): output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output - print(output) @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") @@ -42,7 +41,6 @@ def test_opt_125m_int8wo_model_loading_with_params(vllm_runner, max_tokens=32) assert output - print(output) @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") @@ -57,7 +55,6 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner): max_tokens=32) assert output - print(output) @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") @@ -72,7 +69,6 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner): max_tokens=32) assert output - print(output) @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") @@ -92,7 +88,127 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner): max_tokens=32) assert output - print(output) + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +def test_on_the_fly_quant_config_dict_json(vllm_runner): + """Testing on the fly quantization, load_weights integration point, + with config dict serialized to json string + """ + torch._dynamo.reset() + model_name = "facebook/opt-125m" + + import json + + from torchao.core.config import config_to_dict + from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, PerRow) + + torchao_quant_config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow()) + hf_overrides = { + "quantization_config_dict_json": + json.dumps(config_to_dict(torchao_quant_config)) + } + with vllm_runner(model_name=model_name, + dtype="bfloat16", + pt_load_map_location="cuda:0", + quantization="torchao", + hf_overrides=hf_overrides) as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + + assert output + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +def test_on_the_fly_quant_config_file(vllm_runner): + """Testing on the fly quantization, load_weights integration point, + with config file + """ + torch._dynamo.reset() + model_name = "facebook/opt-125m" + import json + from tempfile import NamedTemporaryFile + + from torchao.core.config import config_to_dict + from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, PerRow) + + config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + + with NamedTemporaryFile(mode="w", delete=False) as f: + f.write(json.dumps(config_to_dict(config))) + # close the file to save it + f.close() + config_file_name = str(f.name) + + hf_overrides = {"quantization_config_file": config_file_name} + with vllm_runner(model_name=model_name, + dtype="bfloat16", + pt_load_map_location="cuda:0", + quantization="torchao", + hf_overrides=hf_overrides) as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + + assert output + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +def test_reload_weights(): + import json + + from torchao.core.config import config_to_dict + from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, PerRow) + + from vllm import LLM, SamplingParams + + torchao_quant_config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow()) + + hf_overrides = { + "quantization_config_dict_json": + json.dumps(config_to_dict(torchao_quant_config)) + } + + llm = LLM( + model="Qwen/Qwen3-0.6B", + dtype="bfloat16", + load_format="dummy", + enforce_eager=True, + quantization="torchao", + hf_overrides=hf_overrides, + ) + # Update load format from `dummy` to `auto` + llm.collective_rpc("update_config", + args=({ + "load_config": { + "load_format": "auto" + } + }, )) + # Now reload real weights inplace + llm.collective_rpc("reload_weights") + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0, top_p=0.95) + outputs = llm.generate(prompts, sampling_params) + # make sure it runs + for output in outputs: + generated_text = output.outputs[0].text + assert generated_text + # can also uncomment locally to make sure the generated + # output makes sense + # prompt = output.prompt + # print(f"Prompt: {prompt!r}") + # print(f"Output: {generated_text!r}") + # print("-" * 60) if __name__ == "__main__": diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 2efb605f203f..7e38304ad6d9 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json from typing import Any, Optional import torch @@ -40,7 +41,8 @@ class TorchAOConfig(QuantizationConfig): def __init__(self, torchao_config, - skip_modules: Optional[list[str]] = None) -> None: + skip_modules: Optional[list[str]] = None, + is_checkpoint_torchao_serialized: bool = False) -> None: """ # TorchAO quantization relies on tensor subclasses. In order, # to enable proper caching this needs standalone compile @@ -58,9 +60,11 @@ class TorchAOConfig(QuantizationConfig): super().__init__() self.torchao_config = torchao_config self.skip_modules = skip_modules or [] + self.is_checkpoint_torchao_serialized = is_checkpoint_torchao_serialized def __repr__(self) -> str: - return f"TorchAOConfig({self.torchao_config})" + return f"TorchAOConfig({self.torchao_config=}, {self.skip_modules=}, " \ + f"{self.is_checkpoint_torchao_serialized=})" def get_name(self) -> QuantizationMethods: return "torchao" @@ -74,7 +78,10 @@ class TorchAOConfig(QuantizationConfig): @staticmethod def get_config_filenames() -> list[str]: - return ["config.json"] + """torchao doesn't require additional config files, we use + `config.json` from huggingface: `model_config.hf_config` + """ + return [] @classmethod def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig": @@ -87,6 +94,10 @@ class TorchAOConfig(QuantizationConfig): "`pip install torchao>=0.10.0` to use torchao quantization." ) from err + quant_method = cls.get_from_keys_or(config, ["quant_method"], None) + is_checkpoint_torchao_serialized = (quant_method is not None + and "torchao" in quant_method) + hf_config = cls.get_from_keys_or(config, ["quant_type"], None) assert hf_config is not None, "quant_type must be specified" assert len(hf_config) == 1 and "default" in hf_config, ( @@ -110,7 +121,38 @@ class TorchAOConfig(QuantizationConfig): if layer_cfg is None: skip_modules.append(layer) - return cls(ao_config, skip_modules) + return cls(ao_config, skip_modules, is_checkpoint_torchao_serialized) + + @classmethod + def from_config_file(cls, config_file: str) -> "TorchAOConfig": + """Initialize class from a config file. Example: + ``` + config = ( + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + ) + fn = "torchao_config.json" + + with open(fn, "w") as f: + f.write(json.dumps(config_to_dict(config))) + ``` + """ + with open(config_file) as f: + f.seek(0) + f_read = f.read() + config_dict = json.loads(f_read) + + hf_config = {"quant_type": {"default": config_dict}} + return cls.from_config(hf_config) + + @classmethod + def from_config_dict_json(cls, config_dict_json: str) -> "TorchAOConfig": + """Iniitalize class from a config_dict json string, got from + torchao_config_object = some AOBaseConfig object + json.dumps(config_to_dict(torchao_config_object)) + """ + config_dict = json.loads(config_dict_json) + hf_config = {"quant_type": {"default": config_dict}} + return cls.from_config(hf_config) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: @@ -128,7 +170,9 @@ class TorchAOConfig(QuantizationConfig): c = module_fqn_to_config.get( module_fqn) or module_fqn_to_config.get("_default", None) if c is not None: - current_torchao_config = TorchAOConfig(c, self.skip_modules) + current_torchao_config = TorchAOConfig( + c, self.skip_modules, + self.is_checkpoint_torchao_serialized) return TorchAOLinearMethod(current_torchao_config) else: return UnquantizedLinearMethod() @@ -172,7 +216,7 @@ class TorchAOLinearMethod(LinearMethodBase): """Linear method for torchao. Args: - quant_config: The torchao quantization config, a string that encodes + quant_config: The torchao quantization config, a string that encodes the type of quantization and all relevant arguments. """ @@ -197,8 +241,9 @@ class TorchAOLinearMethod(LinearMethodBase): ), requires_grad=False, ) - weight = torchao_quantize_param_data(weight, - self.quant_config.torchao_config) + if self.quant_config.is_checkpoint_torchao_serialized: + weight = torchao_quantize_param_data( + weight, self.quant_config.torchao_config) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) @@ -212,3 +257,14 @@ class TorchAOLinearMethod(LinearMethodBase): bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return F.linear(x, layer.weight, bias) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.quant_config.is_checkpoint_torchao_serialized: + return + + # quantize the weight on the fly if the checkpoint is not already + # quantized by torchao + weight = torchao_quantize_param_data(layer.weight, + self.quant_config.torchao_config) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 4b7bcd37d4bc..8e2db9292ff8 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -261,8 +261,35 @@ class DefaultModelLoader(BaseModelLoader): def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: weights_to_load = {name for name, _ in model.named_parameters()} - loaded_weights = model.load_weights( - self.get_all_weights(model_config, model)) + + # if we don't have `model.weight_metadata_and_attr_saved` defined and + # set to True, it means that this is either offline quantization case + # or the first run of online quantization + # see online_quantization.py for detailed notes + offline_quantization_or_first_run_of_online_quantization = not getattr( + model, "weight_metadata_and_attr_saved", False) + + if model_config.quantization is None: + # model is not quantized + loaded_weights = model.load_weights( + self.get_all_weights(model_config, model)) + elif offline_quantization_or_first_run_of_online_quantization: + # case 1: offline quantized checkpoint + # case 2: Step I1 first run of weight loading with + # online quantization + # see online_quantization.py for detailed notes + loaded_weights = model.load_weights( + self.get_all_weights(model_config, model)) + else: + # to avoid circular dependency + from vllm.model_executor.model_loader.online_quantization import ( + load_weights_and_online_quantize) + + # subsequent runs of weight loading with online + # quantization + loaded_weights = load_weights_and_online_quantize( + self, model, model_config) + self.counter_after_loading_weights = time.perf_counter() logger.info( "Loading weights took %.2f seconds", diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py new file mode 100644 index 000000000000..beec2d20ad69 --- /dev/null +++ b/vllm/model_executor/model_loader/online_quantization.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import types + +import torch +from torch import nn + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader +from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading) + +logger = init_logger(__name__) + +# Notes for Online Quantization +# In terms of state of checkpoints, quantization config and their +# correspondance to online quantization: +# | Use Case | Checkpoints | model_config.quantization | +# | no quant | high precision | None | +# | offline quant | quantized | fp8, torchao etc. | +# | online quant | high precision | torchao etc. | +# +# The process for loading non-quantized checkpoint +# 1. load non-quantized weights (load_weights) +# 2. do any additional post processing (process_weights_after_loading) +# +# The process for loading offline quantized checkpoint +# 1. load offline-quantized weights (load_weights) +# 2. do any additional post processing (process_weights_after_loading) + +# The process for unquantized model reloading +# (repeated run in RL training loop) +# first run +# UI1. load_weights: load bfloat16 weights +# UI2. process_weights_after_loading: any additional post processing +# subsequent run +# UC1: load_weights: load bfloat16 weights +# (shouldn't be any issues since we didn't change any attributes +# of the weights) +# UC2: process_weights_after_loading: any additional post processing + +# The process for weight reloading with online quantization +# (repeated run in RL training loop) +# first run +# I1. load_weights: load bfloat16 weights +# I2. process_weights_after_loading: +# record weight metadata and attributes for R1 and R2 +# quantize weights to fp8 +# subsequent run +# (beginning model weight is in fp8) +# load_weights: +# R1. restore bfloat16 model weight metadata +# R2. restore the model weight attributes +# R3. reload bfloat16 weights +# R4. quantize weights (by calling process_weights_after_loading), +# also set `process_weights_after_loading_already_called` to +# True to stop it from running again +# process_weights_after_loading (if called): +# this will be skipped since it's already ran in +# load_weights + + +def maybe_save_metadata_and_attributes_for_weight_reloading( + model: nn.Module, model_config: ModelConfig): + # following is to support on the fly quantization, currently only supported + # for torchao + if model_config.quantization != "torchao": + return + + if getattr(model, "process_weights_after_loading_already_called", False): + # In case `process_weights_after_loading` is called multiple times + # we'll skip it at later times + logger.warning( + "process_weights_after_loading already called for model %s", model) + return + + from vllm.model_executor.model_loader.weight_utils import get_quant_config + quant_config = get_quant_config(model_config, None) + + # If checkpoint is already torchao serialized, this means it's + # pre-quantized quantization case, we'll skip saving the metadata + # Otherwise, this is Step I2 of initialization steps of + # online quantization + # This step record the weights metadata and weight attributes so we can + # restore the bfloat16 model weights during the relad step (R1 and R2) + # see Notes in online_quantization.py for more details + if not (hasattr(quant_config, "is_checkpoint_torchao_serialized") and \ + not quant_config.is_checkpoint_torchao_serialized): + return + + # This is the I2 step of online quantiztion that saves + # metadata and attributes of weights so they can be used in R1 and + # R2 step, note that we only save these during initialization + + # Includes two things + # 1. save floating point metadata (shape, dtype, device) for init + # 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init + + if getattr(model, "weight_metadata_and_attr_saved", False): + return + + # save the dtype, shape and device for model parameter, used for + # restoring the model high precision parameters before + # reloading the weights + assert not hasattr(model, "original_weights_rebuild_keys") + model.original_weights_rebuild_keys = {} + for name, p in model.named_parameters(): + model.original_weights_rebuild_keys[name] = { + "shape": p.shape, + "dtype": p.dtype, + "device": p.device, + } + + # record the weight attributes (loader functions etc.) + # so these can be recovered later when we reload the weights + # structure: {"weight_name": {"weight_attr_key": attr}} + assert not hasattr(model, "recorded_weight_attr") + model.recorded_weight_attr = {} + for name, param in model.named_parameters(): + model.recorded_weight_attr[name] = {} + for key in param.__dict__: + if hasattr(param, key): + attr = getattr(param, key) + if not callable(attr): + model.recorded_weight_attr[name][key] = attr + elif hasattr(attr, "__self__") and param is attr.__self__: + # if attr is a bonded method for an instance, and + # attr.__self__ points to the instance (param) + # we'll record the underlying function object + model.recorded_weight_attr[name][key] = attr.__func__ + else: + model.recorded_weight_attr[name][key] = attr + # mark the metadata and attributes saved so we don't run it again + model.weight_metadata_and_attr_saved = True + + +def _bond_method_to_cls(func, obj): + if hasattr(func, "__self__") or not callable(func): + # If the function is already bound to an instance, return it as is + return func + else: + return types.MethodType(func, obj) + + +def load_weights_and_online_quantize(model_loader: DefaultModelLoader, + model: nn.Module, + model_config: ModelConfig) -> set[str]: + # online quantization, right now only enabled for + # torchao + # R1, R2, R3, R4 in the Notes + + # TODO: Add fp8 support + assert model_config.quantization == "torchao", "online " \ + "quantization is only enabled for torchao currently" + # TODO: use create_weights to restore the weights to original state + + # Step R1: First restore the quantized weights to original bfloat16 + # weights, with original metadata (shape, dtype, device) + # and attributes, so that bfloat16 weights can be loaded properly + existing_param_names = dict( + model.named_parameters(remove_duplicate=False)).keys() + named_modules = dict(model.named_modules(remove_duplicate=False)) + model_device = None + + # Step R2: recover the parameter to the state before first loading + for name, d in model.original_weights_rebuild_keys.items(): + _shape = d["shape"] + _dtype = d["dtype"] + _device = d["device"] + if model_device is not None: + assert model_device == _device, "Expecting all weights " \ + "to be in the same device for now, got both: " \ + f"{model_device} and {_device}" + else: + model_device = _device + + if name in existing_param_names: + module_name, weight_name = name.rsplit(".", 1) + module = named_modules[module_name] + setattr( + module, weight_name, + torch.nn.Parameter( + torch.empty(_shape, dtype=_dtype, device=_device))) + + # recorded_weight_attr is + # {"weight_name": {"weight_attr_key": attr}} + # e.g. + # { + # { + # "layer.0.weight": { + # "weight_loader": weight_loader_function_object, + # "input_dim": 0, ... + # }, + # "layer.1.weight": ..., + # } + # } + for full_weight_name, weight_attr_dict in \ + model.recorded_weight_attr.items(): + for attr_name, attr in weight_attr_dict.items(): + module_name, weight_name = full_weight_name.rsplit(".", 1) + module = named_modules[module_name] + weight = getattr(module, weight_name) + if not hasattr(weight, attr_name): + setattr(weight, attr_name, _bond_method_to_cls(attr, weight)) + + # Step I1: reload bfloat16 / high precision weights + loaded_weights = model.load_weights( + model_loader.get_all_weights(model_config, model)) + + # Step I2: online quantize the weights + # manually process weights after loading + model.process_weights_after_loading_already_called = False + process_weights_after_loading(model, model_config, model_device) + model.process_weights_after_loading_already_called = True + return loaded_weights diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 03202e13c280..293edadcc240 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -95,6 +95,13 @@ def initialize_model( def process_weights_after_loading(model: nn.Module, model_config: ModelConfig, target_device: torch.device) -> None: + + # to avoid circular dependency + from vllm.model_executor.model_loader.online_quantization import ( + maybe_save_metadata_and_attributes_for_weight_reloading) + maybe_save_metadata_and_attributes_for_weight_reloading( + model, model_config) + for _, module in model.named_modules(): if isinstance(module, QKVCrossParallelLinear): # NOTE(Isotr0py): special case for cross QKV layer because @@ -243,7 +250,7 @@ def get_architecture_class_name(model_config: ModelConfig) -> str: class ParamMapping: """ A class to handle parameter mapping for model weight loading. - It creates a bidirectional mapping between packed parameters and their + It creates a bidirectional mapping between packed parameters and their constituent parts. """ packed_mapping: dict[str, list[str]] diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index f52d9dd2f534..bbed43b17543 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -246,8 +246,34 @@ def get_quant_config(model_config: ModelConfig, # compressed-tensors uses a compressions_config hf_quant_config = getattr(model_config.hf_config, "compression_config", None) + if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) + + # if hf_quant_config is None, we will try to get config from + # hf_overrides + hf_overrides = model_config.hf_overrides + quantization_config_file = hf_overrides.get("quantization_config_file", + None) + if quantization_config_file is not None: + if hasattr(quant_cls, "from_config_file"): + return quant_cls.from_config_file(quantization_config_file) + else: + raise NotImplementedError( + "from_config_file is specified in hf_override config, " + "but quant_cls.from_config_file is not implemented in " + f"{quant_cls}") + quantization_config_json = hf_overrides.get( + "quantization_config_dict_json", None) + if quantization_config_json is not None: + if hasattr(quant_cls, "from_config_dict_json"): + return quant_cls.from_config_dict_json(quantization_config_json) + else: + raise NotImplementedError( + "from_config_dict_json is specified in hf_override config, " + "but quant_cls.from_config_dict_json is not implemented in " + f"{quant_cls}") + # Inflight BNB quantization if model_config.quantization == "bitsandbytes": return quant_cls.from_config({})