From 016b8d1b7f1d51a46306ebf0d62f0499bce2a36d Mon Sep 17 00:00:00 2001 From: Ruheena Suhani Shaik Date: Tue, 15 Jul 2025 08:56:08 +0530 Subject: [PATCH] Enabled BnB NF4 inference on Gaudi (#20172) Signed-off-by: Ruheena Suhani Shaik --- .../layers/quantization/bitsandbytes.py | 12 ++++++------ .../model_loader/bitsandbytes_loader.py | 14 ++++++++++++-- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 92a46ad65cb8..a96f3ee5c301 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -390,12 +391,11 @@ def _apply_bnb_4bit_fake( try: - direct_register_custom_op( - op_name="apply_bnb_4bit", - op_func=_apply_bnb_4bit, - mutates_args=["out"], - fake_impl=_apply_bnb_4bit_fake, - ) + direct_register_custom_op(op_name="apply_bnb_4bit", + op_func=_apply_bnb_4bit, + mutates_args=["out"], + fake_impl=_apply_bnb_4bit_fake, + dispatch_key=current_platform.dispatch_key) apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit except AttributeError as error: diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index d22b1e7b67d4..907bc3c13619 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -199,6 +199,10 @@ class BitsAndBytesModelLoader(BaseModelLoader): if self.pre_quant: if self.load_8bit: + if current_platform.is_hpu(): + raise ValueError( + "currently hpu supports 4bit quantization only") + return self._quantized_8bit_generator( hf_weights_files, use_safetensors, quant_state_dict), quant_state_dict @@ -302,6 +306,10 @@ class BitsAndBytesModelLoader(BaseModelLoader): in temp_state_dict): quant_state = _parse_quant_state(mapped_weight_name, temp_state_dict) + if current_platform.is_hpu(): + assert quant_state.quant_type == "nf4", ( + "currently hpu supports nf4 quant_type only") + quant_state_dict[mapped_weight_name] = quant_state yield org_weight_name, weight_tensor else: @@ -372,10 +380,12 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...] # bitsandbytes requires data in GPU - if weight_sub_tensor.is_cuda: + if (weight_sub_tensor.is_cuda + or weight_sub_tensor.device.type == "hpu"): loaded_weight = weight_sub_tensor else: - loaded_weight = weight_sub_tensor.cuda() + loaded_weight = weight_sub_tensor.to( + device=current_platform.device_type) # remove the following after the issue is fixed: # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342