[Hardwware][Neuron] Simplify model load for transformers-neuronx library (#9380)

This commit is contained in:
Shashwat Srijan 2024-10-17 15:39:39 -07:00 committed by GitHub
parent d615b5c9f8
commit bb76538bbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import transformers
from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
@ -108,39 +107,11 @@ class NeuronCasualLM(nn.Module):
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
split_model_dir = f"{model_name_or_path}-split"
if _is_pretrained_neuron_checkpoint(model_name_or_path):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
hf_model_cls = getattr(transformers, hf_model_cls_name)
from transformers_neuronx.module import save_pretrained_split
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = neuronx_model_cls.from_pretrained(split_model_dir,
self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
**kwargs)
self.model.to_neuron()
def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool:
# Checking if the neuron checkpoint is saved in the old format.
if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")):
return True
# Checking if the neuron checkpoint is saved in the new format.
pretrained_split_files = ["config.json", "generation_config.json"]
pretrained_split_format = ".safetensors"
for file in pretrained_split_files:
file_path = os.path.join(model_name_or_path, file)
if not os.path.isfile(file_path):
return False
for file in os.listdir(model_name_or_path):
if file.endswith(pretrained_split_format):
return True
return False
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures: