keep improving

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-08-26 18:09:40 +08:00
parent 348e741a11
commit 21b2dfe172

View File

@ -8,7 +8,7 @@ import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs, set_global_args
from model import Transformer, ModelArgs
def sample(logits, temperature: float = 1.0):
@ -110,8 +110,7 @@ def main(
torch.set_num_threads(8)
torch.manual_seed(965)
with open(config) as f:
config_dict = json.load(f)
args = ModelArgs(**config_dict)
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)