Skip to content

Commit

Permalink
[imitation] Prevent overflow in total_frames.
Browse files Browse the repository at this point in the history
  • Loading branch information
vladfi1 committed Jan 20, 2025
1 parent 36dca31 commit 598b331
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions slippi_ai/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ def train(config: Config):
train_stats, _ = train_manager.step()
logging.info('loss initial: %f', _get_loss(train_stats))

step = tf.Variable(0, trainable=False, name="step")
with tf.device('/cpu:0'):
step = tf.Variable(0, trainable=False, name="step", dtype=tf.int64)

# saving and restoring
tf_state = dict(
Expand Down Expand Up @@ -400,7 +401,7 @@ def maybe_log(train_stats: dict):
print()

def maybe_eval():
total_steps = step.numpy()
total_steps = int(step.numpy())
if total_steps % runtime.eval_every_n != 0:
return

Expand Down

0 comments on commit 598b331

Please sign in to comment.