From 5a82cef6344a8e6d907b43ae80c478c2eda39905 Mon Sep 17 00:00:00 2001 From: stephen Date: Sun, 21 Jun 2020 12:31:40 +0100 Subject: [PATCH] Added variation flag to dataset_generator. --- tools/dataset_generator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tools/dataset_generator.py b/tools/dataset_generator.py index d1c69ed7e..2cfeaaf39 100644 --- a/tools/dataset_generator.py +++ b/tools/dataset_generator.py @@ -34,6 +34,8 @@ 'The number of parallel processes during collection.') flags.DEFINE_integer('episodes_per_task', 10, 'The number of episodes to collect per task.') +flags.DEFINE_integer('variations', -1, + 'Number of variations to collect per task. -1 for all.') def check_and_make(dir): @@ -187,7 +189,10 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks): my_variation_count = variation_count.value t = tasks[task_index.value] task_env = rlbench_env.get_task(t) - if my_variation_count >= task_env.variation_count(): + var_target = task_env.variation_count() + if FLAGS.variations >= 0: + var_target = np.minimum(FLAGS.variations, var_target) + if my_variation_count >= var_target: # If we have reached the required number of variations for this # task, then move on to the next task. variation_count.value = my_variation_count = 0 @@ -250,7 +255,6 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks): break results[i] = tasks_with_problems - rlbench_env.shutdown()