Skip to content

Commit

Permalink
Add min changes to sample.py allowing for cpu inf
Browse files Browse the repository at this point in the history
  • Loading branch information
gkielian committed Nov 7, 2023
1 parent c8f558a commit dbbcfc1
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,30 @@
import torch
import tiktoken
from model import GPTConfig, GPT
import argparse

def parseargs():
parser = argparse.ArgumentParser(description='')
parser.add_argument("-d",
"--device",
type=str, help="device to run inference, e.g. 'cpu' or 'cuda' or 'cuda:0', 'cuda:1', etc...")
parser.add_argument("-o",
"--out_dir",
type=str, help="directory to load checkpoint from")

return parser.parse_args()

# -----------------------------------------------------------------------------
args = parseargs()
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out' # ignored if init_from is not 'resume'
out_dir = args.out_dir # ignored if init_from is not 'resume'
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
device = args.device
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------
Expand Down

0 comments on commit dbbcfc1

Please sign in to comment.