mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 16:56:07 +08:00
Enabled BnB NF4 inference on Gaudi (#20172)
Signed-off-by: Ruheena Suhani Shaik <rsshaik@habana.ai>
This commit is contained in:
parent
80305c1b24
commit
016b8d1b7f
@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
@ -390,12 +391,11 @@ def _apply_bnb_4bit_fake(
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(op_name="apply_bnb_4bit",
|
||||||
op_name="apply_bnb_4bit",
|
op_func=_apply_bnb_4bit,
|
||||||
op_func=_apply_bnb_4bit,
|
mutates_args=["out"],
|
||||||
mutates_args=["out"],
|
fake_impl=_apply_bnb_4bit_fake,
|
||||||
fake_impl=_apply_bnb_4bit_fake,
|
dispatch_key=current_platform.dispatch_key)
|
||||||
)
|
|
||||||
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
|
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
|
||||||
|
|
||||||
except AttributeError as error:
|
except AttributeError as error:
|
||||||
|
|||||||
@ -199,6 +199,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
if self.pre_quant:
|
if self.pre_quant:
|
||||||
if self.load_8bit:
|
if self.load_8bit:
|
||||||
|
if current_platform.is_hpu():
|
||||||
|
raise ValueError(
|
||||||
|
"currently hpu supports 4bit quantization only")
|
||||||
|
|
||||||
return self._quantized_8bit_generator(
|
return self._quantized_8bit_generator(
|
||||||
hf_weights_files, use_safetensors,
|
hf_weights_files, use_safetensors,
|
||||||
quant_state_dict), quant_state_dict
|
quant_state_dict), quant_state_dict
|
||||||
@ -302,6 +306,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
in temp_state_dict):
|
in temp_state_dict):
|
||||||
quant_state = _parse_quant_state(mapped_weight_name,
|
quant_state = _parse_quant_state(mapped_weight_name,
|
||||||
temp_state_dict)
|
temp_state_dict)
|
||||||
|
if current_platform.is_hpu():
|
||||||
|
assert quant_state.quant_type == "nf4", (
|
||||||
|
"currently hpu supports nf4 quant_type only")
|
||||||
|
|
||||||
quant_state_dict[mapped_weight_name] = quant_state
|
quant_state_dict[mapped_weight_name] = quant_state
|
||||||
yield org_weight_name, weight_tensor
|
yield org_weight_name, weight_tensor
|
||||||
else:
|
else:
|
||||||
@ -372,10 +380,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
...]
|
...]
|
||||||
|
|
||||||
# bitsandbytes requires data in GPU
|
# bitsandbytes requires data in GPU
|
||||||
if weight_sub_tensor.is_cuda:
|
if (weight_sub_tensor.is_cuda
|
||||||
|
or weight_sub_tensor.device.type == "hpu"):
|
||||||
loaded_weight = weight_sub_tensor
|
loaded_weight = weight_sub_tensor
|
||||||
else:
|
else:
|
||||||
loaded_weight = weight_sub_tensor.cuda()
|
loaded_weight = weight_sub_tensor.to(
|
||||||
|
device=current_platform.device_type)
|
||||||
|
|
||||||
# remove the following after the issue is fixed:
|
# remove the following after the issue is fixed:
|
||||||
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
|
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user