Skip to content

Commit

Permalink
Update sample.py with eval parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
gkielian committed Sep 5, 2024
1 parent 37ca368 commit 457ce68
Showing 1 changed file with 48 additions and 3 deletions.
51 changes: 48 additions & 3 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def parse_args():
parser.add_argument('--token_boundary', type=str, default=None, help="optional separator between emitted tokens")
parser.add_argument('--print_model_info', default=True, action=argparse.BooleanOptionalAction, help="print info about model before infernece")

parser.add_argument("--eval_only", action=argparse.BooleanOptionalAction, help="Enable evaluation only mode to calculate and print validation loss")
parser.add_argument("--eval_iters", type=int, default=250, help="iterations for evaluation")
parser.add_argument("--eval_dataset", type=str, default=None, help="dataset for evaluation")


return parser.parse_args()


Expand Down Expand Up @@ -104,6 +109,32 @@ def save_quantized_data(state_dict, out_file):
with open(f"{out_file}.pkl", 'wb') as f:
pickle.dump(to_save, f)

def load_validation_data(block_size, eval_dataset):
# Load validation data similar to how train data is handled
val_path = os.path.join('data', eval_dataset, 'val.bin')
assert os.path.exists(val_path), f"Validation data file {val_path} not found."
# Assuming validation data is similar in format to train data
val_data = np.memmap(val_path, dtype=np.uint16, mode='r')
return val_data

def get_batch(data, block_size, device):
# Create a random batch from the dataset
ix = torch.randint(len(data) - block_size, (1,))
x = torch.stack([torch.from_numpy((data[i:i + block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i + 1:i + 1 + block_size]).astype(np.int64)) for i in ix])
return x.to(device), y.to(device)

def calculate_validation_loss(model, val_data, block_size, eval_iters, device, dtype):
model.eval()
losses = []
with torch.no_grad():
for _ in range(eval_iters):
X, Y = get_batch(val_data, block_size, device)
with torch.amp.autocast(device_type=device, dtype=dtype):
logits, loss = model(X, Y)
losses.append(loss.item())
return np.mean(losses)

def main():
args = parse_args()

Expand Down Expand Up @@ -147,18 +178,32 @@ def main():
model.eval()
model.to(args.device)

# Inference with different block size (note: for this one cannot use abs pos embeddings)
if args.block_size:
model.update_block_size(args.block_size)

# Print the model summary
if args.print_model_info:
print_summary(model)
print_model_blocks(model)
print_module_structure(model)

if args.eval_only:
print("Running in eval_only mode...")
# Load the validation dataset
print(model.config.block_size)
val_data = load_validation_data(model.config.block_size,
args.eval_dataset)
# Calculate validation loss
val_loss = calculate_validation_loss(model, val_data,
model.config.block_size,
args.eval_iters, args.device, ptdtype)
print(f"Validation Loss: {val_loss:.4f}")
return

if args.compile:
model = torch.compile(model)

# Inference with different block size (note: for this one cannot use abs pos embeddings)
if args.block_size:
model.update_block_size(args.block_size)

# Inference with different number of angles
if args.sym_rot_num_angles:
Expand Down

0 comments on commit 457ce68

Please sign in to comment.