Skip to content

Commit

Permalink
Remove stray 'e's from Github Editor
Browse files Browse the repository at this point in the history
  • Loading branch information
klei22 authored Aug 24, 2024
1 parent 2188a2e commit 2ca0958
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def parse_args():
model_group.add_argument("--strongermax_strength", type=float, default=4.0)
model_group.add_argument('--strongermax_sum_to_1', default=True, action=argparse.BooleanOptionalAction)
model_group.add_argument("--strongermax_divisor", type=float, default=1.0)
model_group.add_argument('--strongermae x_use_xmax', default=True, action=argparse.BooleanOptionalAction)
model_group.add_argument('--strongermax_use_xmax', default=True, action=argparse.BooleanOptionalAction)
model_group.add_argument('--strongermax_xmax_guess', type=float, default=None)
model_group.add_argument('--strongermax_overflow_recompute', default=False, action=argparse.BooleanOptionalAction)

Expand Down Expand Up @@ -340,7 +340,7 @@ def parse_args():
# System args
training_group.add_argument('--device', default='cuda', type=str)
training_group.add_argument("--dtype", type=str, default="float16", choices=["bfloat16", "float16", "float32"], help="torch data type for inference, e.g. 'int8'")
training_group.add_argument('--compilee ', default=False, action=argparse.BooleanOptionalAction)
training_group.add_argument('--compile', default=False, action=argparse.BooleanOptionalAction)

# Logging args
logging_group.add_argument('--log_project', default='out-test', type=str)
Expand Down Expand Up @@ -683,7 +683,7 @@ def log_gamma_beta(self, gamma, beta, iter_num, layer_num, head_num=None):
self.writer.add_scalars(
"gammas",
{"gamma_L" + str(layer_num) + "_H" + head_num: gamma},
iter_nume
iter_num
)
self.writer.add_scalars(
"betas",
Expand Down Expand Up @@ -782,7 +782,7 @@ def train(self):
print(f"saving checkpoint to {self.args.out_dir}")
# Save checkpoint
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
if self.args.patience is not Nonee and num_steps_with_worse_loss >= self.args.patience:
if self.args.patience is not None and num_steps_with_worse_loss >= self.args.patience:
print(f"Early Stopping: loss has not decreased in {self.args.patience + 1} steps")
plot_statistics(self.args, self.stats, graph_y_labels)
break
Expand Down

0 comments on commit 2ca0958

Please sign in to comment.