From 5c6c54d67a4d7d08f1db8bcc80612d44595d1b4f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 9 Aug 2024 17:23:46 -0400 Subject: [PATCH] [Bugfix] Fix `PerTensorScaleParameter` weight loading for fused models (#7376) --- vllm/model_executor/layers/linear.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 646839ff303ee..e574062e4636b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -14,7 +14,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.parameter import (BasevLLMParameter, - PackedvLLMParameter) + PackedvLLMParameter, + PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -573,11 +574,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param: BasevLLMParameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): - param_data = param.data if loaded_shard_id is None: - if param.output_dim is None: - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) + if isinstance(param, PerTensorScaleParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight, + shard_id=0) + return + elif type(param) is BasevLLMParameter: + param.load_merged_column_weight(loaded_weight=loaded_weight) return self._load_fused_module_from_checkpoint(param, loaded_weight) return @@ -720,11 +723,13 @@ class QKVParallelLinear(ColumnParallelLinear): param: BasevLLMParameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None): - param_data = param.data if loaded_shard_id is None: # special case for certain models - if param.output_dim is None: - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) + if isinstance(param, PerTensorScaleParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight, + shard_id=0) + return + elif type(param) is BasevLLMParameter: + param.load_merged_column_weight(loaded_weight=loaded_weight) return self._load_fused_module_from_checkpoint(param, loaded_weight) return