From da94c7c0eb8dabea9c500dbd70fa042497497689 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 18 Nov 2025 16:52:41 -0800 Subject: [PATCH] Move online quantization to `model.load_weights` (#26327) Signed-off-by: Jerry Zhang --- examples/offline_inference/rlhf.py | 2 +- .../offline_inference/rlhf_online_quant.py | 162 +++++++++++++++ .../model_loader/default_loader.py | 46 +---- .../model_loader/online_quantization.py | 195 +++++++++++------- vllm/model_executor/model_loader/utils.py | 8 + vllm/model_executor/models/utils.py | 4 + 6 files changed, 309 insertions(+), 108 deletions(-) create mode 100644 examples/offline_inference/rlhf_online_quant.py diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index 0c09e603271de..6f05968ce065e 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -62,7 +62,7 @@ ray.init() # Create a placement group that reserves GPU 1–2 for the vLLM inference engine. # Learn more about Ray placement groups: -# https://docs.ray.io/en/latest/placement-groups.html +# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) ray.get(pg_inference.ready()) scheduling_inference = PlacementGroupSchedulingStrategy( diff --git a/examples/offline_inference/rlhf_online_quant.py b/examples/offline_inference/rlhf_online_quant.py new file mode 100644 index 0000000000000..2d98ad22c589e --- /dev/null +++ b/examples/offline_inference/rlhf_online_quant.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray. + +The script separates training and inference workloads onto distinct GPUs +so that Ray can manage process placement and inter-process communication. +A Hugging Face Transformer model occupies GPU 0 for training, whereas a +tensor-parallel vLLM inference engine occupies GPU 1–2. + +The example performs the following steps: + +* Load the training model on GPU 0. +* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism + and Ray placement groups. +* Generate text from a list of prompts using the inference engine. +* Update the weights of the training model and broadcast the updated weights + to the inference engine by using a Ray collective RPC group. Note that + for demonstration purposes we simply zero out the weights. + +For a production-ready implementation that supports multiple training and +inference replicas, see the OpenRLHF framework: +https://github.com/OpenRLHF/OpenRLHF + +This example assumes a single-node cluster with three GPUs, but Ray +supports multi-node clusters. vLLM expects the GPUs are only used for vLLM +workloads. Residual GPU activity interferes with vLLM memory profiling and +causes unexpected behavior. +""" + +import json +import os + +import ray +import torch +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from rlhf_utils import stateless_init_process_group +from torchao.core.config import config_to_dict +from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, + PerRow, +) +from transformers import AutoModelForCausalLM + +from vllm import LLM, SamplingParams +from vllm.utils.network_utils import get_ip, get_open_port + + +class MyLLM(LLM): + """Configure the vLLM worker for Ray placement group execution.""" + + def __init__(self, *args, **kwargs): + # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray + # so that vLLM can manage its own device placement within the worker. + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + super().__init__(*args, **kwargs) + + +# Load the OPT-125M model onto GPU 0 for the training workload. +train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") +train_model.to("cuda:0") + +# Initialize Ray and set the visible devices. The vLLM engine will +# be placed on GPUs 1 and 2. +os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" +ray.init() + +# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. +# Learn more about Ray placement groups: +# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html +pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) +ray.get(pg_inference.ready()) +scheduling_inference = PlacementGroupSchedulingStrategy( + placement_group=pg_inference, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=0, +) + +# Launch the vLLM inference engine. The `enforce_eager` flag reduces +# start-up latency. + +# generate torchao quantization config for RL rollout +# see https://github.com/vllm-project/vllm/pull/23014 for instructions to +# use serialized config files instead of passing around json string +config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + +json_str = json.dumps(config_to_dict(config)) + +llm = ray.remote( + num_cpus=0, + num_gpus=0, + scheduling_strategy=scheduling_inference, +)(MyLLM).remote( + model="facebook/opt-125m", + hf_overrides={"quantization_config_dict_json": json_str}, + enforce_eager=True, + worker_extension_cls="rlhf_utils.WorkerExtension", + tensor_parallel_size=2, + distributed_executor_backend="ray", +) + +# Generate text from the prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +sampling_params = SamplingParams(temperature=0) + +outputs = ray.get(llm.generate.remote(prompts, sampling_params)) + +print("-" * 50) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + +# Set up the communication channel between the training process and the +# inference engine. +master_address = get_ip() +master_port = get_open_port() + +handle = llm.collective_rpc.remote( + "init_weight_update_group", args=(master_address, master_port, 1, 3) +) + +model_update_group = stateless_init_process_group( + master_address, master_port, 0, 3, torch.device("cuda:0") +) +ray.get(handle) + +# Simulate a training step by zeroing out all model weights. +# In a real RLHF training loop the weights would be updated using the gradient +# from an RL objective such as PPO on a reward model. +for name, p in train_model.named_parameters(): + p.data.zero_() + +# Synchronize the updated weights to the inference engine. +for name, p in train_model.named_parameters(): + dtype_name = str(p.dtype).split(".")[-1] + handle = llm.collective_rpc.remote( + "update_weight", args=(name, dtype_name, p.shape) + ) + model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) + ray.get(handle) + +# Verify that the inference weights have been updated. +assert all(ray.get(llm.collective_rpc.remote("check_weights_changed"))) + +# Generate text with the updated model. The output is expected to be nonsense +# because the weights are zero. +outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) +print("-" * 50) +for output in outputs_updated: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index c06ac550a94ae..b80026741781f 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import ( fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, + get_quant_config, maybe_download_from_modelscope, multi_thread_pt_weights_iterator, multi_thread_safetensors_weights_iterator, @@ -273,42 +274,17 @@ class DefaultModelLoader(BaseModelLoader): ) def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: - if model_config.quantization == "torchao" and torchao_version_at_least( - "0.14.0" - ): - self.load_config.safetensors_load_strategy = "torchao" + if model_config.quantization == "torchao": + quant_config = get_quant_config(model_config, self.load_config) + if ( + hasattr(quant_config, "is_checkpoint_torchao_serialized") + and quant_config.is_checkpoint_torchao_serialized + and torchao_version_at_least("0.14.0") + ): + self.load_config.safetensors_load_strategy = "torchao" + weights_to_load = {name for name, _ in model.named_parameters()} - - # if we don't have `model.weight_metadata_and_attr_saved` defined and - # set to True, it means that this is either offline quantization case - # or the first run of online quantization - # see online_quantization.py for detailed notes - offline_quantization_or_first_run_of_online_quantization = not getattr( - model, "weight_metadata_and_attr_saved", False - ) - - if model_config.quantization is None: - # model is not quantized - loaded_weights = model.load_weights( - self.get_all_weights(model_config, model) - ) - elif offline_quantization_or_first_run_of_online_quantization: - # case 1: offline quantized checkpoint - # case 2: Step I1 first run of weight loading with - # online quantization - # see online_quantization.py for detailed notes - loaded_weights = model.load_weights( - self.get_all_weights(model_config, model) - ) - else: - # to avoid circular dependency - from vllm.model_executor.model_loader.online_quantization import ( - load_weights_and_online_quantize, - ) - - # subsequent runs of weight loading with online - # quantization - loaded_weights = load_weights_and_online_quantize(self, model, model_config) + loaded_weights = model.load_weights(self.get_all_weights(model_config, model)) self.counter_after_loading_weights = time.perf_counter() logger.info_once( diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py index 890dd7231a0e1..f330af85bbe8b 100644 --- a/vllm/model_executor/model_loader/online_quantization.py +++ b/vllm/model_executor/model_loader/online_quantization.py @@ -2,13 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import types +from collections.abc import Iterable import torch from torch import nn from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.utils import process_weights_after_loading logger = init_logger(__name__) @@ -56,6 +56,9 @@ logger = init_logger(__name__) # R4. quantize weights (by calling process_weights_after_loading), # also set `process_weights_after_loading_already_called` to # True to stop it from running again +# R5. (workaround for cudagraph), we restore the weight params to original quantized +# weights params, and use original_weight_param.copy_(updated_weight_param) so that +# the weight update work well with cudagraph # process_weights_after_loading (if called): # this will be skipped since it's already ran in # load_weights @@ -69,14 +72,6 @@ def maybe_save_metadata_and_attributes_for_weight_reloading( if model_config.quantization != "torchao": return - if getattr(model, "process_weights_after_loading_already_called", False): - # In case `process_weights_after_loading` is called multiple times - # we'll skip it at later times - logger.warning( - "process_weights_after_loading already called for model %s", model - ) - return - from vllm.model_executor.model_loader.weight_utils import get_quant_config quant_config = get_quant_config(model_config, None) @@ -137,6 +132,7 @@ def maybe_save_metadata_and_attributes_for_weight_reloading( else: model.recorded_weight_attr[name][key] = attr # mark the metadata and attributes saved so we don't run it again + model._model_config = model_config model.weight_metadata_and_attr_saved = True @@ -148,77 +144,132 @@ def _bond_method_to_cls(func, obj): return types.MethodType(func, obj) -def load_weights_and_online_quantize( - model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig -) -> set[str]: +def support_quantized_model_reload_from_hp_weights(original_load_weights): + """Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support + reloading high precision (bfloat16/float16/float32) weight for an already quantized + model, this involves restoring the weights to a high precision weights and + then online quantize the weights + """ # online quantization, right now only enabled for # torchao - # R1, R2, R3, R4 in the Notes + # R1, R2, R3, R4, R5 in the Notes - # TODO: Add fp8 support - assert model_config.quantization == "torchao", ( - "online quantization is only enabled for torchao currently" - ) - # TODO: use create_weights to restore the weights to original state + def patched_model_load_weights( + auto_weight_loader, weights: Iterable[tuple[str, torch.Tensor]], *, mapper=None + ) -> set[str]: + model = auto_weight_loader.module + offline_quantization_or_first_run_of_online_quantization = not getattr( + model, "weight_metadata_and_attr_saved", False + ) - # Step R1: First restore the quantized weights to original bfloat16 - # weights, with original metadata (shape, dtype, device) - # and attributes, so that bfloat16 weights can be loaded properly - existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys() - named_modules = dict(model.named_modules(remove_duplicate=False)) - model_device = None + # if we don't have `model.weight_metadata_and_attr_saved` defined and + # set to True, it means that this is either offline quantization case + # or the first run of online quantization + # see Notes in this file for more details + if offline_quantization_or_first_run_of_online_quantization: + # case 1: offline quantized checkpoint + # case 2: Step I1 first run of weight loading with + # online quantization + return original_load_weights(auto_weight_loader, weights, mapper=mapper) - # Step R2: recover the parameter to the state before first loading - for name, d in model.original_weights_rebuild_keys.items(): - _shape = d["shape"] - _dtype = d["dtype"] - _device = d["device"] + model_config = model._model_config + + # TODO: Add fp8 support + assert model_config.quantization == "torchao", ( + "online quantization is only enabled for torchao currently" + ) + # TODO: use create_weights to restore the weights to original state + + # Step R1: First restore the quantized weights to original bfloat16 + # weights, with original metadata (shape, dtype, device) + # and attributes, so that bfloat16 weights can be loaded properly + # TODO: maybe set remove_duplicate to True? + original_quantized_weight_dict = dict( + model.named_parameters(remove_duplicate=False) + ) + named_modules = dict(model.named_modules(remove_duplicate=False)) + model_device = None + + for name, d in model.original_weights_rebuild_keys.items(): + _shape = d["shape"] + _dtype = d["dtype"] + _device = d["device"] + if model_device is not None: + assert model_device == _device, ( + "Expecting all weights " + "to be in the same device for now, got both: " + f"{model_device} and {_device}" + ) + else: + model_device = _device + + if name in original_quantized_weight_dict: + module_name, weight_name = name.rsplit(".", 1) + module = named_modules[module_name] + setattr( + module, + weight_name, + torch.nn.Parameter( + torch.empty(_shape, dtype=_dtype, device=_device), + requires_grad=False, + ), + ) + + # Step R2: recover the weight attributes to the state before first loading + # recorded_weight_attr is + # {"weight_name": {"weight_attr_key": attr}} + # e.g. + # { + # { + # "layer.0.weight": { + # "weight_loader": weight_loader_function_object, + # "input_dim": 0, ... + # }, + # "layer.1.weight": ..., + # } + # } + for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items(): + for attr_name, attr in weight_attr_dict.items(): + module_name, weight_name = full_weight_name.rsplit(".", 1) + module = named_modules[module_name] + weight = getattr(module, weight_name) + if not hasattr(weight, attr_name): + setattr(weight, attr_name, _bond_method_to_cls(attr, weight)) + + # Step R3: reload bfloat16 / high precision weights + updated_params = original_load_weights( + auto_weight_loader, weights, mapper=mapper + ) + + # Step R4: online quantize the weights + # manually process weights after loading + model.process_weights_after_loading_already_called = False if model_device is not None: - assert model_device == _device, ( - "Expecting all weights " - "to be in the same device for now, got both: " - f"{model_device} and {_device}" - ) + process_weights_after_loading(model, model_config, model_device) else: - model_device = _device - - if name in existing_param_names: - module_name, weight_name = name.rsplit(".", 1) - module = named_modules[module_name] - setattr( - module, - weight_name, - torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)), + logger.warning_once( + "model_device is None, skip calling process_weights_after_loading" ) - # recorded_weight_attr is - # {"weight_name": {"weight_attr_key": attr}} - # e.g. - # { - # { - # "layer.0.weight": { - # "weight_loader": weight_loader_function_object, - # "input_dim": 0, ... - # }, - # "layer.1.weight": ..., - # } - # } - for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items(): - for attr_name, attr in weight_attr_dict.items(): - module_name, weight_name = full_weight_name.rsplit(".", 1) - module = named_modules[module_name] - weight = getattr(module, weight_name) - if not hasattr(weight, attr_name): - setattr(weight, attr_name, _bond_method_to_cls(attr, weight)) + # Step R5 (workaround for cudagraph): restore the original quantized weights + # and do a copy_ of the currents weights to the original weights + updated_quantized_weights = dict(model.named_parameters(remove_duplicate=False)) + for name in model.original_weights_rebuild_keys: + if name in original_quantized_weight_dict: + original_quantized_weight = original_quantized_weight_dict[name] + updated_quantized_weight = updated_quantized_weights[name] - # Step I1: reload bfloat16 / high precision weights - loaded_weights = model.load_weights( - model_loader.get_all_weights(model_config, model) - ) + module_name, weight_name = name.rsplit(".", 1) + module = named_modules[module_name] + setattr(module, weight_name, original_quantized_weight) + with torch.no_grad(): + original_quantized_weight.copy_(updated_quantized_weight) - # Step I2: online quantize the weights - # manually process weights after loading - model.process_weights_after_loading_already_called = False - process_weights_after_loading(model, model_config, model_device) - model.process_weights_after_loading_already_called = True - return loaded_weights + del original_quantized_weight_dict + del named_modules + del updated_quantized_weight + + model.process_weights_after_loading_already_called = True + return updated_params + + return patched_model_load_weights diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index ba708a098c0da..e74434e9d12cb 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -88,6 +88,14 @@ def initialize_model( def process_weights_after_loading( model: nn.Module, model_config: ModelConfig, target_device: torch.device ) -> None: + if getattr(model, "process_weights_after_loading_already_called", False): + # In case `process_weights_after_loading` is called multiple times + # we'll skip it at later times + logger.debug_once( + "process_weights_after_loading already called for model %s", model + ) + return + # to avoid circular dependency from vllm.model_executor.model_loader.online_quantization import ( maybe_save_metadata_and_attributes_for_weight_reloading, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index ca5af358e2eed..ccefd7e66697f 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -21,6 +21,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, ) +from vllm.model_executor.model_loader.online_quantization import ( + support_quantized_model_reload_from_hp_weights, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import supports_any_eagle from vllm.multimodal import NestedTensors @@ -316,6 +319,7 @@ class AutoWeightsLoader: ) raise ValueError(msg) + @support_quantized_model_reload_from_hp_weights def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]],