mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 12:34:30 +08:00
[Bugfix][Model] Fix baichuan model loader for tp (#18597)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
parent
fbb13a2c15
commit
4ce64e2df4
@ -42,7 +42,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, row_parallel_weight_loader)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -384,7 +385,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.quant_config = quant_config
|
||||
self.model = BaiChuanModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
@ -438,8 +439,10 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
is_baichuan2 = self.config.vocab_size == 125696
|
||||
if is_baichuan2:
|
||||
loaded_weight = torch.nn.functional.normalize(loaded_weight)
|
||||
|
||||
default_weight_loader(param, loaded_weight)
|
||||
if self.tp_size > 1:
|
||||
row_parallel_weight_loader(param, loaded_weight)
|
||||
else:
|
||||
default_weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user