diff --git a/official/recommendation/constants.py b/official/recommendation/constants.py index 33230e75e6c..d4faf5fd74a 100644 --- a/official/recommendation/constants.py +++ b/official/recommendation/constants.py @@ -58,6 +58,7 @@ def __init__(self, data_dir, cache_id=None): CYCLES_TO_BUFFER = 3 # The number of train cycles worth of data to "run ahead" # of the main training loop. +READY_FILE_TEMP = "ready.json.temp" READY_FILE = "ready.json" TRAIN_RECORD_TEMPLATE = "train_{}.tfrecords" diff --git a/official/recommendation/data_async_generation.py b/official/recommendation/data_async_generation.py index 30432bb5268..16ec7a792e1 100644 --- a/official/recommendation/data_async_generation.py +++ b/official/recommendation/data_async_generation.py @@ -282,11 +282,17 @@ def _construct_training_records( raise ValueError("Error detected: point counts do not match: {} vs. {}" .format(num_pts, written_pts)) - with tf.gfile.Open(os.path.join(record_dir, rconst.READY_FILE), "w") as f: + # We write to a temp file then atomically rename it to the final file, because + # writing directly to the final file can cause the main process to read a + # partially written JSON file. + ready_file_temp = os.path.join(record_dir, rconst.READY_FILE_TEMP) + with tf.gfile.Open(ready_file_temp, "w") as f: json.dump({ "batch_size": train_batch_size, "batch_count": batch_count, }, f) + ready_file = os.path.join(record_dir, rconst.READY_FILE) + tf.gfile.Rename(ready_file_temp, ready_file) log_msg("Cycle {} complete. Total time: {:.1f} seconds" .format(train_cycle, timeit.default_timer() - st))