diff --git a/inference/generate.py b/inference/generate.py index deb594e..7e9bffe 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -8,7 +8,7 @@ import torch.distributed as dist from transformers import AutoTokenizer from safetensors.torch import load_model -from model import Transformer, ModelArgs, set_global_args +from model import Transformer, ModelArgs def sample(logits, temperature: float = 1.0): @@ -110,8 +110,7 @@ def main( torch.set_num_threads(8) torch.manual_seed(965) with open(config) as f: - config_dict = json.load(f) - args = ModelArgs(**config_dict) + args = ModelArgs(**json.load(f)) print(args) with torch.device("cuda"): model = Transformer(args)