diff --git a/WIP/4 - Gated_PixelCNN/gated_pixelCNN.py b/WIP/4 - Gated_PixelCNN/gated_pixelCNN.py
new file mode 100644
index 0000000..5a381fb
--- /dev/null
+++ b/WIP/4 - Gated_PixelCNN/gated_pixelCNN.py
@@ -0,0 +1,8 @@
+import random as rn
+import time
+
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
\ No newline at end of file
diff --git a/WIP/4 - Gated_PixelCNN/gated_pixelcnn.ipynb b/WIP/4 - Gated_PixelCNN/gated_pixelcnn.ipynb
index 9fe85e6..8fedab5 100644
--- a/WIP/4 - Gated_PixelCNN/gated_pixelcnn.ipynb
+++ b/WIP/4 - Gated_PixelCNN/gated_pixelcnn.ipynb
@@ -1,604 +1,602 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "view-in-github"
- },
- "source": [
- ""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "accelerator": "GPU",
"colab": {
- "base_uri": "https://localhost:8080/",
- "height": 402
+ "name": "gated_pixelcnn.ipynb",
+ "provenance": []
},
- "colab_type": "code",
- "id": "VMF0k8kyGjQ7",
- "outputId": "489e26d2-3f1f-4488-92b6-d0f2801e7a2d"
- },
- "outputs": [],
- "source": [
- "! pip3 install tensorflow-gpu==2.0.0-rc1"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "k1uZnxh4Xz9Z"
- },
- "outputs": [],
- "source": [
- "import random as rn\n",
- "import time\n",
- "\n",
- "import matplotlib\n",
- "import matplotlib.pyplot as plt\n",
- "import numpy as np\n",
- "import tensorflow as tf\n",
- "from tensorflow import keras"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'2.0.0'"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
}
- ],
- "source": [
- "tf.__version__"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "NN6vJl7eVnZ4"
- },
- "outputs": [],
- "source": [
- "# --------------------------------------------------------------------------------------------------------------\n",
- "# Defining random seeds\n",
- "random_seed = 42\n",
- "tf.random.set_seed(random_seed)\n",
- "np.random.seed(random_seed)\n",
- "rn.seed(random_seed)"
- ]
},
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "8BnkhgCjVpJu"
- },
- "outputs": [],
- "source": [
- "# --------------------------------------------------------------------------------------------------------------\n",
- "# Loading data\n",
- "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
- "\n",
- "height = 28\n",
- "width = 28\n",
- "n_channel = 1\n",
- "\n",
- "x_train = x_train.astype('float32') / 255.\n",
- "x_test = x_test.astype('float32') / 255.\n",
- "\n",
- "x_train = x_train.reshape(x_train.shape[0], height, width, n_channel)\n",
- "x_test = x_test.reshape(x_test.shape[0], height, width, n_channel)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "3ne-qY7JVZaB"
- },
- "outputs": [],
- "source": [
- "def quantise(images, q_levels):\n",
- " \"\"\"Quantise image into q levels\"\"\"\n",
- " return (np.digitize(images, np.arange(q_levels) / q_levels) - 1).astype('float32')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "3QVhnMymVrzc"
- },
- "outputs": [],
- "source": [
- "# --------------------------------------------------------------------------------------------------------------\n",
- "# Quantise the input data in q levels\n",
- "q_levels = 256\n",
- "x_train_quantised = quantise(x_train, q_levels)\n",
- "x_test_quantised = quantise(x_test, q_levels)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 89
+ "cells": [
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "k1uZnxh4Xz9Z",
+ "colab": {}
+ },
+ "source": [
+ "import random as rn\n",
+ "import time\n",
+ "\n",
+ "import matplotlib\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import tensorflow as tf\n",
+ "from tensorflow.keras.utils import Progbar\n",
+ "from tensorflow import keras\n",
+ "from tensorflow.keras import initializers\n",
+ "from tensorflow import nn"
+ ],
+ "execution_count": 0,
+ "outputs": []
},
- "colab_type": "code",
- "id": "ZObIXqzNGwmo",
- "outputId": "64cdf3d5-2f5e-4e48-9d6b-e8844e1efcb3"
- },
- "outputs": [],
- "source": [
- "# --------------------------------------------------------------------------------------------------------------\n",
- "# Creating input stream using tf.data API\n",
- "batch_size = 128\n",
- "train_buf = 60000\n",
- "\n",
- "train_dataset = tf.data.Dataset.from_tensor_slices((x_train_quantised / (q_levels - 1),\n",
- " x_train_quantised.astype('int32')))\n",
- "train_dataset = train_dataset.shuffle(buffer_size=train_buf)\n",
- "train_dataset = train_dataset.batch(batch_size)\n",
- "\n",
- "test_dataset = tf.data.Dataset.from_tensor_slices((x_test_quantised / (q_levels - 1),\n",
- " x_test_quantised.astype('int32')))\n",
- "test_dataset = test_dataset.batch(batch_size)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "75VTDkK8VZLA"
- },
- "outputs": [],
- "source": [
- "class MaskedConv2D(tf.keras.layers.Layer):\n",
- " \"\"\"Convolutional layers with masks for autoregressive models\n",
- "\n",
- " Convolutional layers with simple implementation to have masks type A and B.\n",
- " \"\"\"\n",
- "\n",
- " def __init__(self,\n",
- " mask_type,\n",
- " filters,\n",
- " kernel_size,\n",
- " strides=1,\n",
- " padding='same',\n",
- " kernel_initializer='glorot_uniform',\n",
- " bias_initializer='zeros'):\n",
- " super(MaskedConv2D, self).__init__()\n",
- "\n",
- " assert mask_type in {'A', 'B', 'V'}\n",
- " self.mask_type = mask_type\n",
- "\n",
- " self.filters = filters\n",
- "\n",
- " if isinstance(kernel_size, int):\n",
- " kernel_size = (kernel_size, kernel_size)\n",
- " self.kernel_size = kernel_size\n",
- "\n",
- " self.strides = strides\n",
- " self.padding = padding.upper()\n",
- " self.kernel_initializer = keras.initializers.get(kernel_initializer)\n",
- " self.bias_initializer = keras.initializers.get(bias_initializer)\n",
- "\n",
- " def build(self, input_shape):\n",
- " kernel_h, kernel_w = self.kernel_size\n",
- "\n",
- " self.kernel = self.add_weight(\"kernel\",\n",
- " shape=(kernel_h,\n",
- " kernel_w,\n",
- " int(input_shape[-1]),\n",
- " self.filters),\n",
- " initializer=self.kernel_initializer,\n",
- " trainable=True)\n",
- "\n",
- " self.bias = self.add_weight(\"bias\",\n",
- " shape=(self.filters,),\n",
- " initializer=self.bias_initializer,\n",
- " trainable=True)\n",
- "\n",
- " mask = np.ones(self.kernel.shape, dtype=np.float32)\n",
- " if self.mask_type == 'V':\n",
- " mask[kernel_h // 2:, :, :, :] = 0.\n",
- " else:\n",
- " mask[kernel_h // 2, kernel_w // 2 + (self.mask_type == 'B'):, :, :] = 0.\n",
- " mask[kernel_h // 2 + 1:, :, :] = 0.\n",
- "\n",
- " self.mask = tf.constant(mask, dtype=tf.float32, name='mask')\n",
- "\n",
- " def call(self, input):\n",
- " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n",
- " x = tf.nn.conv2d(input, masked_kernel, strides=[1, self.strides, self.strides, 1], padding=self.padding)\n",
- " x = tf.nn.bias_add(x, self.bias)\n",
- " return x\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "PTUN4s52Nu3w"
- },
- "outputs": [],
- "source": [
- "class GatedBlock(tf.keras.Model):\n",
- " \"\"\"\"\"\"\n",
- "\n",
- " def __init__(self, mask_type, filters, kernel_size):\n",
- " super(GatedBlock, self).__init__(name='')\n",
- "\n",
- " self.mask_type = mask_type\n",
- " self.vertical_conv = MaskedConv2D(mask_type='V', filters=2 * filters, kernel_size=kernel_size)\n",
- " self.horizontal_conv = MaskedConv2D(mask_type=mask_type, filters=2 * filters, kernel_size=(1, kernel_size))\n",
- " self.v_to_h_conv = keras.layers.Conv2D(filters=2 * filters, kernel_size=1)\n",
- "\n",
- " self.horizontal_output = keras.layers.Conv2D(filters=filters, kernel_size=1)\n",
- "\n",
- " def _gate(self, x):\n",
- " tanh_preactivation, sigmoid_preactivation = tf.split(x, 2, axis=-1)\n",
- " return tf.nn.tanh(tanh_preactivation) * tf.nn.sigmoid(sigmoid_preactivation)\n",
- "\n",
- " def call(self, input_tensor):\n",
- " v = input_tensor[0]\n",
- " h = input_tensor[1]\n",
- "\n",
- " horizontal_preactivation = self.horizontal_conv(h) # 1xN\n",
- " vertical_preactivation = self.vertical_conv(v) # NxN\n",
- " v_to_h = self.v_to_h_conv(vertical_preactivation) # 1x1\n",
- " v_out = self._gate(vertical_preactivation)\n",
- "\n",
- " horizontal_preactivation = horizontal_preactivation + v_to_h\n",
- " h_activated = self._gate(horizontal_preactivation)\n",
- " h_activated = self.horizontal_output(h_activated)\n",
- "\n",
- " if self.mask_type == 'A':\n",
- " h_out = h_activated\n",
- " elif self.mask_type == 'B':\n",
- " h_out = h + h_activated\n",
- "\n",
- " return v_out, h_out"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "WB57YufrVxn2"
- },
- "outputs": [],
- "source": [
- "# --------------------------------------------------------------------------------------------------------------\n",
- "# Create PixelCNN model\n",
- "inputs = keras.layers.Input(shape=(height, width, n_channel))\n",
- "v, h = GatedBlock(mask_type='A', filters=64, kernel_size=3)([inputs, inputs])\n",
- "\n",
- "for i in range(7):\n",
- " v, h = GatedBlock(mask_type='B', filters=64, kernel_size=3)([v, h])\n",
- "\n",
- "x = keras.layers.Activation(activation='relu')(h)\n",
- "x = keras.layers.Conv2D(filters=128, kernel_size=1, strides=1)(x)\n",
- "\n",
- "x = keras.layers.Activation(activation='relu')(x)\n",
- "x = keras.layers.Conv2D(filters=n_channel * q_levels, kernel_size=1, strides=1)(x)\n",
- "\n",
- "pixelcnn = tf.keras.Model(inputs=inputs, outputs=x)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "_LnzHUaqV77d"
- },
- "outputs": [],
- "source": [
- "# --------------------------------------------------------------------------------------------------------------\n",
- "# Prepare optimizer and loss function\n",
- "lr_decay = 0.9999\n",
- "learning_rate = 5e-2\n",
- "optimizer = tf.keras.optimizers.Adam(lr=learning_rate)\n",
- "\n",
- "compute_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "CsAgEKVzLCJD"
- },
- "outputs": [],
- "source": [
- "# --------------------------------------------------------------------------------------------------------------\n",
- "@tf.function\n",
- "def train_step(batch_x, batch_y):\n",
- " with tf.GradientTape() as ae_tape:\n",
- " logits = pixelcnn(batch_x, training=True)\n",
- "\n",
- " logits = tf.reshape(logits, [-1, height, width, q_levels, n_channel])\n",
- " logits = tf.transpose(logits, perm=[0, 1, 2, 4, 3])\n",
- "\n",
- " loss = compute_loss(tf.one_hot(batch_y, q_levels), logits)\n",
- "\n",
- " gradients = ae_tape.gradient(loss, pixelcnn.trainable_variables)\n",
- " gradients, _ = tf.clip_by_global_norm(gradients, 1.0)\n",
- " optimizer.apply_gradients(zip(gradients, pixelcnn.trainable_variables))\n",
- "\n",
- " return loss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 1000
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "NN6vJl7eVnZ4",
+ "colab": {}
+ },
+ "source": [
+ "# Defining random seeds\n",
+ "random_seed = 42\n",
+ "tf.random.set_seed(random_seed)\n",
+ "np.random.seed(random_seed)\n",
+ "rn.seed(random_seed)"
+ ],
+ "execution_count": 0,
+ "outputs": []
},
- "colab_type": "code",
- "id": "NoEPrfwQNM-s",
- "outputId": "cd806f66-acff-48fe-da47-d75a742acf99"
- },
- "outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "EPOCH 0: ITER 0/ 469 TIME: 2.03 LOSS: 0.6414\n",
- "EPOCH 0: ITER 100/ 469 TIME: 0.13 LOSS: 0.9549\n",
- "EPOCH 0: ITER 200/ 469 TIME: 0.13 LOSS: 0.8991\n",
- "EPOCH 0: ITER 300/ 469 TIME: 0.13 LOSS: 0.8518\n",
- "EPOCH 0: ITER 400/ 469 TIME: 0.13 LOSS: 0.8543\n",
- "EPOCH 0: TIME: 64.81 ETA: 12962.69\n",
- "EPOCH 1: ITER 0/ 469 TIME: 0.02 LOSS: 0.9295\n",
- "EPOCH 1: ITER 100/ 469 TIME: 0.13 LOSS: 0.8955\n",
- "EPOCH 1: ITER 200/ 469 TIME: 0.13 LOSS: 0.9063\n",
- "EPOCH 1: ITER 300/ 469 TIME: 0.13 LOSS: 0.8343\n",
- "EPOCH 1: ITER 400/ 469 TIME: 0.13 LOSS: 0.8433\n",
- "EPOCH 1: TIME: 62.16 ETA: 12369.32\n",
- "EPOCH 2: ITER 0/ 469 TIME: 0.02 LOSS: 0.8575\n",
- "EPOCH 2: ITER 100/ 469 TIME: 0.13 LOSS: 0.8659\n",
- "EPOCH 2: ITER 200/ 469 TIME: 0.13 LOSS: 0.8498\n",
- "EPOCH 2: ITER 300/ 469 TIME: 0.13 LOSS: 0.8268\n"
- ]
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "8BnkhgCjVpJu",
+ "colab": {}
+ },
+ "source": [
+ "# Loading data\n",
+ "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
+ "\n",
+ "height = 28\n",
+ "width = 28\n",
+ "n_channel = 1\n",
+ "\n",
+ "x_train = x_train.astype('float32') / 255.\n",
+ "x_test = x_test.astype('float32') / 255.\n",
+ "\n",
+ "x_train = x_train.reshape(x_train.shape[0], height, width, n_channel)\n",
+ "x_test = x_test.reshape(x_test.shape[0], height, width, n_channel)"
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "ename": "KeyboardInterrupt",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi_iter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mbatch_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_y\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_dataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mstart\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlr\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mlr_decay\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0miter_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/media/kcl_1/HDD/PycharmProjects/vq-vae/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py\u001b[0m in \u001b[0;36m_run_op\u001b[0;34m(a, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1077\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1078\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1079\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtensor_oper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1080\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1081\u001b[0m \u001b[0mfunctools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_wrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_run_op\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_oper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/media/kcl_1/HDD/PycharmProjects/vq-vae/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/math_ops.py\u001b[0m in \u001b[0;36mbinary_op_wrapper\u001b[0;34m(x, y)\u001b[0m\n\u001b[1;32m 910\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 911\u001b[0m \u001b[0;32mraise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 912\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 913\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 914\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbinary_op_wrapper_sparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msp_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/media/kcl_1/HDD/PycharmProjects/vq-vae/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/math_ops.py\u001b[0m in \u001b[0;36m_mul_dispatch\u001b[0;34m(x, y, name)\u001b[0m\n\u001b[1;32m 1204\u001b[0m \u001b[0mis_tensor_y\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1205\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_tensor_y\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1206\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mgen_math_ops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1207\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1208\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparse_tensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSparseTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Case: Dense * Sparse.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/media/kcl_1/HDD/PycharmProjects/vq-vae/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_math_ops.py\u001b[0m in \u001b[0;36mmul\u001b[0;34m(x, y, name)\u001b[0m\n\u001b[1;32m 6683\u001b[0m _result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n\u001b[1;32m 6684\u001b[0m \u001b[0m_ctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_context_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_ctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_thread_local_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Mul\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 6685\u001b[0;31m name, _ctx._post_execution_callbacks, x, y)\n\u001b[0m\u001b[1;32m 6686\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_result\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6687\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0m_core\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_FallbackException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
- ]
- }
- ],
- "source": [
- "# --------------------------------------------------------------------------------------------------------------\n",
- "# Training loop\n",
- "n_epochs = 200\n",
- "n_iter = int(np.ceil(x_train_quantised.shape[0] / batch_size))\n",
- "for epoch in range(n_epochs):\n",
- " start_epoch = time.time()\n",
- " for i_iter, (batch_x, batch_y) in enumerate(train_dataset):\n",
- " start = time.time()\n",
- " optimizer.lr = optimizer.lr * lr_decay\n",
- " loss = train_step(batch_x, batch_y)\n",
- " iter_time = time.time() - start\n",
- " if i_iter % 100 == 0:\n",
- " print('EPOCH {:3d}: ITER {:4d}/{:4d} TIME: {:.2f} LOSS: {:.4f}'.format(epoch,\n",
- " i_iter, n_iter,\n",
- " iter_time,\n",
- " loss))\n",
- " epoch_time = time.time() - start_epoch\n",
- " print('EPOCH {:3d}: TIME: {:.2f} ETA: {:.2f}'.format(epoch,\n",
- " epoch_time,\n",
- " epoch_time * (n_epochs - epoch)))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 52
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "3ne-qY7JVZaB",
+ "colab": {}
+ },
+ "source": [
+ "def quantise(images, q_levels):\n",
+ " \"\"\"Quantise image into q levels\"\"\"\n",
+ " return (np.digitize(images, np.arange(q_levels) / q_levels) - 1).astype('float32')"
+ ],
+ "execution_count": 0,
+ "outputs": []
},
- "colab_type": "code",
- "id": "ue0vZbitSNmz",
- "outputId": "0eb765cc-3a9a-4354-cd38-6316807e3748"
- },
- "outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "nll : 0.6348000764846802 nats\n",
- "bits/dim : 0.0008096939751080104\n"
- ]
- }
- ],
- "source": [
- "# --------------------------------------------------------------------------------------------------------------\n",
- "# Test\n",
- "test_loss = []\n",
- "for batch_x, batch_y in test_dataset:\n",
- " logits = pixelcnn(batch_x, training=False)\n",
- " logits = tf.reshape(logits, [-1, height, width, q_levels, n_channel])\n",
- " logits = tf.transpose(logits, perm=[0, 1, 2, 4, 3])\n",
- "\n",
- " # Calculate cross-entropy (= negative log-likelihood)\n",
- " loss = compute_loss(tf.one_hot(batch_y, q_levels), logits)\n",
- "\n",
- " test_loss.append(loss)\n",
- "print('nll : {:} nats'.format(np.array(test_loss).mean()))\n",
- "print('bits/dim : {:}'.format(np.array(test_loss).mean() / (height * width)))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 526
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "3QVhnMymVrzc",
+ "colab": {}
+ },
+ "source": [
+ "# Quantise the input data in q levels\n",
+ "q_levels = 2\n",
+ "x_train_quantised = quantise(x_train, q_levels)\n",
+ "x_test_quantised = quantise(x_test, q_levels)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "ZObIXqzNGwmo",
+ "colab": {}
+ },
+ "source": [
+ "# Creating input stream using tf.data API\n",
+ "batch_size = 128\n",
+ "train_buf = 60000\n",
+ "\n",
+ "train_dataset = tf.data.Dataset.from_tensor_slices((x_train_quantised / (q_levels - 1),\n",
+ " x_train_quantised.astype('int32')))\n",
+ "train_dataset = train_dataset.shuffle(buffer_size=train_buf)\n",
+ "train_dataset = train_dataset.batch(batch_size)\n",
+ "\n",
+ "test_dataset = tf.data.Dataset.from_tensor_slices((x_test_quantised / (q_levels - 1),\n",
+ " x_test_quantised.astype('int32')))\n",
+ "test_dataset = test_dataset.batch(batch_size)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "75VTDkK8VZLA",
+ "colab": {}
+ },
+ "source": [
+ "class MaskedConv2D(keras.layers.Layer):\n",
+ " \"\"\"Convolutional layers with masks for Gated PixelCNN.\n",
+ "\n",
+ " Masked convolutional layers used to implement Vertical and Horizontal\n",
+ " stacks of the Gated PixelCNN.\n",
+ "\n",
+ " Note: This implementation is different from the normal PixelCNN.\n",
+ "\n",
+ " Arguments:\n",
+ " mask_type: one of `\"V\"`, `\"A\"` or `\"B\".`\n",
+ " filters: Integer, the dimensionality of the output space\n",
+ " (i.e. the number of output filters in the convolution).\n",
+ " kernel_size: An integer or tuple/list of 2 integers, specifying the\n",
+ " height and width of the 2D convolution window.\n",
+ " Can be a single integer to specify the same value for\n",
+ " all spatial dimensions.\n",
+ " strides: An integer or tuple/list of 2 integers,\n",
+ " specifying the strides of the convolution along the height and width.\n",
+ " Can be a single integer to specify the same value for\n",
+ " all spatial dimensions.\n",
+ " Specifying any stride value != 1 is incompatible with specifying\n",
+ " any `dilation_rate` value != 1.\n",
+ " padding: one of `\"valid\"` or `\"same\"` (case-insensitive).\n",
+ " kernel_initializer: Initializer for the `kernel` weights matrix.\n",
+ " bias_initializer: Initializer for the bias vector.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self,\n",
+ " mask_type,\n",
+ " filters,\n",
+ " kernel_size,\n",
+ " strides=1,\n",
+ " padding='same',\n",
+ " kernel_initializer='glorot_uniform',\n",
+ " bias_initializer='zeros'):\n",
+ " super(MaskedConv2D, self).__init__()\n",
+ "\n",
+ " assert mask_type in {'A', 'B', 'V'}\n",
+ " self.mask_type = mask_type\n",
+ "\n",
+ " self.filters = filters\n",
+ "\n",
+ " if isinstance(kernel_size, int):\n",
+ " kernel_size = (kernel_size, kernel_size)\n",
+ " self.kernel_size = kernel_size\n",
+ "\n",
+ " self.strides = strides\n",
+ " self.padding = padding.upper()\n",
+ " self.kernel_initializer = initializers.get(kernel_initializer)\n",
+ " self.bias_initializer = initializers.get(bias_initializer)\n",
+ "\n",
+ " def build(self, input_shape):\n",
+ " kernel_h, kernel_w = self.kernel_size\n",
+ "\n",
+ " self.kernel = self.add_weight('kernel',\n",
+ " shape=(kernel_h,\n",
+ " kernel_w,\n",
+ " int(input_shape[-1]),\n",
+ " self.filters),\n",
+ " initializer=self.kernel_initializer,\n",
+ " trainable=True)\n",
+ "\n",
+ " self.bias = self.add_weight('bias',\n",
+ " shape=(self.filters,),\n",
+ " initializer=self.bias_initializer,\n",
+ " trainable=True)\n",
+ "\n",
+ " mask = np.ones(self.kernel.shape, dtype=np.float32)\n",
+ "\n",
+ " if kernel_h % 2 != 0: \n",
+ " center_h = kernel_h // 2\n",
+ " else:\n",
+ " center_h = (kernel_h - 1) // 2\n",
+ "\n",
+ " if kernel_w % 2 != 0: \n",
+ " center_w = kernel_w // 2\n",
+ " else:\n",
+ " center_w = (kernel_w - 1) // 2\n",
+ "\n",
+ " if self.mask_type == 'V':\n",
+ " mask[center_h + 1:, :, :, :] = 0.\n",
+ " else:\n",
+ " mask[:center_h, :, :] = 0.\n",
+ " mask[center_h, center_w + (self.mask_type == 'B'):, :, :] = 0.\n",
+ " mask[center_h + 1:, :, :] = 0. \n",
+ "\n",
+ " self.mask = tf.constant(mask, dtype=tf.float32, name='mask')\n",
+ "\n",
+ " def call(self, input):\n",
+ " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n",
+ " x = nn.conv2d(input,\n",
+ " masked_kernel,\n",
+ " strides=[1, self.strides, self.strides, 1],\n",
+ " padding=self.padding)\n",
+ " x = nn.bias_add(x, self.bias)\n",
+ " return x"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "PTUN4s52Nu3w",
+ "colab": {}
+ },
+ "source": [
+ "class GatedBlock(tf.keras.Model):\n",
+ " \"\"\" Gated block of the Gated PixelCNN.\"\"\"\n",
+ "\n",
+ " def __init__(self, mask_type, filters, kernel_size):\n",
+ " super(GatedBlock, self).__init__(name='')\n",
+ "\n",
+ " self.mask_type = mask_type\n",
+ " self.vertical_conv = MaskedConv2D(mask_type='V',\n",
+ " filters=2 * filters,\n",
+ " kernel_size=kernel_size)\n",
+ " \n",
+ " self.horizontal_conv = MaskedConv2D(mask_type=mask_type,\n",
+ " filters=2 * filters,\n",
+ " kernel_size=kernel_size)\n",
+ "\n",
+ " self.padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n",
+ " self.cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n",
+ "\n",
+ " self.v_to_h_conv = keras.layers.Conv2D(filters=2 * filters, kernel_size=1)\n",
+ "\n",
+ " self.horizontal_output = keras.layers.Conv2D(filters=filters, kernel_size=1)\n",
+ "\n",
+ " def _gate(self, x):\n",
+ " tanh_preactivation, sigmoid_preactivation = tf.split(x, 2, axis=-1)\n",
+ " return tf.nn.tanh(tanh_preactivation) * tf.nn.sigmoid(sigmoid_preactivation)\n",
+ "\n",
+ " def call(self, input_tensor):\n",
+ " v = input_tensor[0]\n",
+ " h = input_tensor[1]\n",
+ "\n",
+ " vertical_preactivation = self.vertical_conv(v) # NxN\n",
+ "\n",
+ " # Shifting feature map down to ensure causality\n",
+ " v_to_h = self.padding(vertical_preactivation)\n",
+ " v_to_h = self.cropping(v_to_h)\n",
+ " v_to_h = self.v_to_h_conv(v_to_h) # 1x1\n",
+ "\n",
+ " horizontal_preactivation = self.horizontal_conv(h) # 1xN\n",
+ " \n",
+ " v_out = self._gate(vertical_preactivation)\n",
+ "\n",
+ " horizontal_preactivation = horizontal_preactivation + v_to_h\n",
+ " h_activated = self._gate(horizontal_preactivation)\n",
+ " h_activated = self.horizontal_output(h_activated)\n",
+ "\n",
+ " if self.mask_type == 'A':\n",
+ " h_out = h_activated\n",
+ " elif self.mask_type == 'B':\n",
+ " h_out = h + h_activated\n",
+ "\n",
+ " return v_out, h_out"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "WB57YufrVxn2",
+ "colab": {}
+ },
+ "source": [
+ "# Create Gated PixelCNN model\n",
+ "inputs = keras.layers.Input(shape=(height, width, n_channel))\n",
+ "v, h = GatedBlock(mask_type='A', filters=64, kernel_size=3)([inputs, inputs])\n",
+ "\n",
+ "for i in range(7):\n",
+ " v, h = GatedBlock(mask_type='B', filters=64, kernel_size=3)([v, h])\n",
+ "\n",
+ "x = keras.layers.Activation(activation='relu')(h)\n",
+ "x = keras.layers.Conv2D(filters=128, kernel_size=1, strides=1)(x)\n",
+ "\n",
+ "x = keras.layers.Activation(activation='relu')(x)\n",
+ "x = keras.layers.Conv2D(filters=q_levels, kernel_size=1, strides=1)(x)\n",
+ "\n",
+ "gated_pixelcnn = tf.keras.Model(inputs=inputs, outputs=x)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "_LnzHUaqV77d",
+ "colab": {}
+ },
+ "source": [
+ "# Prepare optimizer and loss function\n",
+ "lr_decay = 0.999995\n",
+ "learning_rate = 1e-3\n",
+ "optimizer = keras.optimizers.Adam(lr=learning_rate)\n",
+ "\n",
+ "compute_loss = keras.losses.CategoricalCrossentropy(from_logits=True)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "CsAgEKVzLCJD",
+ "colab": {}
+ },
+ "source": [
+ "@tf.function\n",
+ "def train_step(batch_x, batch_y):\n",
+ " with tf.GradientTape() as ae_tape:\n",
+ " logits = gated_pixelcnn(batch_x, training=True)\n",
+ "\n",
+ " loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits)\n",
+ "\n",
+ " gradients = ae_tape.gradient(loss, gated_pixelcnn.trainable_variables)\n",
+ " gradients, _ = tf.clip_by_global_norm(gradients, 1.0)\n",
+ " optimizer.apply_gradients(zip(gradients, gated_pixelcnn.trainable_variables))\n",
+ "\n",
+ " return loss"
+ ],
+ "execution_count": 0,
+ "outputs": []
},
- "colab_type": "code",
- "id": "-Ia9VXYySkuW",
- "outputId": "2bce0fd7-d366-4ed6-dc42-42f519b02567"
- },
- "outputs": [
{
- "data": {
- "image/png": "\n",
- "text/plain": [
- "