mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 11:35:18 +08:00
[Hardwware][Neuron] Simplify model load for transformers-neuronx library (#9380)
This commit is contained in:
parent
d615b5c9f8
commit
bb76538bbd
@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
|
||||
@ -108,39 +107,11 @@ class NeuronCasualLM(nn.Module):
|
||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
|
||||
split_model_dir = f"{model_name_or_path}-split"
|
||||
if _is_pretrained_neuron_checkpoint(model_name_or_path):
|
||||
split_model_dir = model_name_or_path
|
||||
elif not os.path.exists(f"{model_name_or_path}-split"):
|
||||
hf_model_cls = getattr(transformers, hf_model_cls_name)
|
||||
from transformers_neuronx.module import save_pretrained_split
|
||||
|
||||
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
|
||||
low_cpu_mem_usage=True)
|
||||
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
|
||||
|
||||
self.model = neuronx_model_cls.from_pretrained(split_model_dir,
|
||||
self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
|
||||
**kwargs)
|
||||
self.model.to_neuron()
|
||||
|
||||
|
||||
def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool:
|
||||
# Checking if the neuron checkpoint is saved in the old format.
|
||||
if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")):
|
||||
return True
|
||||
# Checking if the neuron checkpoint is saved in the new format.
|
||||
pretrained_split_files = ["config.json", "generation_config.json"]
|
||||
pretrained_split_format = ".safetensors"
|
||||
for file in pretrained_split_files:
|
||||
file_path = os.path.join(model_name_or_path, file)
|
||||
if not os.path.isfile(file_path):
|
||||
return False
|
||||
for file in os.listdir(model_name_or_path):
|
||||
if file.endswith(pretrained_split_format):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> str:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user