diff --git a/nbs/05_environments.ipynb b/nbs/05_environments.ipynb index 496c30be..de85aa32 100644 --- a/nbs/05_environments.ipynb +++ b/nbs/05_environments.ipynb @@ -59,7 +59,7 @@ "from pct.webots import WebotsHelper\n", "# from pct.yaw_module import YawEnv\n", "from pct.arc import ARCEnv\n", - "from pct.helpers import ListChecker, ChallengesDataManager" + "from pct.helpers import ListChecker, DataManagerSingleton" ] }, { @@ -1558,9 +1558,12 @@ "\n", " def set_properties(self, props: dict) -> None:\n", "\n", - " data = props['data']\n", + " # data = props['data']\n", + "\n", + " data_mgr = DataManagerSingleton.get_instance(folder = 'c:/tmp/arc-prize-2024', prefix = 'arc-agi_simple_', show_timing=True)\n", + " data = data_mgr.get_data_for_code(props['code'])\n", + " props['test_output_array'] = data_mgr.get_solutions_for_code(props['code'])\n", "\n", - " # props['data']=data\n", " self.env.initialise(props, data)\n", " self.fitness = self.env.fitness\n", " self.history = props.get('history', 5)\n", @@ -1667,7 +1670,6 @@ " self.done, details = ListChecker.check_list_unchanged(self.boxcar, rel_tol =get_rel_tol('ARC-change'), abs_tol=get_abs_tol('ARC-change'), gradient_abs_tol=get_abs_tol('ARC-gradient'))\n", " if self.done:\n", " self.env.add_to_gradient_list(details['gradient_range']) \n", - " # self.env.fitness_isclose_to_zero = ListChecker.check_float_list_close_to_zero(self.boxcar, rel_tol = 0, abs_tol=get_abs_tol('ARC-zero'), gradient_abs_tol=get_abs_tol('ARC-gradient'))\n", "\n", " if self.done:\n", " self.env.add_to_fitness_list(max(self.boxcar) )\n", @@ -1747,51 +1749,31 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_actions': 1, 'grid_shape': 'equal', 'dims': 1}\n", - "ARC ARC | [4] | links constant \n", - "ARC ARC | [5] | links constant \n", - "ARC ARC | [6] | links constant \n", - "ARC ARC | [7] | links constant \n", - "ARC ARC | [8] | links constant \n", - "ARC ARC | [9] | links constant \n", - "{'type': 'ARC', 'name': 'ARC', 'value': [9], 'links': {0: 'constant'}, 'env_name': 'ARC'}\n", - "\n", - "1.000 \n", - "9.000 0 True {'num_actions': 1, 'grid_shape': 'equal', 'dims': 1}\n", - "\n", - "[9]\n" - ] - } - ], + "outputs": [], "source": [ "#| gui\n", "env = ARC()\n", "env.add_link(Constant(1))\n", "# env.add_link(Constant(0))\n", "properties = { 'dir': 'C:\\\\packages\\\\arc-prize-2024', 'file_prefix':'arc-agi_training_', 'code':'007bbfb7', 'dataset': 'train', 'control_set': ['dims'], 'input_set': ['env']}\n", - "file_name = os.path.join(properties['dir'], properties['file_prefix']) + 'challenges.json' \n", - "challenges_manager = ChallengesDataManager(file_name)\n", - "data = challenges_manager.get_data_for_key(properties['code'])\n", - "properties['data']=data\n", - "\n", - "env.set_properties(properties)\n", - "env.set_render(True)\n", - "env.reset()\n", - "print(env.env.info)\n", - "for i in range(6):\n", - " state = env()\n", - " env.summary() \n", - " # print()\n", - "print(env.get_config())\n", - "print()\n", - "print(env.output_string()) \n", - "print()\n", - "print(state)\n", + "# file_name = os.path.join(properties['dir'], properties['file_prefix']) + 'challenges.json' \n", + "# challenges_manager = ChallengesDataManager(file_name)\n", + "# data = challenges_manager.get_data_for_key(properties['code'])\n", + "# properties['data']=data\n", + "\n", + "# env.set_properties(properties)\n", + "# env.set_render(True)\n", + "# env.reset()\n", + "# print(env.env.info)\n", + "# for i in range(6):\n", + "# state = env()\n", + "# env.summary() \n", + "# # print()\n", + "# print(env.get_config())\n", + "# print()\n", + "# print(env.output_string()) \n", + "# print()\n", + "# print(state)\n", "\n", "# env.close()" ] diff --git a/nbs/14_helpers.ipynb b/nbs/14_helpers.ipynb index 4a24600f..155a124a 100644 --- a/nbs/14_helpers.ipynb +++ b/nbs/14_helpers.ipynb @@ -16,7 +16,16 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -50,7 +59,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## ListChecker" + "# ListChecker" ] }, { @@ -61,8 +70,6 @@ "source": [ "#| export\n", "class ListChecker:\n", - "\n", - "\n", " \n", " @staticmethod\n", " def check_list_unchanged(float_list, rel_tol=1e-9, abs_tol=0.0, gradient_abs_tol=0.0):\n", @@ -157,38 +164,52 @@ "\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## check_float_list_close_to_zero" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "@staticmethod\n", - "def check_float_list_close_to_zero(float_list, rel_tol=1e-9, abs_tol=0.0, gradient_abs_tol=0.0):\n", - " \"\"\"\n", - " Checks if the values in the float list are close to zero within the specified tolerance\n", - " and if the gradient (difference between consecutive values) is close to zero within the specified gradient tolerance.\n", - "\n", - " Returns:\n", - " bool: True if all values are close to zero within the specified tolerance and the gradient of all consecutive values is close to zero within the specified gradient tolerance.\n", - " \"\"\"\n", - " if not float_list:\n", - " return True\n", - " \n", - " values_close_to_zero = all(\n", - " math.isclose(value, 0, rel_tol=rel_tol, abs_tol=abs_tol)\n", - " for value in float_list\n", - " )\n", - " \n", - " if len(float_list) == 1:\n", - " return values_close_to_zero\n", - " \n", - " gradients_close_to_zero = all(\n", - " math.isclose(float_list[i] - float_list[i - 1], 0, rel_tol=0, abs_tol=gradient_abs_tol)\n", - " for i in range(1, len(float_list))\n", - " )\n", - " \n", - " return values_close_to_zero and gradients_close_to_zero" + "# @staticmethod\n", + "# def check_float_list_close_to_zero(float_list, rel_tol=1e-9, abs_tol=0.0, gradient_abs_tol=0.0):\n", + "# \"\"\"\n", + "# Checks if the values in the float list are close to zero within the specified tolerance\n", + "# and if the gradient (difference between consecutive values) is close to zero within the specified gradient tolerance.\n", + "\n", + "# Returns:\n", + "# bool: True if all values are close to zero within the specified tolerance and the gradient of all consecutive values is close to zero within the specified gradient tolerance.\n", + "# \"\"\"\n", + "# if not float_list:\n", + "# return True\n", + " \n", + "# values_close_to_zero = all(\n", + "# math.isclose(value, 0, rel_tol=rel_tol, abs_tol=abs_tol)\n", + "# for value in float_list\n", + "# )\n", + " \n", + "# if len(float_list) == 1:\n", + "# return values_close_to_zero\n", + " \n", + "# gradients_close_to_zero = all(\n", + "# math.isclose(float_list[i] - float_list[i - 1], 0, rel_tol=0, abs_tol=gradient_abs_tol)\n", + "# for i in range(1, len(float_list))\n", + "# )\n", + " \n", + "# return values_close_to_zero and gradients_close_to_zero" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## check_float_list_close_to_zero" ] }, { @@ -197,32 +218,39 @@ "metadata": {}, "outputs": [], "source": [ - "@staticmethod\n", - "def check_float_list_close_to_zero(float_list, rel_tol=1e-9, abs_tol=0.0, gradient_abs_tol=0.0):\n", - " \"\"\"\n", - " Checks if the values in the float list are close to zero within the specified tolerance\n", - " and if the gradient (difference between consecutive values) is close to zero within the specified gradient tolerance.\n", - "\n", - " Returns:\n", - " bool: True if all values are close to zero within the specified tolerance and the gradient of all consecutive values is close to zero within the specified gradient tolerance.\n", - " \"\"\"\n", - " if not float_list:\n", - " return True\n", - " \n", - " values_close_to_zero = all(\n", - " math.isclose(value, 0, rel_tol=rel_tol, abs_tol=abs_tol)\n", - " for value in float_list\n", - " )\n", - " \n", - " if len(float_list) == 1:\n", - " return values_close_to_zero\n", - " \n", - " gradients_close_to_zero = all(\n", - " math.isclose(float_list[i] - float_list[i - 1], 0, rel_tol=0, abs_tol=gradient_abs_tol)\n", - " for i in range(1, len(float_list))\n", - " )\n", - " \n", - " return values_close_to_zero and gradients_close_to_zero" + "# @staticmethod\n", + "# def check_float_list_close_to_zero(float_list, rel_tol=1e-9, abs_tol=0.0, gradient_abs_tol=0.0):\n", + "# \"\"\"\n", + "# Checks if the values in the float list are close to zero within the specified tolerance\n", + "# and if the gradient (difference between consecutive values) is close to zero within the specified gradient tolerance.\n", + "\n", + "# Returns:\n", + "# bool: True if all values are close to zero within the specified tolerance and the gradient of all consecutive values is close to zero within the specified gradient tolerance.\n", + "# \"\"\"\n", + "# if not float_list:\n", + "# return True\n", + " \n", + "# values_close_to_zero = all(\n", + "# math.isclose(value, 0, rel_tol=rel_tol, abs_tol=abs_tol)\n", + "# for value in float_list\n", + "# )\n", + " \n", + "# if len(float_list) == 1:\n", + "# return values_close_to_zero\n", + " \n", + "# gradients_close_to_zero = all(\n", + "# math.isclose(float_list[i] - float_list[i - 1], 0, rel_tol=0, abs_tol=gradient_abs_tol)\n", + "# for i in range(1, len(float_list))\n", + "# )\n", + " \n", + "# return values_close_to_zero and gradients_close_to_zero" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## check_float_list_close_to_zero" ] }, { @@ -231,32 +259,39 @@ "metadata": {}, "outputs": [], "source": [ - "@staticmethod\n", - "def check_float_list_close_to_zero(float_list, rel_tol=1e-9, abs_tol=0.0, gradient_abs_tol=0.0):\n", - " \"\"\"\n", - " Checks if the values in the float list are close to zero within the specified tolerance\n", - " and if the gradient (difference between consecutive values) is close to zero within the specified gradient tolerance.\n", - "\n", - " Returns:\n", - " bool: True if all values are close to zero within the specified tolerance and the gradient of all consecutive values is close to zero within the specified gradient tolerance.\n", - " \"\"\"\n", - " if not float_list:\n", - " return True\n", - " \n", - " values_close_to_zero = all(\n", - " math.isclose(value, 0, rel_tol=rel_tol, abs_tol=abs_tol)\n", - " for value in float_list\n", - " )\n", - " \n", - " if len(float_list) == 1:\n", - " return values_close_to_zero\n", - " \n", - " gradients_close_to_zero = all(\n", - " math.isclose(float_list[i] - float_list[i - 1], 0, rel_tol=0, abs_tol=gradient_abs_tol)\n", - " for i in range(1, len(float_list))\n", - " )\n", - " \n", - " return values_close_to_zero and gradients_close_to_zero" + "# @staticmethod\n", + "# def check_float_list_close_to_zero(float_list, rel_tol=1e-9, abs_tol=0.0, gradient_abs_tol=0.0):\n", + "# \"\"\"\n", + "# Checks if the values in the float list are close to zero within the specified tolerance\n", + "# and if the gradient (difference between consecutive values) is close to zero within the specified gradient tolerance.\n", + "\n", + "# Returns:\n", + "# bool: True if all values are close to zero within the specified tolerance and the gradient of all consecutive values is close to zero within the specified gradient tolerance.\n", + "# \"\"\"\n", + "# if not float_list:\n", + "# return True\n", + " \n", + "# values_close_to_zero = all(\n", + "# math.isclose(value, 0, rel_tol=rel_tol, abs_tol=abs_tol)\n", + "# for value in float_list\n", + "# )\n", + " \n", + "# if len(float_list) == 1:\n", + "# return values_close_to_zero\n", + " \n", + "# gradients_close_to_zero = all(\n", + "# math.isclose(float_list[i] - float_list[i - 1], 0, rel_tol=0, abs_tol=gradient_abs_tol)\n", + "# for i in range(1, len(float_list))\n", + "# )\n", + " \n", + "# return values_close_to_zero and gradients_close_to_zero" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## check_float_list_close_to_zero" ] }, { @@ -265,32 +300,32 @@ "metadata": {}, "outputs": [], "source": [ - "@staticmethod\n", - "def check_float_list_close_to_zero(float_list, rel_tol=1e-9, abs_tol=0.0, gradient_abs_tol=0.0):\n", - " \"\"\"\n", - " Checks if the values in the float list are close to zero within the specified tolerance\n", - " and if the gradient (difference between consecutive values) is close to zero within the specified gradient tolerance.\n", - "\n", - " Returns:\n", - " bool: True if all values are close to zero within the specified tolerance and the gradient of all consecutive values is close to zero within the specified gradient tolerance.\n", - " \"\"\"\n", - " if not float_list:\n", - " return True\n", - " \n", - " values_close_to_zero = all(\n", - " math.isclose(value, 0, rel_tol=rel_tol, abs_tol=abs_tol)\n", - " for value in float_list\n", - " )\n", - " \n", - " if len(float_list) == 1:\n", - " return values_close_to_zero\n", - " \n", - " gradients_close_to_zero = all(\n", - " math.isclose(float_list[i] - float_list[i - 1], 0, rel_tol=0, abs_tol=gradient_abs_tol)\n", - " for i in range(1, len(float_list))\n", - " )\n", - " \n", - " return values_close_to_zero and gradients_close_to_zero" + "# @staticmethod\n", + "# def check_float_list_close_to_zero(float_list, rel_tol=1e-9, abs_tol=0.0, gradient_abs_tol=0.0):\n", + "# \"\"\"\n", + "# Checks if the values in the float list are close to zero within the specified tolerance\n", + "# and if the gradient (difference between consecutive values) is close to zero within the specified gradient tolerance.\n", + "\n", + "# Returns:\n", + "# bool: True if all values are close to zero within the specified tolerance and the gradient of all consecutive values is close to zero within the specified gradient tolerance.\n", + "# \"\"\"\n", + "# if not float_list:\n", + "# return True\n", + " \n", + "# values_close_to_zero = all(\n", + "# math.isclose(value, 0, rel_tol=rel_tol, abs_tol=abs_tol)\n", + "# for value in float_list\n", + "# )\n", + " \n", + "# if len(float_list) == 1:\n", + "# return values_close_to_zero\n", + " \n", + "# gradients_close_to_zero = all(\n", + "# math.isclose(float_list[i] - float_list[i - 1], 0, rel_tol=0, abs_tol=gradient_abs_tol)\n", + "# for i in range(1, len(float_list))\n", + "# )\n", + " \n", + "# return values_close_to_zero and gradients_close_to_zero" ] }, { @@ -302,7 +337,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "(False, {'gradient': 1.0000000050247593e-08, 'mean': 1.0000000199999999, 'std_dev': 8.164965850304249e-09})\n", + "(False, {'gradient_range': 2.220446049250313e-16, 'mean': 1.0000000199999999, 'std_dev': 8.164965850304249e-09})\n", "True\n" ] } @@ -317,6 +352,13 @@ "print(ListChecker.check_integer_list_unchanged(int_list)) # Should print: True\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ARC Data" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -331,13 +373,12 @@ "outputs": [], "source": [ "#| export\n", - "\n", - "\n", "class JSONDataManager:\n", " def __init__(self, path: str, show_timing: bool = False):\n", " self.data = self.load_json(path)\n", " self.show_timing = show_timing\n", " \n", + " \n", " def load_json(self, path: str) -> Dict:\n", " with open(path, 'r') as file:\n", " return json.load(file)\n", @@ -351,7 +392,74 @@ " print(f\"Execution time of {method.__name__}: {end_time - start_time:.4f} seconds\")\n", " return result\n", " return timed_method\n", - "\n" + "\n", + " def reload_data(self, path: str):\n", + " self.data = self.load_json(path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# #| export\n", + "# class JSONDataManager:\n", + "# def __init__(self, path: str, show_timing: bool = False):\n", + "# self.data = self.load_json(path)\n", + "# self.show_timing = show_timing\n", + " \n", + "# @JSONDataManager.timing_decorator\n", + "# def load_json(self, path: str) -> Dict:\n", + "# with open(path, 'r') as file:\n", + "# return json.load(file)\n", + " \n", + "# def timing_decorator(method):\n", + "# def timed_method(self, *args, **kwargs):\n", + "# start_time = time.time()\n", + "# result = method(self, *args, **kwargs)\n", + "# end_time = time.time()\n", + "# if self.show_timing:\n", + "# print(f\"Execution time of {method.__name__}: {end_time - start_time:.4f} seconds\")\n", + "# return result\n", + "# return timed_method\n", + "\n", + "# def reload_data(self, path: str):\n", + "# self.data = self.load_json(path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# import json\n", + "# import time\n", + "# from typing import Dict\n", + "\n", + "# def timing_decorator(method):\n", + "# def timed_method(self, *args, **kwargs):\n", + "# start_time = time.time()\n", + "# result = method(self, *args, **kwargs)\n", + "# end_time = time.time()\n", + "# if self.show_timing:\n", + "# print(f\"Execution time of {method.__name__}: {end_time - start_time:.4f} seconds\")\n", + "# return result\n", + "# return timed_method\n", + "\n", + "# class JSONDataManager:\n", + "# def __init__(self, path: str, show_timing: bool = False):\n", + "# self.data = self.load_json(path)\n", + "# self.show_timing = show_timing\n", + " \n", + "# @timing_decorator\n", + "# def load_json(self, path: str) -> Dict:\n", + "# with open(path, 'r') as file:\n", + "# return json.load(file)\n", + "\n", + "# def reload_data(self, path: str):\n", + "# self.data = self.load_json(path)" ] }, { @@ -373,6 +481,11 @@ "class ChallengesDataManager(JSONDataManager):\n", " \n", " @JSONDataManager.timing_decorator\n", + " def __init__(self, path: str, show_timing: bool = False):\n", + " super().__init__(path, show_timing)\n", + " \n", + "\n", + " @JSONDataManager.timing_decorator\n", " def get_all_keys(self) -> List[str]:\n", " return list(self.data.keys())\n", " \n", @@ -411,9 +524,11 @@ " counts = Counter(len(value['train']) for value in self.data.values())\n", " return dict(counts)\n", " \n", - " @JSONDataManager.timing_decorator\n", + " # @JSONDataManager.timing_decorator\n", " def get_data_for_key(self, key: str) -> Dict[str, Any]:\n", - " return self.data.get(key, {})\n", + " if key not in self.data:\n", + " raise KeyError(f\"Key '{key}' not found in data.\")\n", + " return self.data[key]\n", " \n", " @JSONDataManager.timing_decorator\n", " def get_arrays_for_key(self, key: str, array_type: str) -> List:\n", @@ -458,7 +573,11 @@ " \"analysis\": analysis,\n", " \"counts\": {k: len(v) for k, v in analysis.items()}\n", " }\n", - "\n" + "\n", + "\n", + " @JSONDataManager.timing_decorator\n", + " def reload_data(self, path: str):\n", + " super().reload_data(path)\n" ] }, { @@ -477,7 +596,12 @@ "#| export\n", "\n", "class SolutionsDataManager(JSONDataManager):\n", + "\n", + " @JSONDataManager.timing_decorator\n", + " def __init__(self, path: str, show_timing: bool = False):\n", + " super().__init__(path, show_timing)\n", " \n", + "\n", " @JSONDataManager.timing_decorator\n", " def get_all_keys(self) -> List[str]:\n", " return list(self.data.keys())\n", @@ -486,18 +610,21 @@ " def count_all_keys(self) -> int:\n", " return len(self.data)\n", " \n", - " @JSONDataManager.timing_decorator\n", + " # @JSONDataManager.timing_decorator\n", " def get_data_for_key(self, key: str) -> Dict[str, Any]:\n", - " data = self.data.get(key, [])\n", - " return data[0] if data else {}\n", - " \n", + " if key not in self.data:\n", + " raise KeyError(f\"Key '{key}' not found in data.\")\n", + " return self.data[key][0] \n", + "\n", " @JSONDataManager.timing_decorator\n", " def get_arrays_for_key(self, key: str, array_type: str) -> List:\n", " if key not in self.data or array_type not in self.data[key]:\n", " return []\n", " return self.data[key][array_type]\n", "\n", - "\n" + " @JSONDataManager.timing_decorator\n", + " def reload_data(self, path: str):\n", + " super().reload_data(path)\n" ] }, { @@ -509,8 +636,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Execution time of challenges load: 0.0301 seconds\n", - "Execution time of solutions load: 0.0053 seconds\n", + "Execution time of challenges load: 0.0912 seconds\n", + "Execution time of solutions load: 0.0040 seconds\n", "['007bbfb7', '00d62c1b', '017c7c7b', '025d127b', '045e512c', '0520fde7', '05269061', '05f2a901', '06df4c85', '08ed6ac7', '09629e4f', '0962bcdd', '0a938d79', '0b148d64', '0ca9ddb6', '0d3d703e', '0dfd9992', '0e206a2e', '10fcaaa3', '11852cab', '1190e5a7', '137eaa0f', '150deff5', '178fcbfb', '1a07d186', '1b2d62fb', '1b60fb0c', '1bfc4729', '1c786137', '1caeab9d', '1cf80156', '1e0a9b12', '1e32b0e9', '1f0c79e5', '1f642eb9', '1f85a75f', '1f876c06', '1fad071e', '2013d3e2', '2204b7a8', '22168020', '22233c11', '2281f1f4', '228f6490', '22eb0ac0', '234bbc79', '23581191', '239be575', '23b5c85d', '253bf280', '25d487eb', '25d8a9c8', '25ff71a9', '264363fd', '272f95fa', '27a28665', '28bf18c6', '28e73c20', '29623171', '29c11459', '29ec7d0e', '2bcee788', '2bee17df', '2c608aff', '2dc579da', '2dd70a9a', '2dee498d', '31aa019c', '321b1fc6', '32597951', '3345333e', '3428a4f5', '3618c87e', '3631a71a', '363442ee', '36d67576', '36fdfd69', '3906de3d', '39a8645d', '39e1d7f9', '3aa6fb7a', '3ac3eb23', '3af2c5a8', '3bd67248', '3bdb4ada', '3befdf3e', '3c9b0459', '3de23699', '3e980e27', '3eda0437', '3f7978a0', '40853293', '4093f84a', '41e4d17e', '4258a5f9', '4290ef0e', '42a50994', '4347f46a', '444801d8', '445eab21', '447fd412', '44d8ac46', '44f52bb0', '4522001f', '4612dd53', '46442a0e', '469497ad', '46f33fce', '47c1f68c', '484b58aa', '48d8fb45', '4938f0c2', '496994bd', '49d1d64f', '4be741c5', '4c4377d9', '4c5c2cf0', '50846271', '508bd3b6', '50cb2852', '5117e062', '5168d44c', '539a4f51', '53b68214', '543a7ed5', '54d82841', '54d9e175', '5521c0d9', '5582e5ca', '5614dbcf', '56dc2b01', '56ff96f3', '57aa92db', '5ad4f10b', '5bd6f4ac', '5c0a986e', '5c2c9af4', '5daaa586', '60b61512', '6150a2bd', '623ea044', '62c24649', '63613498', '6430c8c4', '6455b5f5', '662c240a', '67385a82', '673ef223', '6773b310', '67a3c6ac', '67a423a3', '67e8384a', '681b3aeb', '6855a6e4', '68b16354', '694f12f3', '6a1e5592', '6aa20dc0', '6b9890af', '6c434453', '6cdd2623', '6cf79266', '6d0160f0', '6d0aefbc', '6d58a25d', '6d75e8bb', '6e02f1e3', '6e19193c', '6e82a1ae', '6ecd11f4', '6f8cd79b', '6fa7a44f', '72322fa7', '72ca375d', '73251a56', '7447852a', '7468f01a', '746b3537', '74dd1130', '75b8110e', '760b3cac', '776ffc46', '77fdfe62', '780d0b14', '7837ac64', '794b24be', '7b6016b9', '7b7f7511', '7c008303', '7ddcd7ec', '7df24a62', '7e0986d6', '7f4411dc', '7fe24cdd', '80af3007', '810b9b61', '82819916', '83302e8f', '834ec97d', '8403a5d5', '846bdb03', '855e0971', '85c4e7cd', '868de0fa', '8731374e', '88a10436', '88a62173', '890034e9', '8a004b2b', '8be77c9e', '8d5021e8', '8d510a79', '8e1813be', '8e5a5113', '8eb1be9a', '8efcae92', '8f2ea7aa', '90c28cc7', '90f3ed37', '913fb3ed', '91413438', '91714a58', '9172f3a0', '928ad970', '93b581b8', '941d9a10', '94f9d214', '952a094c', '9565186b', '95990924', '963e52fc', '97999447', '97a05b5b', '98cf29f8', '995c5fa3', '99b1bc43', '99fa7670', '9aec4887', '9af7a82c', '9d9215db', '9dfd6313', '9ecd008a', '9edfc990', '9f236235', 'a1570a43', 'a2fd1cf0', 'a3325580', 'a3df8b1e', 'a416b8f3', 'a48eeaf7', 'a5313dff', 'a5f85a15', 'a61ba2ce', 'a61f2674', 'a64e4611', 'a65b410d', 'a68b268e', 'a699fb00', 'a740d043', 'a78176bb', 'a79310a0', 'a85d4709', 'a87f7484', 'a8c38be5', 'a8d7556c', 'a9f96cdd', 'aabf363d', 'aba27056', 'ac0a08a4', 'ae3edfdc', 'ae4f1146', 'aedd82e4', 'af902bf9', 'b0c4d837', 'b190f7f5', 'b1948b0a', 'b230c067', 'b27ca6d3', 'b2862040', 'b527c5c6', 'b548a754', 'b60334d2', 'b6afb2da', 'b7249182', 'b775ac94', 'b782dc8a', 'b8825c91', 'b8cdaf2b', 'b91ae062', 'b94a9452', 'b9b7f026', 'ba26e723', 'ba97ae07', 'bb43febb', 'bbc9ae5d', 'bc1d5164', 'bd4472b8', 'bda2d7a6', 'bdad9b1f', 'be94b721', 'beb8660c', 'c0f76784', 'c1d99e64', 'c3e719e8', 'c3f564a4', 'c444b776', 'c59eb873', 'c8cbb738', 'c8f0f002', 'c909285e', 'c9e6f938', 'c9f8e694', 'caa06a1f', 'cbded52d', 'cce03e0d', 'cdecee7f', 'ce22a75a', 'ce4f8723', 'ce602527', 'ce9e57f2', 'cf98881b', 'd037b0a7', 'd06dbe63', 'd07ae81c', 'd0f5fe59', 'd10ecb37', 'd13f3404', 'd22278a0', 'd23f8c26', 'd2abd087', 'd364b489', 'd406998b', 'd43fd935', 'd4469b4b', 'd4a91cb9', 'd4f3cd78', 'd511f180', 'd5d6de2d', 'd631b094', 'd687bc17', 'd6ad076f', 'd89b689b', 'd8c310e9', 'd90796e8', 'd9f24cd1', 'd9fac9be', 'dae9d2b5', 'db3e9e38', 'db93a21d', 'dbc1a6ce', 'dc0a314f', 'dc1df850', 'dc433765', 'ddf7fa4f', 'de1cd16c', 'ded97339', 'e179c5f4', 'e21d9049', 'e26a3af2', 'e3497940', 'e40b9e2f', 'e48d4e1a', 'e5062a87', 'e509e548', 'e50d258f', 'e6721834', 'e73095fd', 'e76a88a6', 'e8593010', 'e8dc4411', 'e9614598', 'e98196ab', 'e9afcf9a', 'ea32f347', 'ea786f4a', 'eb281b96', 'eb5a1d5d', 'ec883f72', 'ecdecbb3', 'ed36ccf7', 'ef135b50', 'f15e1fac', 'f1cefba8', 'f25fbde4', 'f25ffba3', 'f2829549', 'f35d900a', 'f5b8619d', 'f76d97a5', 'f8a8fe49', 'f8b3ba0a', 'f8c80d96', 'f8ff0b80', 'f9012d9b', 'fafffa47', 'fcb5c309', 'fcc82909', 'feca6190', 'ff28f65a', 'ff805c23']\n", "400\n", "(['00d62c1b', '025d127b', '045e512c', '0520fde7', '05269061', '05f2a901', '06df4c85', '08ed6ac7', '09629e4f', '0962bcdd', '0a938d79', '0ca9ddb6', '0d3d703e', '0dfd9992', '0e206a2e', '11852cab', '150deff5', '178fcbfb', '1a07d186', '1b2d62fb', '1b60fb0c', '1bfc4729', '1caeab9d', '1e0a9b12', '1e32b0e9', '1f0c79e5', '1f642eb9', '1f876c06', '2204b7a8', '22168020', '22233c11', '2281f1f4', '228f6490', '22eb0ac0', '234bbc79', '23581191', '253bf280', '25d487eb', '25d8a9c8', '25ff71a9', '264363fd', '272f95fa', '28e73c20', '29623171', '29c11459', '29ec7d0e', '2bcee788', '2bee17df', '2c608aff', '2dd70a9a', '2dee498d', '31aa019c', '321b1fc6', '32597951', '3345333e', '3618c87e', '3631a71a', '363442ee', '36d67576', '36fdfd69', '3906de3d', '39e1d7f9', '3aa6fb7a', '3ac3eb23', '3bd67248', '3bdb4ada', '3befdf3e', '3c9b0459', '3e980e27', '3eda0437', '40853293', '4093f84a', '41e4d17e', '4258a5f9', '42a50994', '4347f46a', '444801d8', '447fd412', '44d8ac46', '4612dd53', '484b58aa', '4938f0c2', '496994bd', '4c5c2cf0', '50846271', '508bd3b6', '50cb2852', '5168d44c', '543a7ed5', '54d82841', '54d9e175', '5521c0d9', '5582e5ca', '56dc2b01', '56ff96f3', '57aa92db', '5c0a986e', '5c2c9af4', '60b61512', '6150a2bd', '623ea044', '63613498', '6455b5f5', '67385a82', '673ef223', '67a3c6ac', '67a423a3', '6855a6e4', '68b16354', '694f12f3', '6a1e5592', '6aa20dc0', '6c434453', '6cdd2623', '6cf79266', '6d0160f0', '6d0aefbc', '6d58a25d', '6d75e8bb', '6e02f1e3', '6e19193c', '6e82a1ae', '6f8cd79b', '72322fa7', '73251a56', '7447852a', '74dd1130', '760b3cac', '776ffc46', '794b24be', '7b6016b9', '7ddcd7ec', '7df24a62', '7e0986d6', '7f4411dc', '810b9b61', '82819916', '83302e8f', '834ec97d', '8403a5d5', '855e0971', '85c4e7cd', '868de0fa', '88a10436', '890034e9', '8d510a79', '8e5a5113', '8eb1be9a', '8f2ea7aa', '90f3ed37', '913fb3ed', '91714a58', '928ad970', '93b581b8', '941d9a10', '952a094c', '9565186b', '95990924', '963e52fc', '97999447', '98cf29f8', '99fa7670', '9d9215db', '9dfd6313', '9edfc990', 'a1570a43', 'a2fd1cf0', 'a3df8b1e', 'a416b8f3', 'a48eeaf7', 'a5313dff', 'a5f85a15', 'a61f2674', 'a64e4611', 'a65b410d', 'a699fb00', 'a78176bb', 'a79310a0', 'a85d4709', 'a8d7556c', 'a9f96cdd', 'aabf363d', 'aba27056', 'ae3edfdc', 'aedd82e4', 'af902bf9', 'b1948b0a', 'b230c067', 'b27ca6d3', 'b2862040', 'b527c5c6', 'b548a754', 'b60334d2', 'b6afb2da', 'b7249182', 'b775ac94', 'b782dc8a', 'b8825c91', 'b8cdaf2b', 'ba26e723', 'ba97ae07', 'bb43febb', 'bd4472b8', 'bda2d7a6', 'bdad9b1f', 'beb8660c', 'c0f76784', 'c1d99e64', 'c3f564a4', 'c444b776', 'c8f0f002', 'c9e6f938', 'c9f8e694', 'caa06a1f', 'cbded52d', 'ce22a75a', 'ce9e57f2', 'cf98881b', 'd037b0a7', 'd06dbe63', 'd07ae81c', 'd22278a0', 'd23f8c26', 'd2abd087', 'd364b489', 'd406998b', 'd43fd935', 'd4a91cb9', 'd4f3cd78', 'd511f180', 'd5d6de2d', 'd687bc17', 'd6ad076f', 'd89b689b', 'd8c310e9', 'd90796e8', 'd9f24cd1', 'dae9d2b5', 'db3e9e38', 'db93a21d', 'dbc1a6ce', 'dc1df850', 'dc433765', 'ddf7fa4f', 'ded97339', 'e179c5f4', 'e21d9049', 'e26a3af2', 'e3497940', 'e40b9e2f', 'e48d4e1a', 'e5062a87', 'e509e548', 'e73095fd', 'e76a88a6', 'e8593010', 'e8dc4411', 'e9614598', 'e9afcf9a', 'ea32f347', 'ea786f4a', 'ec883f72', 'ecdecbb3', 'ed36ccf7', 'ef135b50', 'f15e1fac', 'f1cefba8', 'f25ffba3', 'f2829549', 'f35d900a', 'f76d97a5', 'f8a8fe49', 'f8c80d96', 'fcc82909'], 274)\n", @@ -583,7 +710,9 @@ { "cell_type": "markdown", "metadata": {}, - "source": [] + "source": [ + "## DataManagerSingleton" + ] }, { "cell_type": "code", @@ -596,33 +725,117 @@ " _instance = None\n", "\n", " @staticmethod\n", - " def get_instance():\n", + " def get_instance(folder: str = None, prefix: str = None, show_timing: bool = False):\n", " if DataManagerSingleton._instance is None:\n", - " DataManagerSingleton._instance = DataManagerSingleton()\n", + " if folder is None or prefix is None:\n", + " raise ValueError(\"folder and prefix must be provided for the first instantiation\")\n", + " DataManagerSingleton._instance = DataManagerSingleton(folder, prefix, show_timing)\n", " return DataManagerSingleton._instance\n", "\n", - " def __init__(self, folder: str, prefix: str):\n", + " def __init__(self, folder: str, prefix: str, show_timing: bool = False):\n", + " if DataManagerSingleton._instance is not None:\n", + " raise Exception(\"This class is a singleton!\")\n", " self.folder = folder\n", " self.prefix = prefix\n", - " self.challenges_manager = ChallengesDataManager(f\"{self.folder}/{self.prefix}_challenges.json\")\n", - " self.solutions_manager = SolutionsDataManager(f\"{self.folder}/{self.prefix}_solutions.json\")\n", - " self.code = None\n", + " self.challenges_manager = ChallengesDataManager(f\"{self.folder}/{self.prefix}challenges.json\", show_timing=show_timing)\n", + " self.solutions_manager = SolutionsDataManager(f\"{self.folder}/{self.prefix}solutions.json\", show_timing=show_timing)\n", "\n", - " def load_data_for_code(self, code: str):\n", - " self.code = code\n", + " def get_data_for_code(self, code: str):\n", " data = self.challenges_manager.get_data_for_key(code)\n", - " # Process the data as needed\n", - " return data" + " return data\n", + " \n", + " def get_solutions_for_code(self, code: str):\n", + " solutions = self.solutions_manager.get_data_for_key(code)\n", + " return solutions\n", + "\n", + "\n", + " def reload_data(self, folder: str, prefix: str): \n", + " self.folder = folder\n", + " self.prefix = prefix\n", + " self.challenges_manager.reload_data(f\"{self.folder}/{self.prefix}challenges.json\")\n", + " self.solutions_manager.reload_data(f\"{self.folder}/{self.prefix}solutions.json\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Execution time of __init__: 0.0000 seconds\n", + "Execution time of __init__: 0.0011 seconds\n", + "Execution time of get_data_for_key: 0.0000 seconds\n", + "{'test': [{'input': [[2]]}], 'train': [{'input': [[3]], 'output': [[4]]}, {'input': [[7]], 'output': [[8]]}, {'input': [[1]], 'output': [[2]]}]}\n", + "Execution time of get_data_for_key: 0.0000 seconds\n", + "[[3]]\n" + ] + } + ], "source": [ - "#| hide\n", - "import nbdev; nbdev.nbdev_export()" + "#| gui\n", + "code = '00000001'\n", + "data_mgr = DataManagerSingleton.get_instance(folder = 'c:/tmp/arc-prize-2024', prefix = 'arc-agi_simple_', show_timing=True)\n", + "data = data_mgr.get_data_for_code(code)\n", + "print(data)\n", + "out = data_mgr.get_solutions_for_code(code)\n", + "print(out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Execution time of get_data_for_key: 0.0000 seconds\n", + "{'test': [{'input': [[2, 2], [2, 2]]}], 'train': [{'input': [[3, 3], [3, 3]], 'output': [[4, 4], [4, 4]]}, {'input': [[7, 7], [7, 7]], 'output': [[8, 8], [8, 8]]}, {'input': [[1, 1], [1, 1]], 'output': [[2, 2], [2, 2]]}]}\n", + "Execution time of get_data_for_key: 0.0000 seconds\n", + "[[3, 3], [3, 3]]\n" + ] + } + ], + "source": [ + "#| gui\n", + "code = '00000003'\n", + "# data_mgr = DataManagerSingleton.get_instance(folder = 'c:/tmp/arc-prize-2024', prefix = 'arc-agi_simple_')\n", + "data = data_mgr.get_data_for_code(code)\n", + "print(data)\n", + "out = data_mgr.get_solutions_for_code(code)\n", + "print(out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Execution time of reload_data: 0.0622 seconds\n", + "Execution time of reload_data: 0.0041 seconds\n", + "Execution time of get_data_for_key: 0.0000 seconds\n", + "{'test': [{'input': [[7, 0, 7], [7, 0, 7], [7, 7, 0]]}], 'train': [{'input': [[0, 7, 7], [7, 7, 7], [0, 7, 7]], 'output': [[0, 0, 0, 0, 7, 7, 0, 7, 7], [0, 0, 0, 7, 7, 7, 7, 7, 7], [0, 0, 0, 0, 7, 7, 0, 7, 7], [0, 7, 7, 0, 7, 7, 0, 7, 7], [7, 7, 7, 7, 7, 7, 7, 7, 7], [0, 7, 7, 0, 7, 7, 0, 7, 7], [0, 0, 0, 0, 7, 7, 0, 7, 7], [0, 0, 0, 7, 7, 7, 7, 7, 7], [0, 0, 0, 0, 7, 7, 0, 7, 7]]}, {'input': [[4, 0, 4], [0, 0, 0], [0, 4, 0]], 'output': [[4, 0, 4, 0, 0, 0, 4, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 4, 0, 0, 0, 0, 0, 4, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 4, 0, 4, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 4, 0, 0, 0, 0]]}, {'input': [[0, 0, 0], [0, 0, 2], [2, 0, 2]], 'output': [[0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 2], [0, 0, 0, 0, 0, 0, 2, 0, 2], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 2, 0, 0, 0, 0, 0, 2], [2, 0, 2, 0, 0, 0, 2, 0, 2]]}, {'input': [[6, 6, 0], [6, 0, 0], [0, 6, 6]], 'output': [[6, 6, 0, 6, 6, 0, 0, 0, 0], [6, 0, 0, 6, 0, 0, 0, 0, 0], [0, 6, 6, 0, 6, 6, 0, 0, 0], [6, 6, 0, 0, 0, 0, 0, 0, 0], [6, 0, 0, 0, 0, 0, 0, 0, 0], [0, 6, 6, 0, 0, 0, 0, 0, 0], [0, 0, 0, 6, 6, 0, 6, 6, 0], [0, 0, 0, 6, 0, 0, 6, 0, 0], [0, 0, 0, 0, 6, 6, 0, 6, 6]]}, {'input': [[2, 2, 2], [0, 0, 0], [0, 2, 2]], 'output': [[2, 2, 2, 2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 2, 2, 0, 2, 2], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 2, 2]]}]}\n", + "Execution time of get_data_for_key: 0.0000 seconds\n", + "[[7, 0, 7, 0, 0, 0, 7, 0, 7], [7, 0, 7, 0, 0, 0, 7, 0, 7], [7, 7, 0, 0, 0, 0, 7, 7, 0], [7, 0, 7, 0, 0, 0, 7, 0, 7], [7, 0, 7, 0, 0, 0, 7, 0, 7], [7, 7, 0, 0, 0, 0, 7, 7, 0], [7, 0, 7, 7, 0, 7, 0, 0, 0], [7, 0, 7, 7, 0, 7, 0, 0, 0], [7, 7, 0, 7, 7, 0, 0, 0, 0]]\n" + ] + } + ], + "source": [ + "#| gui\n", + "code = '007bbfb7'\n", + "data_mgr.reload_data(folder = 'c:/tmp/arc-prize-2024', prefix = 'arc-agi_training_')\n", + "data = data_mgr.get_data_for_code(code)\n", + "print(data)\n", + "out = data_mgr.get_solutions_for_code(code)\n", + "print(out)" ] }, { @@ -630,7 +843,10 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "#| hide\n", + "import nbdev; nbdev.nbdev_export()" + ] } ], "metadata": { diff --git a/nbs/16_environment_processing.ipynb b/nbs/16_environment_processing.ipynb index 188463f6..afb754e2 100644 --- a/nbs/16_environment_processing.ipynb +++ b/nbs/16_environment_processing.ipynb @@ -317,8 +317,8 @@ " for key, value in details.items():\n", " self.env_processing_details[key]=value\n", "\n", - " def enhanced_environment_properties(self, environment_properties=None):\n", - " pass\n", + " # def enhanced_environment_properties(self, environment_properties=None):\n", + " # pass\n", "\n", "\n", " def get_experiment(self):\n", @@ -504,35 +504,35 @@ " def get_workspace(self):\n", " return 'arc-challenge'\n", " \n", - " def enhanced_environment_properties(self, environment_properties=None):\n", - " enhanced_environment_properties = {}\n", - " if 'dir' in environment_properties:\n", - " dir = environment_properties['dir']\n", - " else:\n", - " dir = 'C:\\\\packages\\\\arc-prize-2024'\n", - " environment_properties['dir'] = dir \n", - "\n", - " if 'file_prefix' in environment_properties:\n", - " file_prefix = environment_properties['file_prefix']\n", - " else:\n", - " file_prefix = 'arc-agi_training_'\n", - " environment_properties['file_prefix'] = file_prefix\n", - "\n", - " file_name = path.join(dir, file_prefix) + 'challenges.json' \n", - " challenges_manager = ChallengesDataManager(file_name)\n", - " data = challenges_manager.get_data_for_key(environment_properties['code'])\n", + " # def enhanced_environment_properties(self, environment_properties=None):\n", + " # enhanced_environment_properties = {}\n", + " # if 'dir' in environment_properties:\n", + " # dir = environment_properties['dir']\n", + " # else:\n", + " # dir = 'C:\\\\packages\\\\arc-prize-2024'\n", + " # environment_properties['dir'] = dir \n", + "\n", + " # if 'file_prefix' in environment_properties:\n", + " # file_prefix = environment_properties['file_prefix']\n", + " # else:\n", + " # file_prefix = 'arc-agi_training_'\n", + " # environment_properties['file_prefix'] = file_prefix\n", + "\n", + " # file_name = path.join(dir, file_prefix) + 'challenges.json' \n", + " # challenges_manager = ChallengesDataManager(file_name)\n", + " # data = challenges_manager.get_data_for_key(environment_properties['code'])\n", " \n", - " self.number_of_challenges = 1\n", - " if 'index' not in environment_properties:\n", - " self.number_of_challenges = len(data['train'])\n", + " # self.number_of_challenges = 1\n", + " # if 'index' not in environment_properties:\n", + " # self.number_of_challenges = len(data['train'])\n", " \n", - " enhanced_environment_properties['data']=data\n", - " solutions_file = path.join(environment_properties['dir'], environment_properties['file_prefix']) + 'solutions.json' \n", - " solutions_manager = SolutionsDataManager(solutions_file)\n", - " test_output_array = solutions_manager.get_data_for_key(environment_properties['code'])\n", - " enhanced_environment_properties['test_output_array']=test_output_array\n", + " # enhanced_environment_properties['data']=data\n", + " # solutions_file = path.join(environment_properties['dir'], environment_properties['file_prefix']) + 'solutions.json' \n", + " # solutions_manager = SolutionsDataManager(solutions_file)\n", + " # test_output_array = solutions_manager.get_data_for_key(environment_properties['code'])\n", + " # enhanced_environment_properties['test_output_array']=test_output_array\n", "\n", - " return enhanced_environment_properties\n", + " # return enhanced_environment_properties\n", " \n", " def results(self, filepath=None, experiment=None):\n", " print(filepath)\n", @@ -545,7 +545,7 @@ " environment_properties['index'] = 0\n", " print(environment_properties)\n", "\n", - " enhanced_environment_properties = self.enhanced_environment_properties(environment_properties=environment_properties)\n", + " # enhanced_environment_properties = self.enhanced_environment_properties(environment_properties=environment_properties)\n", "\n", " verbose= self.args['verbosed']['hpct_verbose'] \n", " min= not self.args['max']\n", @@ -556,7 +556,9 @@ " runs = int(1.5*environment_properties['runs']/self.number_of_challenges)\n", " hierarchy, score = PCTHierarchy.run_from_file(filepath, env_props=environment_properties, history=history, hpct_verbose= verbose, \n", " render=self.args['verbosed']['display_env'], runs=runs, experiment=experiment, min=min, plots=plots, plots_dir=self.args['plots_dir'],\n", - " enhanced_environment_properties=enhanced_environment_properties, title_prefix=title_prefix, early_termination=False)\n", + " title_prefix=title_prefix, early_termination=False\n", + " # ,enhanced_environment_properties=enhanced_environment_properties\n", + " )\n", "\n", " score = round(score ** 0.5, 1)\n", " print('Test score',score)\n", diff --git a/pct/_modidx.py b/pct/_modidx.py index a0d9615f..7011cc1f 100644 --- a/pct/_modidx.py +++ b/pct/_modidx.py @@ -106,8 +106,6 @@ 'pct/environment_processing.py'), 'pct.environment_processing.ARCEnvironmentProcessing.Factory': ( 'environment_processing.html#arcenvironmentprocessing.factory', 'pct/environment_processing.py'), - 'pct.environment_processing.ARCEnvironmentProcessing.enhanced_environment_properties': ( 'environment_processing.html#arcenvironmentprocessing.enhanced_environment_properties', - 'pct/environment_processing.py'), 'pct.environment_processing.ARCEnvironmentProcessing.get_experiment_name': ( 'environment_processing.html#arcenvironmentprocessing.get_experiment_name', 'pct/environment_processing.py'), 'pct.environment_processing.ARCEnvironmentProcessing.get_workspace': ( 'environment_processing.html#arcenvironmentprocessing.get_workspace', @@ -120,8 +118,6 @@ 'pct/environment_processing.py'), 'pct.environment_processing.BaseEnvironmentProcessing.add_processing_detail': ( 'environment_processing.html#baseenvironmentprocessing.add_processing_detail', 'pct/environment_processing.py'), - 'pct.environment_processing.BaseEnvironmentProcessing.enhanced_environment_properties': ( 'environment_processing.html#baseenvironmentprocessing.enhanced_environment_properties', - 'pct/environment_processing.py'), 'pct.environment_processing.BaseEnvironmentProcessing.get_experiment': ( 'environment_processing.html#baseenvironmentprocessing.get_experiment', 'pct/environment_processing.py'), 'pct.environment_processing.BaseEnvironmentProcessing.get_experiment_name': ( 'environment_processing.html#baseenvironmentprocessing.get_experiment_name', @@ -967,6 +963,8 @@ 'pct/functions.py'), 'pct.functions.WeightedSum.summary': ('functions.html#weightedsum.summary', 'pct/functions.py')}, 'pct.helpers': { 'pct.helpers.ChallengesDataManager': ('helpers.html#challengesdatamanager', 'pct/helpers.py'), + 'pct.helpers.ChallengesDataManager.__init__': ( 'helpers.html#challengesdatamanager.__init__', + 'pct/helpers.py'), 'pct.helpers.ChallengesDataManager.analyze_arrays': ( 'helpers.html#challengesdatamanager.analyze_arrays', 'pct/helpers.py'), 'pct.helpers.ChallengesDataManager.count_all_keys': ( 'helpers.html#challengesdatamanager.count_all_keys', @@ -987,15 +985,22 @@ 'pct/helpers.py'), 'pct.helpers.ChallengesDataManager.get_largest_array_size': ( 'helpers.html#challengesdatamanager.get_largest_array_size', 'pct/helpers.py'), + 'pct.helpers.ChallengesDataManager.reload_data': ( 'helpers.html#challengesdatamanager.reload_data', + 'pct/helpers.py'), 'pct.helpers.DataManagerSingleton': ('helpers.html#datamanagersingleton', 'pct/helpers.py'), 'pct.helpers.DataManagerSingleton.__init__': ('helpers.html#datamanagersingleton.__init__', 'pct/helpers.py'), + 'pct.helpers.DataManagerSingleton.get_data_for_code': ( 'helpers.html#datamanagersingleton.get_data_for_code', + 'pct/helpers.py'), 'pct.helpers.DataManagerSingleton.get_instance': ( 'helpers.html#datamanagersingleton.get_instance', 'pct/helpers.py'), - 'pct.helpers.DataManagerSingleton.load_data_for_code': ( 'helpers.html#datamanagersingleton.load_data_for_code', - 'pct/helpers.py'), + 'pct.helpers.DataManagerSingleton.get_solutions_for_code': ( 'helpers.html#datamanagersingleton.get_solutions_for_code', + 'pct/helpers.py'), + 'pct.helpers.DataManagerSingleton.reload_data': ( 'helpers.html#datamanagersingleton.reload_data', + 'pct/helpers.py'), 'pct.helpers.JSONDataManager': ('helpers.html#jsondatamanager', 'pct/helpers.py'), 'pct.helpers.JSONDataManager.__init__': ('helpers.html#jsondatamanager.__init__', 'pct/helpers.py'), 'pct.helpers.JSONDataManager.load_json': ('helpers.html#jsondatamanager.load_json', 'pct/helpers.py'), + 'pct.helpers.JSONDataManager.reload_data': ('helpers.html#jsondatamanager.reload_data', 'pct/helpers.py'), 'pct.helpers.JSONDataManager.timing_decorator': ( 'helpers.html#jsondatamanager.timing_decorator', 'pct/helpers.py'), 'pct.helpers.ListChecker': ('helpers.html#listchecker', 'pct/helpers.py'), @@ -1008,6 +1013,7 @@ 'pct.helpers.ListChecker.check_list_unchanged': ( 'helpers.html#listchecker.check_list_unchanged', 'pct/helpers.py'), 'pct.helpers.SolutionsDataManager': ('helpers.html#solutionsdatamanager', 'pct/helpers.py'), + 'pct.helpers.SolutionsDataManager.__init__': ('helpers.html#solutionsdatamanager.__init__', 'pct/helpers.py'), 'pct.helpers.SolutionsDataManager.count_all_keys': ( 'helpers.html#solutionsdatamanager.count_all_keys', 'pct/helpers.py'), 'pct.helpers.SolutionsDataManager.get_all_keys': ( 'helpers.html#solutionsdatamanager.get_all_keys', @@ -1015,7 +1021,9 @@ 'pct.helpers.SolutionsDataManager.get_arrays_for_key': ( 'helpers.html#solutionsdatamanager.get_arrays_for_key', 'pct/helpers.py'), 'pct.helpers.SolutionsDataManager.get_data_for_key': ( 'helpers.html#solutionsdatamanager.get_data_for_key', - 'pct/helpers.py')}, + 'pct/helpers.py'), + 'pct.helpers.SolutionsDataManager.reload_data': ( 'helpers.html#solutionsdatamanager.reload_data', + 'pct/helpers.py')}, 'pct.hierarchy': { 'pct.hierarchy.FunctionsData': ('hierarchy.html#functionsdata', 'pct/hierarchy.py'), 'pct.hierarchy.FunctionsData.__init__': ('hierarchy.html#functionsdata.__init__', 'pct/hierarchy.py'), 'pct.hierarchy.FunctionsData.add_data': ('hierarchy.html#functionsdata.add_data', 'pct/hierarchy.py'), diff --git a/pct/environment_processing.py b/pct/environment_processing.py index f9096bb2..5cfc23d4 100644 --- a/pct/environment_processing.py +++ b/pct/environment_processing.py @@ -223,8 +223,8 @@ def add_details(self, details): for key, value in details.items(): self.env_processing_details[key]=value - def enhanced_environment_properties(self, environment_properties=None): - pass + # def enhanced_environment_properties(self, environment_properties=None): + # pass def get_experiment(self): @@ -375,35 +375,35 @@ class ARCEnvironmentProcessing(BaseEnvironmentProcessing): def get_workspace(self): return 'arc-challenge' - def enhanced_environment_properties(self, environment_properties=None): - enhanced_environment_properties = {} - if 'dir' in environment_properties: - dir = environment_properties['dir'] - else: - dir = 'C:\\packages\\arc-prize-2024' - environment_properties['dir'] = dir - - if 'file_prefix' in environment_properties: - file_prefix = environment_properties['file_prefix'] - else: - file_prefix = 'arc-agi_training_' - environment_properties['file_prefix'] = file_prefix - - file_name = path.join(dir, file_prefix) + 'challenges.json' - challenges_manager = ChallengesDataManager(file_name) - data = challenges_manager.get_data_for_key(environment_properties['code']) + # def enhanced_environment_properties(self, environment_properties=None): + # enhanced_environment_properties = {} + # if 'dir' in environment_properties: + # dir = environment_properties['dir'] + # else: + # dir = 'C:\\packages\\arc-prize-2024' + # environment_properties['dir'] = dir + + # if 'file_prefix' in environment_properties: + # file_prefix = environment_properties['file_prefix'] + # else: + # file_prefix = 'arc-agi_training_' + # environment_properties['file_prefix'] = file_prefix + + # file_name = path.join(dir, file_prefix) + 'challenges.json' + # challenges_manager = ChallengesDataManager(file_name) + # data = challenges_manager.get_data_for_key(environment_properties['code']) - self.number_of_challenges = 1 - if 'index' not in environment_properties: - self.number_of_challenges = len(data['train']) + # self.number_of_challenges = 1 + # if 'index' not in environment_properties: + # self.number_of_challenges = len(data['train']) - enhanced_environment_properties['data']=data - solutions_file = path.join(environment_properties['dir'], environment_properties['file_prefix']) + 'solutions.json' - solutions_manager = SolutionsDataManager(solutions_file) - test_output_array = solutions_manager.get_data_for_key(environment_properties['code']) - enhanced_environment_properties['test_output_array']=test_output_array + # enhanced_environment_properties['data']=data + # solutions_file = path.join(environment_properties['dir'], environment_properties['file_prefix']) + 'solutions.json' + # solutions_manager = SolutionsDataManager(solutions_file) + # test_output_array = solutions_manager.get_data_for_key(environment_properties['code']) + # enhanced_environment_properties['test_output_array']=test_output_array - return enhanced_environment_properties + # return enhanced_environment_properties def results(self, filepath=None, experiment=None): print(filepath) @@ -416,7 +416,7 @@ def results(self, filepath=None, experiment=None): environment_properties['index'] = 0 print(environment_properties) - enhanced_environment_properties = self.enhanced_environment_properties(environment_properties=environment_properties) + # enhanced_environment_properties = self.enhanced_environment_properties(environment_properties=environment_properties) verbose= self.args['verbosed']['hpct_verbose'] min= not self.args['max'] @@ -427,7 +427,9 @@ def results(self, filepath=None, experiment=None): runs = int(1.5*environment_properties['runs']/self.number_of_challenges) hierarchy, score = PCTHierarchy.run_from_file(filepath, env_props=environment_properties, history=history, hpct_verbose= verbose, render=self.args['verbosed']['display_env'], runs=runs, experiment=experiment, min=min, plots=plots, plots_dir=self.args['plots_dir'], - enhanced_environment_properties=enhanced_environment_properties, title_prefix=title_prefix, early_termination=False) + title_prefix=title_prefix, early_termination=False + # ,enhanced_environment_properties=enhanced_environment_properties + ) score = round(score ** 0.5, 1) print('Test score',score) diff --git a/pct/environments.py b/pct/environments.py index 44390cc4..b7717012 100644 --- a/pct/environments.py +++ b/pct/environments.py @@ -19,7 +19,7 @@ from .webots import WebotsHelper # from pct.yaw_module import YawEnv from .arc import ARCEnv -from .helpers import ListChecker, ChallengesDataManager +from .helpers import ListChecker, DataManagerSingleton # %% ../nbs/05_environments.ipynb 6 class EnvironmentFactory: @@ -1307,9 +1307,12 @@ def __call__(self, verbose: bool = False) -> Any: def set_properties(self, props: dict) -> None: - data = props['data'] + # data = props['data'] + + data_mgr = DataManagerSingleton.get_instance(folder = 'c:/tmp/arc-prize-2024', prefix = 'arc-agi_simple_', show_timing=True) + data = data_mgr.get_data_for_code(props['code']) + props['test_output_array'] = data_mgr.get_solutions_for_code(props['code']) - # props['data']=data self.env.initialise(props, data) self.fitness = self.env.fitness self.history = props.get('history', 5) @@ -1416,7 +1419,6 @@ def add_to_fitness_history(self, fitness): self.done, details = ListChecker.check_list_unchanged(self.boxcar, rel_tol =get_rel_tol('ARC-change'), abs_tol=get_abs_tol('ARC-change'), gradient_abs_tol=get_abs_tol('ARC-gradient')) if self.done: self.env.add_to_gradient_list(details['gradient_range']) - # self.env.fitness_isclose_to_zero = ListChecker.check_float_list_close_to_zero(self.boxcar, rel_tol = 0, abs_tol=get_abs_tol('ARC-zero'), gradient_abs_tol=get_abs_tol('ARC-gradient')) if self.done: self.env.add_to_fitness_list(max(self.boxcar) ) diff --git a/pct/helpers.py b/pct/helpers.py index dacf6efa..7ed6e481 100644 --- a/pct/helpers.py +++ b/pct/helpers.py @@ -14,8 +14,6 @@ # %% ../nbs/14_helpers.ipynb 5 class ListChecker: - - @staticmethod def check_list_unchanged(float_list, rel_tol=1e-9, abs_tol=0.0, gradient_abs_tol=0.0): @@ -110,12 +108,13 @@ def check_integer_list_unchanged(int_list): -# %% ../nbs/14_helpers.ipynb 12 +# %% ../nbs/14_helpers.ipynb 17 class JSONDataManager: def __init__(self, path: str, show_timing: bool = False): self.data = self.load_json(path) self.show_timing = show_timing + def load_json(self, path: str) -> Dict: with open(path, 'r') as file: return json.load(file) @@ -130,11 +129,17 @@ def timed_method(self, *args, **kwargs): return result return timed_method + def reload_data(self, path: str): + self.data = self.load_json(path) - -# %% ../nbs/14_helpers.ipynb 14 +# %% ../nbs/14_helpers.ipynb 21 class ChallengesDataManager(JSONDataManager): + @JSONDataManager.timing_decorator + def __init__(self, path: str, show_timing: bool = False): + super().__init__(path, show_timing) + + @JSONDataManager.timing_decorator def get_all_keys(self) -> List[str]: return list(self.data.keys()) @@ -174,9 +179,11 @@ def get_input_array_histogram(self) -> Dict[int, int]: counts = Counter(len(value['train']) for value in self.data.values()) return dict(counts) - @JSONDataManager.timing_decorator + # @JSONDataManager.timing_decorator def get_data_for_key(self, key: str) -> Dict[str, Any]: - return self.data.get(key, {}) + if key not in self.data: + raise KeyError(f"Key '{key}' not found in data.") + return self.data[key] @JSONDataManager.timing_decorator def get_arrays_for_key(self, key: str, array_type: str) -> List: @@ -223,10 +230,19 @@ def analyze_arrays(self) -> Dict[str, Any]: } + @JSONDataManager.timing_decorator + def reload_data(self, path: str): + super().reload_data(path) -# %% ../nbs/14_helpers.ipynb 16 + +# %% ../nbs/14_helpers.ipynb 23 class SolutionsDataManager(JSONDataManager): + + @JSONDataManager.timing_decorator + def __init__(self, path: str, show_timing: bool = False): + super().__init__(path, show_timing) + @JSONDataManager.timing_decorator def get_all_keys(self) -> List[str]: return list(self.data.keys()) @@ -235,39 +251,55 @@ def get_all_keys(self) -> List[str]: def count_all_keys(self) -> int: return len(self.data) - @JSONDataManager.timing_decorator + # @JSONDataManager.timing_decorator def get_data_for_key(self, key: str) -> Dict[str, Any]: - data = self.data.get(key, []) - return data[0] if data else {} - + if key not in self.data: + raise KeyError(f"Key '{key}' not found in data.") + return self.data[key][0] + @JSONDataManager.timing_decorator def get_arrays_for_key(self, key: str, array_type: str) -> List: if key not in self.data or array_type not in self.data[key]: return [] return self.data[key][array_type] + @JSONDataManager.timing_decorator + def reload_data(self, path: str): + super().reload_data(path) - -# %% ../nbs/14_helpers.ipynb 19 +# %% ../nbs/14_helpers.ipynb 26 class DataManagerSingleton: _instance = None @staticmethod - def get_instance(): + def get_instance(folder: str = None, prefix: str = None, show_timing: bool = False): if DataManagerSingleton._instance is None: - DataManagerSingleton._instance = DataManagerSingleton() + if folder is None or prefix is None: + raise ValueError("folder and prefix must be provided for the first instantiation") + DataManagerSingleton._instance = DataManagerSingleton(folder, prefix, show_timing) return DataManagerSingleton._instance - def __init__(self, folder: str, prefix: str): + def __init__(self, folder: str, prefix: str, show_timing: bool = False): + if DataManagerSingleton._instance is not None: + raise Exception("This class is a singleton!") self.folder = folder self.prefix = prefix - self.challenges_manager = ChallengesDataManager(f"{self.folder}/{self.prefix}_challenges.json") - self.solutions_manager = SolutionsDataManager(f"{self.folder}/{self.prefix}_solutions.json") - self.code = None + self.challenges_manager = ChallengesDataManager(f"{self.folder}/{self.prefix}challenges.json", show_timing=show_timing) + self.solutions_manager = SolutionsDataManager(f"{self.folder}/{self.prefix}solutions.json", show_timing=show_timing) - def load_data_for_code(self, code: str): - self.code = code + def get_data_for_code(self, code: str): data = self.challenges_manager.get_data_for_key(code) - # Process the data as needed return data + + def get_solutions_for_code(self, code: str): + solutions = self.solutions_manager.get_data_for_key(code) + return solutions + + + def reload_data(self, folder: str, prefix: str): + self.folder = folder + self.prefix = prefix + self.challenges_manager.reload_data(f"{self.folder}/{self.prefix}challenges.json") + self.solutions_manager.reload_data(f"{self.folder}/{self.prefix}solutions.json") +