mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:05:01 +08:00
[Bugfix] Fix broken deepseek fp8 TP weights loading (#24367)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
0eadaeff7e
commit
00a4e56d8d
@ -262,7 +262,7 @@ class LinearBase(CustomOp):
|
|||||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||||
if not disable_tp else 1)
|
if not disable_tp else 1)
|
||||||
|
|
||||||
def __post_init__(self):
|
def update_param_tp_status(self):
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
if isinstance(param, BasevLLMParameter):
|
if isinstance(param, BasevLLMParameter):
|
||||||
param.tp_rank = self.tp_rank
|
param.tp_rank = self.tp_rank
|
||||||
@ -459,6 +459,7 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
self.update_param_tp_status()
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
|
|
||||||
@ -1250,6 +1251,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
self.update_param_tp_status()
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
input_dim = getattr(param, "input_dim", None)
|
input_dim = getattr(param, "input_dim", None)
|
||||||
|
|||||||
@ -270,7 +270,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.weight_block_size = None
|
layer.weight_block_size = None
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = getattr(layer, "tp_size",
|
||||||
|
get_tensor_model_parallel_world_size())
|
||||||
assert self.quant_config.weight_block_size is not None
|
assert self.quant_config.weight_block_size is not None
|
||||||
layer.weight_block_size = self.quant_config.weight_block_size
|
layer.weight_block_size = self.quant_config.weight_block_size
|
||||||
block_n, block_k = (
|
block_n, block_k = (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user