From 849b79ceed3e20b2fa935daa3685149a8f584452 Mon Sep 17 00:00:00 2001 From: Jesse Wood Date: Mon, 26 Feb 2024 18:06:20 +1300 Subject: [PATCH] Delete code/identification/part/EDIT_R01_S02_Identification_Part_Wrapper_Based_Multi_Tree_GP.ipynb --- ...ion_Part_Wrapper_Based_Multi_Tree_GP.ipynb | 1307 ----------------- 1 file changed, 1307 deletions(-) delete mode 100644 code/identification/part/EDIT_R01_S02_Identification_Part_Wrapper_Based_Multi_Tree_GP.ipynb diff --git a/code/identification/part/EDIT_R01_S02_Identification_Part_Wrapper_Based_Multi_Tree_GP.ipynb b/code/identification/part/EDIT_R01_S02_Identification_Part_Wrapper_Based_Multi_Tree_GP.ipynb deleted file mode 100644 index c0ada96e..00000000 --- a/code/identification/part/EDIT_R01_S02_Identification_Part_Wrapper_Based_Multi_Tree_GP.ipynb +++ /dev/null @@ -1,1307 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SAlW5BfrgjQk" - }, - "source": [ - "# Multi-tree Genetic Program" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TESFcfJMSeT5" - }, - "source": [ - "## Kicked for inactivity?\n", - "\n", - "To stop a colab notebook from disconnecting, open up the console with CTRL + SHIFT + I, and copy and execute the following code.\n", - "\n", - "```javascript\n", - "function ConnectButton(){\n", - " console.log(\"Connect pushed\");\n", - " document.querySelector(\"#top-toolbar > colab-connect-button\").shadowRoot.querySelector(\"#connect\").click()\n", - "}\n", - "setInterval(ConnectButton,60000);\n", - "```\n", - "\n", - "This tricks the browser into thinking the user is active." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "NLL18yL7WWRt", - "outputId": "52714fe5-cfcb-4d53-cf24-515c04995d4d" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting deap\n", - " Downloading deap-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (135 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m135.4/135.4 kB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from deap) (1.25.2)\n", - "Installing collected packages: deap\n", - "Successfully installed deap-1.4.1\n", - "Reading package lists... Done\n", - "Building dependency tree... Done\n", - "Reading state information... Done\n", - "The following additional packages will be installed:\n", - " libgail-common libgail18 libgtk2.0-0 libgtk2.0-bin libgtk2.0-common libgvc6-plugins-gtk\n", - " librsvg2-common libxdot4\n", - "Suggested packages:\n", - " gvfs\n", - "The following NEW packages will be installed:\n", - " libgail-common libgail18 libgraphviz-dev libgtk2.0-0 libgtk2.0-bin libgtk2.0-common\n", - " libgvc6-plugins-gtk librsvg2-common libxdot4\n", - "0 upgraded, 9 newly installed, 0 to remove and 33 not upgraded.\n", - "Need to get 2,433 kB of archives.\n", - "After this operation, 7,694 kB of additional disk space will be used.\n", - "Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 libgtk2.0-common all 2.24.33-2ubuntu2 [125 kB]\n", - "Get:2 http://archive.ubuntu.com/ubuntu jammy/main amd64 libgtk2.0-0 amd64 2.24.33-2ubuntu2 [2,037 kB]\n", - "Get:3 http://archive.ubuntu.com/ubuntu jammy/main amd64 libgail18 amd64 2.24.33-2ubuntu2 [15.9 kB]\n", - "Get:4 http://archive.ubuntu.com/ubuntu jammy/main amd64 libgail-common amd64 2.24.33-2ubuntu2 [132 kB]\n", - "Get:5 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libxdot4 amd64 2.42.2-6 [16.4 kB]\n", - "Get:6 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libgvc6-plugins-gtk amd64 2.42.2-6 [22.6 kB]\n", - "Get:7 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libgraphviz-dev amd64 2.42.2-6 [58.5 kB]\n", - "Get:8 http://archive.ubuntu.com/ubuntu jammy/main amd64 libgtk2.0-bin amd64 2.24.33-2ubuntu2 [7,932 B]\n", - "Get:9 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 librsvg2-common amd64 2.52.5+dfsg-3ubuntu0.2 [17.7 kB]\n", - "Fetched 2,433 kB in 1s (1,816 kB/s)\n", - "Selecting previously unselected package libgtk2.0-common.\n", - "(Reading database ... 121749 files and directories currently installed.)\n", - "Preparing to unpack .../0-libgtk2.0-common_2.24.33-2ubuntu2_all.deb ...\n", - "Unpacking libgtk2.0-common (2.24.33-2ubuntu2) ...\n", - "Selecting previously unselected package libgtk2.0-0:amd64.\n", - "Preparing to unpack .../1-libgtk2.0-0_2.24.33-2ubuntu2_amd64.deb ...\n", - "Unpacking libgtk2.0-0:amd64 (2.24.33-2ubuntu2) ...\n", - "Selecting previously unselected package libgail18:amd64.\n", - "Preparing to unpack .../2-libgail18_2.24.33-2ubuntu2_amd64.deb ...\n", - "Unpacking libgail18:amd64 (2.24.33-2ubuntu2) ...\n", - "Selecting previously unselected package libgail-common:amd64.\n", - "Preparing to unpack .../3-libgail-common_2.24.33-2ubuntu2_amd64.deb ...\n", - "Unpacking libgail-common:amd64 (2.24.33-2ubuntu2) ...\n", - "Selecting previously unselected package libxdot4:amd64.\n", - "Preparing to unpack .../4-libxdot4_2.42.2-6_amd64.deb ...\n", - "Unpacking libxdot4:amd64 (2.42.2-6) ...\n", - "Selecting previously unselected package libgvc6-plugins-gtk.\n", - "Preparing to unpack .../5-libgvc6-plugins-gtk_2.42.2-6_amd64.deb ...\n", - "Unpacking libgvc6-plugins-gtk (2.42.2-6) ...\n", - "Selecting previously unselected package libgraphviz-dev:amd64.\n", - "Preparing to unpack .../6-libgraphviz-dev_2.42.2-6_amd64.deb ...\n", - "Unpacking libgraphviz-dev:amd64 (2.42.2-6) ...\n", - "Selecting previously unselected package libgtk2.0-bin.\n", - "Preparing to unpack .../7-libgtk2.0-bin_2.24.33-2ubuntu2_amd64.deb ...\n", - "Unpacking libgtk2.0-bin (2.24.33-2ubuntu2) ...\n", - "Selecting previously unselected package librsvg2-common:amd64.\n", - "Preparing to unpack .../8-librsvg2-common_2.52.5+dfsg-3ubuntu0.2_amd64.deb ...\n", - "Unpacking librsvg2-common:amd64 (2.52.5+dfsg-3ubuntu0.2) ...\n", - "Setting up libxdot4:amd64 (2.42.2-6) ...\n", - "Setting up librsvg2-common:amd64 (2.52.5+dfsg-3ubuntu0.2) ...\n", - "Setting up libgtk2.0-common (2.24.33-2ubuntu2) ...\n", - "Setting up libgtk2.0-0:amd64 (2.24.33-2ubuntu2) ...\n", - "Setting up libgvc6-plugins-gtk (2.42.2-6) ...\n", - "Setting up libgail18:amd64 (2.24.33-2ubuntu2) ...\n", - "Setting up libgtk2.0-bin (2.24.33-2ubuntu2) ...\n", - "Setting up libgail-common:amd64 (2.24.33-2ubuntu2) ...\n", - "Setting up libgraphviz-dev:amd64 (2.42.2-6) ...\n", - "Processing triggers for libc-bin (2.35-0ubuntu3.4) ...\n", - "/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link\n", - "\n", - "/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link\n", - "\n", - "/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link\n", - "\n", - "/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link\n", - "\n", - "/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link\n", - "\n", - "/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link\n", - "\n", - "Processing triggers for man-db (2.10.2-1) ...\n", - "Processing triggers for libgdk-pixbuf-2.0-0:amd64 (2.42.8+dfsg-1ubuntu0.2) ...\n", - "Collecting pygraphviz\n", - " Downloading pygraphviz-1.12.tar.gz (104 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m104.9/104.9 kB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Building wheels for collected packages: pygraphviz\n", - " Building wheel for pygraphviz (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for pygraphviz: filename=pygraphviz-1.12-cp310-cp310-linux_x86_64.whl size=168131 sha256=4a4b9e9457f15b978b2b682471b0bd77ab7a185eaca0c21b0fca73309703ba92\n", - " Stored in directory: /root/.cache/pip/wheels/1d/ee/b5/a2f54f9e9b3951599c05dcce270ca85e472f8e6cec470e397a\n", - "Successfully built pygraphviz\n", - "Installing collected packages: pygraphviz\n", - "Successfully installed pygraphviz-1.12\n", - "Collecting skfeature-chappers\n", - " Downloading skfeature_chappers-1.1.0-py3-none-any.whl (66 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m66.3/66.3 kB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from skfeature-chappers) (1.2.2)\n", - "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from skfeature-chappers) (1.5.3)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from skfeature-chappers) (1.25.2)\n", - "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->skfeature-chappers) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->skfeature-chappers) (2023.4)\n", - "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->skfeature-chappers) (1.11.4)\n", - "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->skfeature-chappers) (1.3.2)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->skfeature-chappers) (3.2.0)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->skfeature-chappers) (1.16.0)\n", - "Installing collected packages: skfeature-chappers\n", - "Successfully installed skfeature-chappers-1.1.0\n" - ] - } - ], - "source": [ - "!pip install deap\n", - "!apt install libgraphviz-dev\n", - "!pip install pygraphviz\n", - "!pip install skfeature-chappers" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "YCcCiUIvL40G", - "outputId": "7ac27fe1-8754-49a0-9d5e-e440fda27e60" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Mounted at /content/drive\n" - ] - } - ], - "source": [ - "from google.colab import drive\n", - "drive.mount('/content/drive')\n", - "data_path_gdrive = '/content/drive/MyDrive/AI/Data'\n", - "dataset = 'Part' #@param [\"Fish\", \"Part\"]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lD4BzYVvq9_j" - }, - "source": [ - "## Data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "r-0dl8EoqK0T", - "outputId": "86e9c5b8-2fea-43ac-a4ef-3db2e7384e92" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n", - "[INFO] Reading the dataset.\n", - "Class Counts: [6 6 3 6 6 3], Class Ratios: [0.2 0.2 0.1 0.2 0.2 0.1]\n", - "Number of features: 1023\n", - "Number of instances: 30\n", - "Number of classes 6.\n" - ] - } - ], - "source": [ - "\"\"\"\n", - "Data - data.py\n", - "==============\n", - "\n", - "This is the data module. It contains the functions for loading, preparing, normalizing and encoding the data.\n", - "\"\"\"\n", - "\n", - "import numpy as np\n", - "from sklearn.preprocessing import MinMaxScaler\n", - "from sklearn import preprocessing\n", - "import scipy.io\n", - "import os\n", - "\n", - "def encode_labels(y, y_test=None):\n", - " \"\"\"\n", - " Convert text labels to numbers.\n", - "\n", - " Args:\n", - " y: The labels.\n", - " y_test: The test labels. Defaults to None.\n", - " \"\"\"\n", - " le = preprocessing.LabelEncoder()\n", - " y = le.fit_transform(y)\n", - " if y_test is not None:\n", - " y_test = le.transform(y_test)\n", - " return y, y_test, le\n", - "\n", - "\n", - "def load(filename, folder=''):\n", - " \"\"\"\n", - " Load the data from the mat file.\n", - "\n", - " Args:\n", - " filename: The name of the mat file.\n", - " folder: The folder where the mat file is located.\n", - " \"\"\"\n", - " path = os.path.join(folder, filename)\n", - " mat = scipy.io.loadmat(path)\n", - " return mat\n", - "\n", - "\n", - "def prepare(mat):\n", - " \"\"\"\n", - " Load the data from matlab format into memory.\n", - "\n", - " Args:\n", - " mat: The data in matlab format.\n", - " \"\"\"\n", - " X = mat['X']\n", - " X = X.astype(float)\n", - " y = mat['Y']\n", - " y = y[:, 0]\n", - " return X,y\n", - "\n", - "\n", - "def normalize(X_train, X_test):\n", - " \"\"\"\n", - " Normalize the input features within range [0,1].\n", - "\n", - " Args:\n", - " X_train: The training data.\n", - " X_test: The test data.\n", - " \"\"\"\n", - " scaler = MinMaxScaler(feature_range=(0, 1))\n", - " scaler = scaler.fit(X_train)\n", - " X_train = scaler.transform(X_train)\n", - " X_test = scaler.transform(X_test)\n", - " return X_train, X_test\n", - "\n", - "# Code snippet from elsewhere.\n", - "\n", - "from google.colab import drive\n", - "drive.mount('/content/drive')\n", - "import os\n", - "os.listdir('/content/drive/My Drive')\n", - "\n", - "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib.ticker as ticker\n", - "\n", - "path = ['drive', 'MyDrive', 'AI', 'fish', 'REIMS_data.xlsx']\n", - "path = os.path.join(*path)\n", - "\n", - "# Load the dataset\n", - "# data = pd.read_excel(path)\n", - "\n", - "print(\"[INFO] Reading the dataset.\")\n", - "raw = pd.read_excel(path)\n", - "\n", - "data = raw[~raw['m/z'].str.contains('HM')]\n", - "data = data[~data['m/z'].str.contains('QC')]\n", - "data = data[~data['m/z'].str.contains('HM')]\n", - "X = data.drop('m/z', axis=1) # X contains only the features.\n", - "# y = data['m/z'].apply(lambda x:\n", - "# [1,0,0,0,0,0] if 'Fillet' in x\n", - "# else ([0,1,0,0,0,0] if 'Heads' in x\n", - "# else ([0,0,1,0,0,0] if 'Livers' in x\n", - "# else ([0,0,0,1,0,0] if 'Skins' in x\n", - "# else ([0,0,0,0,1,0] if 'Guts' in x\n", - "# else ([0,0,0,0,0,1] if 'Frames' in x\n", - "# else None )))))) # Labels for fish parts\n", - "y = data['m/z'].apply(lambda x:\n", - " 0 if 'Fillet' in x\n", - " else 1 if 'Heads' in x\n", - " else (2 if 'Livers' in x\n", - " else (3 if 'Skins' in x\n", - " else (4 if 'Guts' in x\n", - " else (5 if 'Frames' in x\n", - " else None ))))) # For fish parts\n", - "xs = []\n", - "ys = []\n", - "for (x,y) in zip(X.to_numpy(),y):\n", - " if y is not None and not np.isnan(y):\n", - " xs.append(x)\n", - " ys.append(y)\n", - "X = np.array(xs)\n", - "y = np.array(ys)\n", - "\n", - "# file = load(f'{dataset}.mat', folder=data_path_gdrive)\n", - "# X,y = prepare(file)\n", - "# X,_ = normalize(X,X)\n", - "# y, _, le = encode_labels(y)\n", - "# labels = le.inverse_transform(np.unique(y))\n", - "classes, class_counts = np.unique(y, axis=0, return_counts=True)\n", - "n_features = X.shape[1]\n", - "n_instances = X.shape[0]\n", - "n_classes = len(np.unique(y, axis=0))\n", - "class_ratios = np.array(class_counts) / n_instances\n", - "\n", - "print(f\"Class Counts: {class_counts}, Class Ratios: {class_ratios}\")\n", - "print(f\"Number of features: {n_features}\\nNumber of instances: {n_instances}\\nNumber of classes {n_classes}.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eanPjO3KB121" - }, - "source": [ - "## Activation Function" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nP6upHxeyFX_" - }, - "source": [ - "## Operators" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "8AQBTITsyEzE" - }, - "outputs": [], - "source": [ - "import math\n", - "import copy\n", - "import random\n", - "import operator\n", - "from re import I\n", - "from operator import attrgetter\n", - "from functools import wraps, partial\n", - "\n", - "import numpy as np\n", - "\n", - "from deap import algorithms\n", - "from deap.algorithms import varAnd\n", - "from deap import base, creator, tools, gp\n", - "from deap.gp import PrimitiveTree, Primitive, Terminal\n", - "\n", - "from sklearn.metrics import balanced_accuracy_score\n", - "\n", - "pset = gp.PrimitiveSet(\"MAIN\", n_features)\n", - "pset.addPrimitive(operator.add, 2)\n", - "pset.addPrimitive(operator.sub, 2)\n", - "pset.addPrimitive(operator.mul, 2)\n", - "pset.addPrimitive(operator.neg, 1)\n", - "# pset.addEphemeralConstant(\"rand101\", lambda: random.randint(-1,1))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Kejku7hWyMbO" - }, - "source": [ - "## Fitness Function" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "rQE3P3qLyMDr" - }, - "outputs": [], - "source": [ - "toolbox = base.Toolbox()\n", - "\n", - "minimized = False\n", - "if minimized:\n", - " weight = -1.0\n", - "else:\n", - " weight = 1.0\n", - "\n", - "weights = (weight,)\n", - "\n", - "if minimized:\n", - " creator.create(\"FitnessMin\", base.Fitness, weights=weights)\n", - " creator.create(\"Individual\", list, fitness=creator.FitnessMin)\n", - "else:\n", - " creator.create(\"FitnessMax\", base.Fitness, weights=weights)\n", - " creator.create(\"Individual\", list, fitness=creator.FitnessMax)\n", - "\n", - "def quick_evaluate(expr: PrimitiveTree, pset, data, prefix='ARG'):\n", - " \"\"\" Quick evaluate offers a 500% speedup for the evluation of GP trees.\n", - "\n", - " The default implementation of gp.compile provided by the DEAP library is\n", - " horrendously inefficient. (Zhang 2022) has shared his code which leads to a\n", - " 5x speedup in the compilation and evaluation of GP trees when compared to the\n", - " standard library approach.\n", - "\n", - " For multi-tree GP, this speedup factor is invaluable! As each individual conists\n", - " of m trees. For the fish dataset we have 4 classes, each with 3 constructed features,\n", - " which corresponds to 4 classes x 3 features = 12 trees for each individual.\n", - " 12 trees x 500% speedup = 6,000% overall speedup, or 60 times faster.\n", - " The 500% speedup is fundamental, for efficient evaluation of multi-tree GP.\n", - "\n", - " Args:\n", - " expr (PrimitiveTree): The uncompiled (gp.PrimitiveTree) GP tree.\n", - " pset: The primitive set.\n", - " data: The dataset to evaluate the GP tree for.\n", - " prefix: Prefix for variable arguments. Defaults to ARG.\n", - "\n", - " Returns:\n", - " The (array-like) result of the GP tree evaluate on the dataset .\n", - " \"\"\"\n", - " result = None\n", - " stack = []\n", - " for node in expr:\n", - " stack.append((node, []))\n", - " while len(stack[-1][1]) == stack[-1][0].arity:\n", - " prim, args = stack.pop()\n", - " if isinstance(prim, Primitive):\n", - " result = pset.context[prim.name](*args)\n", - " elif isinstance(prim, Terminal):\n", - " if prefix in prim.name:\n", - " result = data[:, int(prim.name.replace(prefix, ''))]\n", - " else:\n", - " result = prim.value\n", - " else:\n", - " raise Exception\n", - " if len(stack) == 0:\n", - " break # If stack is empty, all nodes should have been seen\n", - " stack[-1][1].append(result)\n", - " return result\n", - "\n", - "def compileMultiTree(expr, pset):\n", - " \"\"\"Compile the expression represented by a list of trees.\n", - "\n", - " A variation of the gp.compileADF method, that handles Multi-tree GP.\n", - "\n", - " Args:\n", - " expr: Expression to compile. It can either be a PrimitiveTree,\n", - " a string of Python code or any object that when\n", - " converted into string produced a valid Python code\n", - " expression.\n", - " pset: Primitive Set\n", - "\n", - " Returns:\n", - " A set of functions that correspond for each tree in the Multi-tree.\n", - " \"\"\"\n", - " funcs = []\n", - " gp_tree = None\n", - " func = None\n", - "\n", - " for subexpr in expr:\n", - " gp_tree = gp.PrimitiveTree(subexpr)\n", - " # 5x speedup by manually parsing GP tree (Zhang 2022) https://mail.google.com/mail/u/0/#inbox/FMfcgzGqQmQthcqPCCNmstgLZlKGXvbc\n", - " func = quick_evaluate(gp_tree, pset, X, prefix='ARG')\n", - " funcs.append(func)\n", - "\n", - " # Hengzhe's method returns the features in the wrong rotation for multi-tree\n", - " features = np.array(funcs).T\n", - " return features\n", - "\n", - "# MCIFC constructs 8 feautres for a (c=4) multi-class classification problem (Tran 2019).\n", - "# c - number of classes, r - construction ratio, m - total number of constructed features.\n", - "# m = r * c = 2 ratio * 4 classes = 8 features\n", - "\n", - "r = 3\n", - "c = n_classes\n", - "m = r * c\n", - "\n", - "toolbox.register(\"expr\", gp.genHalfAndHalf, pset=pset, min_=1, max_=2)\n", - "toolbox.register(\"individual\", tools.initRepeat, creator.Individual, toolbox.expr, n=m)\n", - "toolbox.register(\"population\", tools.initRepeat, list, toolbox.individual)\n", - "toolbox.register(\"compile\", compileMultiTree)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dFj8hToTJ09i" - }, - "source": [ - "## Infinite speedup!?\n", - "\n", - "For multi-tree GP, this speedup factor is invaluable! As each individual conists of m trees. For the fish dataset we have 4 classes, each with 3 constructed features, which corresponds to 4 classes x 3 features = 12 trees for each individual. 12 trees x 500% speedup = 6,000% overall speedup, or 60 times faster. The 500% speedup is fundamental, for efficient evaluation of multi-tree GP.\n", - "\n", - "The evaluation below shows my calculations may still be wrong. And perhaps, it is even faster:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "AcGYavg1D-77", - "outputId": "f3561507-a25c-4788-965b-afd7e2369c6d" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "6.27 ms ± 1.97 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", - "The slowest run took 5.16 times longer than the fastest. This could mean that an intermediate result is being cached.\n", - "5.37 µs ± 4.18 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" - ] - } - ], - "source": [ - "first = toolbox.population(n=1)[0]\n", - "\n", - "subtree = first[0]\n", - "gp_tree = gp.PrimitiveTree(subtree)\n", - "\n", - "%timeit gp.compile(gp_tree, pset)\n", - "%timeit quick_evaluate(gp_tree, pset, X)" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ryKFCvmusAlS", - "outputId": "7abe0e6e-de91-41d5-8d1b-94705ce276b2" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Train accuracy: 1.0, Validation accuracy: 0.4166666666666667, Test accuracy: 0.3333333333333333\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "0.7083333333333334" - ] - }, - "metadata": {}, - "execution_count": 37 - } - ], - "source": [ - "from sklearn.svm import LinearSVC as svm\n", - "from sklearn.model_selection import StratifiedKFold, train_test_split\n", - "\n", - "def wrapper_classification_accuracy(X, k=10, verbose=False):\n", - " \"\"\" Evaluate balanced classification accuracy over stratified k-fold cross validation.\n", - "\n", - " This method is our fitness measure for an individual. We measure each individual\n", - " based on its balanced classification accuracy using 10-fold cross-validation on\n", - " the training set.\n", - "\n", - " If verbose, we evaluate performance on the test set as well, and print the results\n", - " to the standard output. By default, only the train set is evaluated, which\n", - " corresponds to a 2x speedup for training, when compared to the verbose method.\n", - "\n", - " Args:\n", - " X: entire dataset, train and test.\n", - " k: Number of folds, for cross validation. Defaults to 10.\n", - " verbose: If true, prints stuff. Defaults to false.\n", - "\n", - " Returns:\n", - " Average balanced classification accuracy with 10-fold CV on training set.\n", - " \"\"\"\n", - "\n", - " X_train, X_temp, y_train, y_temp = train_test_split(X, y, stratify=y, test_size=0.6666, random_state=42)\n", - " X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, stratify=y_temp, test_size=0.5, random_state=42)\n", - "\n", - " # train_accs = []\n", - " # test_accs = []\n", - " # skf = StratifiedKFold(n_splits=3)\n", - " # for train_idx, test_idx in skf.split(X,y):\n", - " # X_train, X_test = X[train_idx], X[test_idx]\n", - " # y_train, y_test = y[train_idx], y[test_idx]\n", - "\n", - " # Convergence errors for the fish part dataset.\n", - " # Need to use a different SVM hyperparameters for this dataset.\n", - " # model = svm(penalty='l2', max_iter=10_000)\n", - " model = svm()\n", - " model.fit(X_train, y_train)\n", - " y_predict = model.predict(X_train)\n", - " train_acc = balanced_accuracy_score(y_train, y_predict)\n", - " # train_accs.append(train_acc)\n", - " # test_accs.append(test_acc)\n", - "\n", - " y_predict = model.predict(X_val)\n", - " val_acc = balanced_accuracy_score(y_val, y_predict)\n", - "\n", - " if verbose:\n", - " # Only evaluate test set if verbose!\n", - " # Results in 2x speedup for training.\n", - " y_predict = model.predict(X_test)\n", - " test_acc = balanced_accuracy_score(y_test, y_predict)\n", - " print(f\"Train accuracy: {train_acc}, Validation accuracy: {val_acc}, Test accuracy: {test_acc}\")\n", - "\n", - " # avg_train_acc = np.mean(train_accs)\n", - " # avg_test_acc = np.mean(test_accs)\n", - "\n", - " # if verbose:\n", - " # Must be here, to avoid numpy warnings!\n", - " # print(f\"Average train accuracy: {avg_train_acc}, Average test accuracy: {avg_test_acc}\")\n", - "\n", - " alpha = 0.5\n", - " return alpha * train_acc + (1 - alpha) * val_acc\n", - "\n", - "wrapper_classification_accuracy(X, verbose=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "_vOz-1Oislgd", - "outputId": "5c164f46-180e-4826-a0c8-c7714009b2e6" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - " \t \t fitness \t size \n", - " \t \t------------------------------------------------------------------------\t------------------------------------------------\n", - "gen\tnevals\tavg \tgen\tmax \tmin \tnevals\tstd \tavg \tgen\tmax\tmin\tnevals\tstd \n", - "0 \t1023 \t0.387586\t0 \t0.791667\t0.0833333\t1023 \t0.146112\t6.79081\t0 \t7 \t3 \t1023 \t0.592585\n", - "1 \t766 \t0.550831\t1 \t0.833333\t0.125 \t766 \t0.117603\t6.89052\t1 \t13 \t3 \t766 \t0.735723\n", - "2 \t758 \t0.625163\t2 \t0.833333\t0.0833333\t758 \t0.127214\t7.02151\t2 \t13 \t3 \t758 \t0.808996\n", - "3 \t781 \t0.682796\t3 \t0.875 \t0.125 \t781 \t0.129309\t7.03226\t3 \t13 \t5 \t781 \t0.932699\n", - "4 \t765 \t0.732608\t4 \t0.875 \t0.0833333\t765 \t0.130342\t6.92473\t4 \t15 \t5 \t765 \t0.970334\n", - "5 \t783 \t0.774927\t5 \t0.875 \t0.0833333\t783 \t0.140355\t6.99316\t5 \t13 \t5 \t783 \t0.802581\n", - "6 \t755 \t0.812357\t6 \t0.875 \t0.125 \t755 \t0.132313\t7.09384\t6 \t12 \t6 \t755 \t0.678529\n", - "7 \t791 \t0.804619\t7 \t0.875 \t0.125 \t791 \t0.139209\t7.11046\t7 \t13 \t6 \t791 \t0.730911\n", - "8 \t810 \t0.794925\t8 \t0.875 \t0.166667 \t810 \t0.149431\t7.15249\t8 \t12 \t6 \t810 \t0.737364\n", - "9 \t751 \t0.808203\t9 \t0.875 \t0.125 \t751 \t0.131484\t7.14761\t9 \t12 \t6 \t751 \t0.788921\n", - "10 \t780 \t0.80637 \t10 \t0.875 \t0.125 \t780 \t0.14206 \t7.20039\t10 \t13 \t6 \t780 \t0.885376\n", - "11 \t776 \t0.807755\t11 \t0.875 \t0.166667 \t776 \t0.14191 \t7.27566\t11 \t13 \t6 \t776 \t0.946395\n", - "12 \t757 \t0.808366\t12 \t0.875 \t0.125 \t757 \t0.141642\t7.26784\t12 \t14 \t6 \t757 \t0.946575\n", - "13 \t757 \t0.816349\t13 \t0.875 \t0.166667 \t757 \t0.129482\t7.26295\t13 \t13 \t6 \t757 \t0.930774\n", - "14 \t789 \t0.81586 \t14 \t0.875 \t0.125 \t789 \t0.129796\t7.30401\t14 \t12 \t6 \t789 \t0.948556\n", - "15 \t817 \t0.812928\t15 \t0.875 \t0.125 \t817 \t0.134941\t7.38123\t15 \t13 \t6 \t817 \t0.980414\n", - "16 \t783 \t0.817041\t16 \t0.875 \t0.166667 \t783 \t0.130802\t7.41153\t16 \t13 \t6 \t783 \t1.03871 \n", - "17 \t766 \t0.816919\t17 \t0.875 \t0.166667 \t766 \t0.126157\t7.47703\t17 \t13 \t6 \t766 \t1.10858 \n", - "18 \t781 \t0.809669\t18 \t0.875 \t0.166667 \t781 \t0.134647\t7.45552\t18 \t17 \t5 \t781 \t1.2035 \n", - "19 \t802 \t0.819282\t19 \t0.875 \t0.0833333\t802 \t0.127511\t7.54936\t19 \t15 \t6 \t802 \t1.28045 \n", - "20 \t768 \t0.80971 \t20 \t0.875 \t0.166667 \t768 \t0.138194\t7.47019\t20 \t15 \t6 \t768 \t1.23837 \n", - "21 \t790 \t0.813783\t21 \t0.875 \t0.166667 \t790 \t0.137514\t7.52493\t21 \t13 \t5 \t790 \t1.22738 \n", - "22 \t758 \t0.810973\t22 \t0.875 \t0.125 \t758 \t0.140973\t7.53177\t22 \t15 \t6 \t758 \t1.23437 \n", - "23 \t775 \t0.821929\t23 \t0.875 \t0.166667 \t775 \t0.125558\t7.46432\t23 \t13 \t5 \t775 \t1.19483 \n", - "24 \t773 \t0.806777\t24 \t0.875 \t0.125 \t773 \t0.143739\t7.46921\t24 \t16 \t5 \t773 \t1.26953 \n", - "25 \t751 \t0.819811\t25 \t0.916667\t0.166667 \t751 \t0.129751\t7.42326\t25 \t14 \t5 \t751 \t1.14438 \n", - "26 \t776 \t0.816634\t26 \t0.916667\t0.125 \t776 \t0.129983\t7.36461\t26 \t14 \t5 \t776 \t1.07627 \n", - "27 \t753 \t0.807103\t27 \t0.916667\t0.125 \t753 \t0.148513\t7.38514\t27 \t14 \t5 \t753 \t1.07956 \n", - "28 \t788 \t0.820055\t28 \t0.916667\t0.125 \t788 \t0.134931\t7.3871 \t28 \t13 \t5 \t788 \t1.0504 \n", - "29 \t775 \t0.836632\t29 \t0.916667\t0.166667 \t775 \t0.127164\t7.34115\t29 \t14 \t6 \t775 \t1.00727 \n", - "30 \t783 \t0.844819\t30 \t0.916667\t0.166667 \t783 \t0.12808 \t7.25611\t30 \t13 \t6 \t783 \t0.841785\n", - "31 \t789 \t0.841805\t31 \t0.916667\t0.166667 \t789 \t0.141222\t7.35582\t31 \t13 \t7 \t789 \t0.992879\n", - "32 \t760 \t0.85325 \t32 \t0.916667\t0.125 \t760 \t0.132885\t7.37243\t32 \t12 \t7 \t760 \t0.982298\n", - "33 \t777 \t0.848648\t33 \t0.916667\t0.166667 \t777 \t0.132838\t7.39883\t33 \t14 \t7 \t777 \t0.998179\n", - "34 \t761 \t0.851825\t34 \t0.916667\t0.166667 \t761 \t0.125768\t7.46139\t34 \t15 \t7 \t761 \t1.10461 \n", - "35 \t782 \t0.844167\t35 \t0.916667\t0.166667 \t782 \t0.139536\t7.55621\t35 \t17 \t7 \t782 \t1.22635 \n", - "36 \t765 \t0.846856\t36 \t0.916667\t0.166667 \t765 \t0.134605\t7.57283\t36 \t17 \t7 \t765 \t1.30879 \n", - "37 \t770 \t0.849381\t37 \t0.916667\t0.166667 \t770 \t0.131763\t7.56989\t37 \t17 \t6 \t770 \t1.30969 \n", - "38 \t756 \t0.849218\t38 \t0.916667\t0.125 \t756 \t0.139697\t7.60215\t38 \t14 \t7 \t756 \t1.28114 \n", - "39 \t776 \t0.846733\t39 \t0.916667\t0.166667 \t776 \t0.140701\t7.64614\t39 \t14 \t7 \t776 \t1.34255 \n", - "40 \t783 \t0.847467\t40 \t0.916667\t0.0833333\t783 \t0.136352\t7.66373\t40 \t14 \t7 \t783 \t1.3658 \n", - "41 \t778 \t0.851988\t41 \t0.916667\t0.166667 \t778 \t0.133342\t7.65885\t41 \t14 \t6 \t778 \t1.31237 \n", - "42 \t782 \t0.855368\t42 \t0.916667\t0.166667 \t782 \t0.130187\t7.65005\t42 \t14 \t7 \t782 \t1.29391 \n", - "43 \t777 \t0.854554\t43 \t0.916667\t0.166667 \t777 \t0.128089\t7.74487\t43 \t16 \t7 \t777 \t1.44309 \n", - "44 \t765 \t0.853047\t44 \t0.916667\t0.166667 \t765 \t0.132218\t7.7478 \t44 \t15 \t7 \t765 \t1.4358 \n", - "45 \t784 \t0.849014\t45 \t0.916667\t0.125 \t784 \t0.137633\t7.73998\t45 \t15 \t7 \t784 \t1.41795 \n", - "46 \t772 \t0.846041\t46 \t0.916667\t0.166667 \t772 \t0.139232\t7.80645\t46 \t16 \t7 \t772 \t1.48951 \n", - "47 \t780 \t0.85545 \t47 \t0.916667\t0.166667 \t780 \t0.124822\t7.71065\t47 \t14 \t6 \t780 \t1.3398 \n", - "48 \t775 \t0.85378 \t48 \t0.916667\t0.166667 \t775 \t0.130179\t7.69208\t48 \t15 \t7 \t775 \t1.36784 \n", - "49 \t767 \t0.864329\t49 \t0.916667\t0.166667 \t767 \t0.113784\t7.72825\t49 \t14 \t7 \t767 \t1.36011 \n", - "50 \t786 \t0.853658\t50 \t0.916667\t0.166667 \t786 \t0.130737\t7.68426\t50 \t15 \t7 \t786 \t1.35888 \n", - "51 \t763 \t0.857201\t51 \t0.916667\t0.166667 \t763 \t0.130651\t7.62659\t51 \t14 \t7 \t763 \t1.26048 \n", - "52 \t753 \t0.845349\t52 \t0.916667\t0.166667 \t753 \t0.141511\t7.57869\t52 \t13 \t6 \t753 \t1.21469 \n", - "53 \t755 \t0.84987 \t53 \t0.916667\t0.125 \t755 \t0.149051\t7.64712\t53 \t14 \t6 \t755 \t1.24968 \n", - "54 \t754 \t0.853006\t54 \t0.916667\t0.166667 \t754 \t0.133851\t7.71848\t54 \t16 \t6 \t754 \t1.35452 \n", - "55 \t748 \t0.856264\t55 \t0.916667\t0.166667 \t748 \t0.129953\t7.71457\t55 \t16 \t7 \t748 \t1.37803 \n", - "56 \t764 \t0.849544\t56 \t0.916667\t0.166667 \t764 \t0.133725\t7.84164\t56 \t16 \t7 \t764 \t1.52987 \n", - "57 \t773 \t0.854187\t57 \t0.916667\t0.166667 \t773 \t0.132287\t7.84066\t57 \t16 \t7 \t773 \t1.50237 \n", - "58 \t764 \t0.851662\t58 \t0.916667\t0.125 \t764 \t0.137114\t7.75073\t58 \t15 \t6 \t764 \t1.36477 \n", - "59 \t774 \t0.858382\t59 \t0.916667\t0.166667 \t774 \t0.118535\t7.77615\t59 \t13 \t6 \t774 \t1.33592 \n", - "60 \t764 \t0.857486\t60 \t0.916667\t0.125 \t764 \t0.125124\t7.77615\t60 \t14 \t6 \t764 \t1.36344 \n", - "61 \t787 \t0.849625\t61 \t0.916667\t0.125 \t787 \t0.140293\t7.83284\t61 \t14 \t7 \t787 \t1.42331 \n", - "62 \t753 \t0.853658\t62 \t0.916667\t0.125 \t753 \t0.133741\t7.79863\t62 \t13 \t7 \t753 \t1.40885 \n", - "63 \t778 \t0.858382\t63 \t0.916667\t0.166667 \t778 \t0.125083\t7.78104\t63 \t17 \t6 \t778 \t1.46345 \n", - "64 \t770 \t0.852354\t64 \t0.916667\t0.166667 \t770 \t0.128381\t7.72239\t64 \t17 \t6 \t770 \t1.43247 \n", - "65 \t768 \t0.851906\t65 \t0.916667\t0.166667 \t768 \t0.133367\t7.69892\t65 \t17 \t6 \t768 \t1.38603 \n", - "66 \t768 \t0.856386\t66 \t0.916667\t0.166667 \t768 \t0.127712\t7.71065\t66 \t15 \t6 \t768 \t1.41779 \n", - "67 \t762 \t0.849381\t67 \t0.916667\t0.166667 \t762 \t0.140236\t7.74976\t67 \t15 \t6 \t762 \t1.4375 \n", - "68 \t762 \t0.845715\t68 \t0.916667\t0.125 \t762 \t0.136616\t7.74976\t68 \t15 \t6 \t762 \t1.42795 \n", - "69 \t787 \t0.853087\t69 \t0.916667\t0.208333 \t787 \t0.126606\t7.76637\t69 \t15 \t6 \t787 \t1.43179 \n", - "70 \t805 \t0.852558\t70 \t0.916667\t0.125 \t805 \t0.136059\t7.7478 \t70 \t16 \t6 \t805 \t1.45608 \n", - "71 \t790 \t0.856672\t71 \t0.916667\t0.125 \t790 \t0.127667\t7.71359\t71 \t15 \t6 \t790 \t1.3979 \n", - "72 \t775 \t0.855531\t72 \t0.916667\t0.166667 \t775 \t0.131379\t7.73216\t72 \t15 \t7 \t775 \t1.36232 \n", - "73 \t776 \t0.856712\t73 \t0.916667\t0.166667 \t776 \t0.116737\t7.81525\t73 \t15 \t7 \t776 \t1.44433 \n", - "74 \t759 \t0.856427\t74 \t0.916667\t0.166667 \t759 \t0.125985\t7.76637\t74 \t14 \t7 \t759 \t1.42426 \n", - "75 \t776 \t0.846815\t75 \t0.916667\t0.0833333\t776 \t0.143878\t7.73998\t75 \t15 \t7 \t776 \t1.41519 \n", - "76 \t750 \t0.853128\t76 \t0.916667\t0.166667 \t750 \t0.13011 \t7.66862\t76 \t15 \t6 \t750 \t1.39706 \n", - "77 \t751 \t0.851499\t77 \t0.916667\t0.166667 \t751 \t0.135079\t7.71163\t77 \t15 \t6 \t751 \t1.48565 \n", - "78 \t797 \t0.851825\t78 \t0.916667\t0.166667 \t797 \t0.128887\t7.70577\t78 \t16 \t7 \t797 \t1.52606 \n", - "79 \t773 \t0.8482 \t79 \t0.916667\t0.166667 \t773 \t0.1359 \t7.67644\t79 \t17 \t7 \t773 \t1.51109 \n", - "80 \t760 \t0.850399\t80 \t0.916667\t0.125 \t760 \t0.1324 \t7.75562\t80 \t15 \t7 \t760 \t1.56167 \n", - "81 \t779 \t0.853087\t81 \t0.916667\t0.125 \t779 \t0.131381\t7.65787\t81 \t15 \t6 \t779 \t1.4212 \n", - "82 \t763 \t0.846978\t82 \t0.916667\t0.166667 \t763 \t0.137998\t7.71457\t82 \t16 \t6 \t763 \t1.52549 \n", - "83 \t784 \t0.846489\t83 \t0.916667\t0.166667 \t784 \t0.141817\t7.71163\t83 \t18 \t7 \t784 \t1.54816 \n", - "84 \t767 \t0.848851\t84 \t0.916667\t0.166667 \t767 \t0.135063\t7.65005\t84 \t18 \t7 \t767 \t1.46738 \n", - "85 \t785 \t0.851417\t85 \t0.916667\t0.166667 \t785 \t0.133789\t7.67155\t85 \t18 \t7 \t785 \t1.43058 \n", - "86 \t766 \t0.846489\t86 \t0.916667\t0.166667 \t766 \t0.140977\t7.71359\t86 \t15 \t7 \t766 \t1.44332 \n", - "87 \t776 \t0.856753\t87 \t0.916667\t0.0833333\t776 \t0.125534\t7.74878\t87 \t15 \t7 \t776 \t1.47526 \n", - "88 \t771 \t0.857364\t88 \t0.916667\t0.166667 \t771 \t0.124433\t7.73118\t88 \t15 \t6 \t771 \t1.43857 \n", - "89 \t773 \t0.849951\t89 \t0.916667\t0.166667 \t773 \t0.136948\t7.7869 \t89 \t18 \t7 \t773 \t1.51162 \n", - "90 \t772 \t0.853454\t90 \t0.916667\t0.166667 \t772 \t0.128445\t7.81623\t90 \t17 \t7 \t772 \t1.55745 \n", - "91 \t775 \t0.846896\t91 \t0.916667\t0.0833333\t775 \t0.138705\t7.79081\t91 \t16 \t7 \t775 \t1.47751 \n", - "92 \t769 \t0.851662\t92 \t0.916667\t0.125 \t769 \t0.133197\t7.83578\t92 \t17 \t7 \t769 \t1.53818 \n", - "93 \t795 \t0.852313\t93 \t0.916667\t0.166667 \t795 \t0.130009\t7.91105\t93 \t15 \t6 \t795 \t1.61581 \n", - "94 \t776 \t0.848444\t94 \t0.916667\t0.166667 \t776 \t0.133301\t7.85924\t94 \t16 \t6 \t776 \t1.57004 \n", - "95 \t764 \t0.862455\t95 \t0.916667\t0.166667 \t764 \t0.115944\t7.93744\t95 \t14 \t7 \t764 \t1.61251 \n", - "96 \t779 \t0.849707\t96 \t0.916667\t0.125 \t779 \t0.133158\t7.88759\t96 \t14 \t7 \t779 \t1.60097 \n", - "97 \t768 \t0.852599\t97 \t0.916667\t0.125 \t768 \t0.138679\t7.83382\t97 \t17 \t7 \t768 \t1.55189 \n", - "98 \t760 \t0.852639\t98 \t0.916667\t0.125 \t760 \t0.137228\t7.88563\t98 \t17 \t7 \t760 \t1.58365 \n", - "99 \t765 \t0.860011\t99 \t0.916667\t0.166667 \t765 \t0.120624\t7.81036\t99 \t17 \t5 \t765 \t1.45752 \n", - "100\t794 \t0.847181\t100\t0.916667\t0.166667 \t794 \t0.139415\t7.76833\t100\t14 \t7 \t794 \t1.3102 \n" - ] - } - ], - "source": [ - "# DEBUG : REMOVE THIS !!!\n", - "import warnings\n", - "warnings.filterwarnings('ignore')\n", - "\n", - "def xmate(ind1, ind2):\n", - " \"\"\" Reproduction operator for multi-tree GP, where trees are represented as a list.\n", - "\n", - " Crossover happens to a subtree that is selected at random.\n", - " Crossover operations are limited to parents from the same tree.\n", - "\n", - " FIXME: Have to compile the trees (manually), which is frustrating.\n", - "\n", - " Args:\n", - " ind1 (Individual): The first parent.\n", - " ind2 (Individual): The second parent\n", - "\n", - " Returns:\n", - " ind1, ind2 (Individual, Individual): The children from the parents reproduction.\n", - " \"\"\"\n", - " n = range(len(ind1))\n", - " selected_tree_idx = random.choice(n)\n", - " for tree_idx in n:\n", - " g1, g2 = gp.PrimitiveTree(ind1[tree_idx]), gp.PrimitiveTree(ind2[tree_idx])\n", - " if tree_idx == selected_tree_idx:\n", - " ind1[tree_idx], ind2[tree_idx] = gp.cxOnePoint(g1, g2)\n", - " else:\n", - " ind1[tree_idx], ind2[tree_idx] = g1, g2\n", - " return ind1, ind2\n", - "\n", - "\n", - "def xmut(ind, expr):\n", - " \"\"\" Mutation operator for multi-tree GP, where trees are represented as a list.\n", - "\n", - " Mutation happens to a tree selected at random, when an individual is selected for crossover.\n", - "\n", - " FIXME: Have to compile the trees (manually), which is frustrating.\n", - "\n", - " Args:\n", - " ind: The individual, a list of GP trees.\n", - " \"\"\"\n", - " n = range(len(ind))\n", - " selected_tree_idx = random.choice(n)\n", - " for tree_idx in n:\n", - " g1 = gp.PrimitiveTree(ind[tree_idx])\n", - " if tree_idx == selected_tree_idx:\n", - " indx = gp.mutUniform(g1, expr, pset)\n", - " ind[tree_idx] = indx[0]\n", - " else:\n", - " ind[tree_idx] = g1\n", - " return ind,\n", - "\n", - "\n", - "def evaluate_classification(individual, alpha = 0.9, verbose=False):\n", - " \"\"\"\n", - " Evalautes the fitness of an individual for multi-tree GP multi-class classification.\n", - "\n", - " We maxmimize the fitness when we evaluate the accuracy + regularization term.\n", - "\n", - " Args:\n", - " individual (Individual): A candidate solution to be evaluated.\n", - " alpha (float): A parameter that balances the accuracy and regularization term. Defaults to 0.98.\n", - "\n", - " Returns:\n", - " accuracy (tuple): The fitness of the individual.\n", - " \"\"\"\n", - " features = toolbox.compile(expr=individual, pset=pset)\n", - " fitness = wrapper_classification_accuracy(features, verbose=verbose)\n", - " return fitness,\n", - "\n", - "\n", - "toolbox.register('evaluate', evaluate_classification)\n", - "toolbox.register(\"select\", tools.selTournament, tournsize=7)\n", - "toolbox.register(\"mate\", xmate)\n", - "toolbox.register(\"expr_mut\", gp.genFull, min_=0, max_=2)\n", - "toolbox.register(\"mutate\", xmut, expr=toolbox.expr_mut)\n", - "\n", - "\n", - "def staticLimit(key, max_value):\n", - " \"\"\"\n", - " A variation of gp.staticLimit that works for Multi-tree representation.\n", - " This works for our altered xmut and xmate genetic operators for mutli-tree GP.\n", - " If tree depth limit is exceeded, the genetic operator is reverted.\n", - "\n", - " When an invalid (over the limit) child is generated,\n", - " it is simply replaced by one of its parents, randomly selected.\n", - "\n", - " Args:\n", - " key: The function to use in order the get the wanted value. For\n", - " instance, on a GP tree, ``operator.attrgetter('height')`` may\n", - " be used to set a depth limit, and ``len`` to set a size limit.\n", - " max_value: The maximum value allowed for the given measurement.\n", - " Defaults to 17, the suggested value in (Koza 1992)\n", - "\n", - " Returns:\n", - " A decorator that can be applied to a GP operator using \\\n", - " :func:`~deap.base.Toolbox.decorate`\n", - "\n", - " References:\n", - " 1. Koza, J. R. G. P. (1992). On the programming of computers by means\n", - " of natural selection. Genetic programming.\n", - " \"\"\"\n", - "\n", - " def decorator(func):\n", - " @wraps(func)\n", - " def wrapper(*args, **kwargs):\n", - " keep_inds = [[copy.deepcopy(tree) for tree in ind] for ind in args]\n", - " new_inds = list(func(*args, **kwargs))\n", - " for ind_idx, ind in enumerate(new_inds):\n", - " for tree_idx, tree in enumerate(ind):\n", - " if key(tree) > max_value:\n", - " random_parent = random.choice(keep_inds)\n", - " new_inds[ind_idx][tree_idx] = random_parent[tree_idx]\n", - " return new_inds\n", - " return wrapper\n", - " return decorator\n", - "\n", - "# See https://groups.google.com/g/deap-users/c/pWzR_q7mKJ0\n", - "toolbox.decorate(\"mate\", staticLimit(key=operator.attrgetter(\"height\"), max_value=8))\n", - "toolbox.decorate(\"mutate\", staticLimit(key=operator.attrgetter(\"height\"), max_value=8))\n", - "\n", - "\n", - "def SimpleGPWithElitism(population, toolbox, cxpb, mutpb, ngen, stats=None,\n", - " halloffame=None, verbose=__debug__):\n", - " \"\"\"\n", - " Elitism for Multi-Tree GP for Multi-Class classification.\n", - " A variation of the eaSimple method from the DEAP library that supports\n", - "\n", - " Elitism ensures the best individuals (the elite) from each generation are\n", - " carried onto the next without alteration. This ensures the quality of the\n", - " best solution monotonically increases over time.\n", - " \"\"\"\n", - " logbook = tools.Logbook()\n", - " logbook.header = ['gen', 'nevals'] + (stats.fields if stats else [])\n", - "\n", - " invalid_ind = [ind for ind in population if not ind.fitness.valid]\n", - " fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)\n", - "\n", - " for ind, fit in zip(invalid_ind, fitnesses):\n", - " ind.fitness.values = fit\n", - "\n", - " if halloffame is None:\n", - " raise ValueError(\"halloffame parameter must not be empty!\")\n", - "\n", - " halloffame.update(population)\n", - " hof_size = len(halloffame.items) if halloffame.items else 0\n", - "\n", - " record = stats.compile(population) if stats else {}\n", - " logbook.record(gen=0, nevals=len(invalid_ind), **record)\n", - "\n", - " if verbose:\n", - " print(logbook.stream)\n", - "\n", - " for gen in range(1, ngen + 1):\n", - " offspring = toolbox.select(population, len(population) - hof_size)\n", - " offspring = algorithms.varAnd(offspring, toolbox, cxpb, mutpb)\n", - " invalid_ind = [ind for ind in offspring if not ind.fitness.valid]\n", - " fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)\n", - "\n", - " for ind, fit in zip(invalid_ind, fitnesses):\n", - " ind.fitness.values = fit\n", - "\n", - " offspring.extend(halloffame.items)\n", - " halloffame.update(offspring)\n", - " population[:] = offspring\n", - "\n", - " record = stats.compile(population) if stats else {}\n", - " logbook.record(gen=gen, nevals=len(invalid_ind), **record)\n", - "\n", - " if verbose:\n", - " print(logbook.stream)\n", - "\n", - " return population, logbook\n", - "\n", - "\n", - "def train(generations=100, population=100, elitism=0.1, crossover_rate=0.5, mutation_rate=0.1):\n", - " \"\"\"\n", - " This is a Multi-tree GP with Elitism for Multi-class classification.\n", - "\n", - " Args:\n", - " generations: The number of generations to evolve the populaiton for.\n", - " elitism: The ratio of elites to be kept between generations.\n", - " crossover_rate: The probability of a crossover between two individuals.\n", - " mutation_rate: The probability of a random mutation within an individual.\n", - "\n", - " Returns:\n", - " pop: The final population the algorithm has evolved.\n", - " log: The logbook which can record important statistics.\n", - " hof: The hall of fame contains the best individual solutions.\n", - " \"\"\"\n", - " random.seed(420)\n", - " pop = toolbox.population(n=population)\n", - "\n", - " mu = round(elitism * population)\n", - " if elitism > 0:\n", - " # See https://www.programcreek.com/python/example/107757/deap.tools.HallOfFame\n", - " hof = tools.HallOfFame(mu)\n", - " else:\n", - " hof = None\n", - "\n", - " stats_fit = tools.Statistics(lambda ind: ind.fitness.values)\n", - " length = lambda a: np.max(list(map(len, a)))\n", - " stats_size = tools.Statistics(length)\n", - "\n", - " mstats = tools.MultiStatistics(fitness=stats_fit, size=stats_size)\n", - " mstats.register(\"avg\", np.mean)\n", - " mstats.register(\"std\", np.std)\n", - " mstats.register(\"min\", np.min)\n", - " mstats.register(\"max\", np.max)\n", - "\n", - " pop, log = SimpleGPWithElitism(pop, toolbox, crossover_rate, mutation_rate,\n", - " generations, stats=mstats, halloffame=hof,\n", - " verbose=True)\n", - " return pop, log, hof\n", - "\n", - "\n", - "\"\"\"\n", - "DeJong (1975), p=50-100, m=0.001, c=0.6\n", - "Grefenstette (1986), p=30, m=0.01, c=0.95\n", - "Schaffer et al., (1989), p=20-30, m=0.005-0.01, c=0.75-0.95\n", - "\n", - "References:\n", - " 1. Patil, V. P., & Pawar, D. D. (2015). The optimal crossover or mutation\n", - " rates in genetic algorithm: a review. International Journal of Applied\n", - " Engineering and Technology, 5(3), 38-41.\n", - "\"\"\"\n", - "\n", - "beta = 1\n", - "population = n_features * beta\n", - "generations = 100\n", - "elitism = 0.1\n", - "crossover_rate = 0.8\n", - "mutation_rate = 0.2\n", - "\n", - "assert crossover_rate + mutation_rate == 1, \"Crossover and mutation sums to 1 (to please the Gods!)\"\n", - "\n", - "pop, log, hof = train(generations, population, elitism, crossover_rate, mutation_rate)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": { - "id": "VSEgttwuilPy", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "29d1754e-3ad2-44c8-ff9b-1fd4b5f37c78" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.16666666666666666\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.16666666666666666\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.3333333333333333\n", - "Train accuracy: 0.9166666666666666, Validation accuracy: 0.75, Test accuracy: 0.3333333333333333\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 0.9166666666666666, Validation accuracy: 0.4166666666666667, Test accuracy: 0.16666666666666666\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.16666666666666666\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.16666666666666666\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.16666666666666666\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 0.9166666666666666, Validation accuracy: 0.5, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.16666666666666666\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.16666666666666666\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.16666666666666666\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.16666666666666666\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 0.9166666666666666, Validation accuracy: 0.75, Test accuracy: 0.3333333333333333\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.16666666666666666\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.16666666666666666\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.8333333333333334, Test accuracy: 0.25\n", - "Train accuracy: 1.0, Validation accuracy: 0.75, Test accuracy: 0.25\n" - ] - } - ], - "source": [ - "for i in range(len(hof)):\n", - " evaluate_classification(hof[i], verbose=True)" - ] - }, - { - "cell_type": "code", - "source": [ - "X_train, X_temp, y_train, y_temp = train_test_split(X, y, stratify=y, test_size=0.6666, random_state=42)\n", - "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, stratify=y_temp, test_size=0.5, random_state=42)\n", - "\n", - "# print(f\"np.unique(y_train): {np.unique(y_train)}\\n np.unique(y_val): {np.unique(y_val)} \\n np.unique(y_test): {np.unique(y_test)}\")\n", - "for (name, ds) in [(\"train\", y_train), (\"val\", y_val), (\"test\", y_test)]:\n", - " print(f\"{name}: {np.unique(ds)}\")\n", - "\n", - " classes, class_counts = np.unique(ds, axis=0, return_counts=True)\n", - " n_features = X.shape[1]\n", - " n_instances = X.shape[0]\n", - " n_classes = len(np.unique(ds, axis=0))\n", - " class_ratios = np.array(class_counts) / n_instances\n", - "\n", - " print(f\"Class Counts: {class_counts}, Class Ratios: {class_ratios}\")\n", - " print(f\"Number of features: {n_features}\\nNumber of instances: {n_instances}\\nNumber of classes {n_classes}.\")" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "SZrI3MjywK_K", - "outputId": "64163f7b-145a-4399-e1fb-6588eb6965db" - }, - "execution_count": 45, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "train: [0. 1. 2. 3. 4. 5.]\n", - "Class Counts: [2 2 1 2 2 1], Class Ratios: [0.06666667 0.06666667 0.03333333 0.06666667 0.06666667 0.03333333]\n", - "Number of features: 1023\n", - "Number of instances: 30\n", - "Number of classes 6.\n", - "val: [0. 1. 2. 3. 4. 5.]\n", - "Class Counts: [2 2 1 2 2 1], Class Ratios: [0.06666667 0.06666667 0.03333333 0.06666667 0.06666667 0.03333333]\n", - "Number of features: 1023\n", - "Number of instances: 30\n", - "Number of classes 6.\n", - "test: [0. 1. 2. 3. 4. 5.]\n", - "Class Counts: [2 2 1 2 2 1], Class Ratios: [0.06666667 0.06666667 0.03333333 0.06666667 0.06666667 0.03333333]\n", - "Number of features: 1023\n", - "Number of instances: 30\n", - "Number of classes 6.\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YkaU6kAGyBl2" - }, - "source": [ - "## Visualization" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": { - "id": "pEqouLDs4xBJ" - }, - "outputs": [], - "source": [ - "from deap import base, creator, gp\n", - "import pygraphviz as pgv\n", - "\n", - "multi_tree = hof[0]\n", - "for t_idx,tree in enumerate(multi_tree):\n", - " nodes, edges, labels = gp.graph(tree)\n", - "\n", - " g = pgv.AGraph()\n", - " g.add_nodes_from(nodes)\n", - " g.add_edges_from(edges)\n", - " g.layout(prog=\"dot\")\n", - "\n", - " for i in nodes:\n", - " n = g.get_node(i)\n", - " n.attr[\"label\"] = labels[i]\n", - "\n", - " g.draw(f\"tree-{t_idx}.pdf\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QX_g4MipuQTU" - }, - "source": [ - "## Changelog\n", - "\n", - "| Date | Title | Description | Update |\n", - "| --- | --- | --- | ---- |\n", - " 2024-02-14 16:56 | Mass spectra | Applications for MT-GP for rapid mass spectrometry dataset. | |\n", - "| 2022-08-21 17:30 | Multi-Objective - Onehot Encoding | Change to multi-objective problem, one-vs-all with a tree classifier for each class.
Y labels are encoded in onehot encodings, error is absolute difference between $|\\hat{y} - y|$|\n", - "| 2022-08-22 20:44 | Non-linearity | Introduce $ round . sigmoid $ to evaluate_classification() method.
Previously, we push each tree to predict either a 0 or 1 value with the onehot encoding representation.
Now, the non-linearity will map any negative value to a negative class 0, and any positive value to positive class 1.|\n", - "| 2022-08-22 21:06 | ~~Genetic operators for tree with worst fitness~~ | Only apply the genetic operators, crossover and mutation, to the tree with the worst fitness.
This guarantees monotonic improvement for the Multi-tree between generations, the best performing tree remain unaltered.| (Update) This was very slow, and inefficient,
basically turned the GP into a single objective,
that balances multi-objective fitness functions. |\n", - "| 2022-08-22 21:15 | Halloffame Equality Operator | Numpy equality function (operators.eq) between two arrays returns the equality element wise,
which raises an exception in the if similar() check of the hall of fame.
Using a different equality function like numpy.array_equal or numpy.allclose solve this issue.|\n", - "| 2022-08-22 23:22 | Elitism as aggregate best tree | Perform elitsim by constructing the best tree, as the tree with best fitness from each clas.
The goal is to have monotonous improvement across the multiple objective functions.|\n", - "| 2022-08-22 23:32 | Update fitness for elite | The elitism was not working as intended, as the multi-objectives didn't appear to increase monotnously.
This was because the aggregate fitness was not being assigned to the best individual after it was created.
Therefore the best invidiual was not passed on to the next generation. |\n", - "| 2022-08-22 02:28 | staticLimit max height | Rewrite the gp.staticLimit decorator function to handle the Multi-tree representation.
Note: size and depth are different values!
depth is the maximum length of a root-to-leaf traversal,
size is the total number of nodes.|\n", - "| 2022-08-24 9:37 | Evaluate Mutli-tree train accuracy | Take the classification accuracy as the argmax for the aggregate multitree.
74% training accuracy, which is not ideal, but this shall improve with time.|\n", - "| 2022-08-25 13:30 | Single-objective fitness | Change the fitness function to a single objective fitness function.
This forces the multi-tree GP to find the best tree subset for one-vs-rest classification performance.|\n", - "| 2022-08-25 20:01 | Fitness = Balanced accuracy + distance measure | Implement the fitness function for MCIFC, but for multi-class classification from (Tran 2019) |\n", - "| 2022-08-26 21:27 | Sklearn Balanced Accuracy | Changed to the balanced accuracy metric from sklearn.
This is much easier to use for now, probably faster than the previous method as well. |\n", - "| 2022-09-05 17:00 | Reject invalid predictions | Change the fitness function to reject invalid predictioctions outright -
e.g. multi-label or zero-label predictions
- when computing the balanced accuracy for the fitness function. |\n", - "| 2022-09-13 19:00 | Mutation + Crossover = 100% | Ensure the mutation and crossover rate sum to 100%,
not necessary with deap, but good to avoid conference questions |\n", - "| 2022-09-13 21:00 | Feature Construction | Changed to Wrapper-based Feature Construction with Multi-tree GP.|\n", - "| 2022-09-13 21:34 | $m = r \\times c$ | Add more trees, following example from (Tran 2019).
With 8 trees for a multi-class classification
$m = r \\times c = 8$ trees, where number of classes $c = 4$, and reconstruction ratio $r = 2$/ |\n", - "| 2022-09-30 19:49 | Quick Evalaute | Manually parse the GP trees, 5x speedup for DEAP in Python (Zhang 2022). |\n", - "| 2022-10-13 6:02 | Ignore timestamps after 4500 | Ignore timestamps after 4500 did not improve accuracy for SVM classifier.
So the bizzare pattern that occurs on GC-MS image there has important information.
Should investigate this further, perhaps ask Daniel as he is a domain expert.|\n", - "| 2022-10-13 22:16 | Cross validation | Evaluate the mean balanced classification accuracy over stratified k-fold cross validation.|\n", - "| 2023-01-13 20:58 | 2x speedup | Only evaluate test set for verbose alternative of the evaluate_classification method.
This results in a 2x speedup in the efficiency of the training regime.|" - ] - } - ], - "metadata": { - "colab": { - "machine_shape": "hm", - "provenance": [], - "mount_file_id": "1Yg4t38NHSYPAlu_099cQeOR-qSwIaaTl", - "authorship_tag": "ABX9TyPs3TrIMW2i9Riw1qhKkO/N", - "include_colab_link": true - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file