diff --git a/sample.py b/sample.py index 78d8c4e88..8432ed91f 100644 --- a/sample.py +++ b/sample.py @@ -7,17 +7,30 @@ import torch import tiktoken from model import GPTConfig, GPT +import argparse + +def parseargs(): + parser = argparse.ArgumentParser(description='') + parser.add_argument("-d", + "--device", + type=str, help="device to run inference, e.g. 'cpu' or 'cuda' or 'cuda:0', 'cuda:1', etc...") + parser.add_argument("-o", + "--out_dir", + type=str, help="directory to load checkpoint from") + + return parser.parse_args() # ----------------------------------------------------------------------------- +args = parseargs() init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl') -out_dir = 'out' # ignored if init_from is not 'resume' +out_dir = args.out_dir # ignored if init_from is not 'resume' start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" num_samples = 10 # number of samples to draw max_new_tokens = 500 # number of tokens generated in each sample temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability seed = 1337 -device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. +device = args.device dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' compile = False # use PyTorch 2.0 to compile the model to be faster # -----------------------------------------------------------------------------