diff --git a/WIP/6-gated_pixelcnn_cropped/cropped_gated_pixelcnn.ipynb b/WIP/6-gated_pixelcnn_cropped/cropped_gated_pixelcnn.ipynb index 2b30544..5deefb7 100644 --- a/WIP/6-gated_pixelcnn_cropped/cropped_gated_pixelcnn.ipynb +++ b/WIP/6-gated_pixelcnn_cropped/cropped_gated_pixelcnn.ipynb @@ -1,624 +1,623 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "accelerator": "GPU", - "colab": { - "name": "cropped gated_pixelcnn.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "k1uZnxh4Xz9Z" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "gpu_devices = tf.config.experimental.list_physical_devices('GPU')\n", + "for device in gpu_devices: tf.config.experimental.set_memory_growth(device, True)\n", + "\n", + "import random as rn\n", + "import time\n", + "\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow.keras.utils import Progbar\n", + "from tensorflow import keras\n", + "from tensorflow.keras import initializers\n", + "from tensorflow import nn" + ] }, - "cells": [ - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "k1uZnxh4Xz9Z", - "colab": {} - }, - "source": [ - "import tensorflow as tf\n", - "gpu_devices = tf.config.experimental.list_physical_devices('GPU')\n", - "for device in gpu_devices: tf.config.experimental.set_memory_growth(device, True)\n", - "\n", - "import random as rn\n", - "import time\n", - "\n", - "import matplotlib\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import tensorflow as tf\n", - "from tensorflow.keras.utils import Progbar\n", - "from tensorflow import keras\n", - "from tensorflow.keras import initializers\n", - "from tensorflow import nn" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "NN6vJl7eVnZ4", - "colab": {} - }, - "source": [ - "# Defining random seeds\n", - "random_seed = 42\n", - "tf.random.set_seed(random_seed)\n", - "np.random.seed(random_seed)\n", - "rn.seed(random_seed)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "8BnkhgCjVpJu", - "colab": {} - }, - "source": [ - "# Loading data\n", - "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", - "\n", - "height = 28\n", - "width = 28\n", - "n_channel = 1\n", - "\n", - "x_train = x_train.astype('float32') / 255.\n", - "x_test = x_test.astype('float32') / 255.\n", - "\n", - "x_train = x_train.reshape(x_train.shape[0], height, width, n_channel)\n", - "x_test = x_test.reshape(x_test.shape[0], height, width, n_channel)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "3ne-qY7JVZaB", - "colab": {} - }, - "source": [ - "def quantise(images, q_levels):\n", - " \"\"\"Quantise image into q levels\"\"\"\n", - " return (np.digitize(images, np.arange(q_levels) / q_levels) - 1).astype('float32')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "3QVhnMymVrzc", - "colab": {} - }, - "source": [ - "# Quantise the input data in q levels\n", - "q_levels = 2\n", - "x_train_quantised = quantise(x_train, q_levels)\n", - "x_test_quantised = quantise(x_test, q_levels)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "ZObIXqzNGwmo", - "colab": {} - }, - "source": [ - "# Creating input stream using tf.data API\n", - "batch_size = 192\n", - "train_buf = 10000\n", - "\n", - "train_dataset = tf.data.Dataset.from_tensor_slices((x_train_quantised / (q_levels - 1),\n", - " x_train_quantised.astype('int32')))\n", - "train_dataset = train_dataset.shuffle(buffer_size=train_buf)\n", - "train_dataset = train_dataset.batch(batch_size)\n", - "\n", - "test_dataset = tf.data.Dataset.from_tensor_slices((x_test_quantised / (q_levels - 1),\n", - " x_test_quantised.astype('int32')))\n", - "test_dataset = test_dataset.batch(batch_size)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "75VTDkK8VZLA", - "colab": {} - }, - "source": [ - "class VerticalConv2D(keras.layers.Conv2D):\n", - " \"\"\"https://github.com/JesseFarebro/PixelCNNPP/blob/master/layers/VerticalConv2D.py\"\"\"\n", - " def __init__(self,\n", - " filters,\n", - " kernel_size,\n", - " **kwargs):\n", - " if not isinstance(kernel_size, tuple):\n", - " kernel_size = (kernel_size // 2 + 1, kernel_size)\n", - "\n", - " super(VerticalConv2D, self).__init__(filters, kernel_size, **kwargs)\n", - "\n", - " self.pad = tf.keras.layers.ZeroPadding2D(\n", - " (\n", - " (kernel_size[0] - 1, 0), # Top, Bottom\n", - " (kernel_size[1] // 2, kernel_size[1] // 2), # Left, Right\n", - " )\n", - " )\n", - "\n", - " def call(self, inputs):\n", - " inputs = self.pad(inputs)\n", - " output = super(VerticalConv2D, self).call(inputs)\n", - "\n", - " return output\n", - "\n", - "\n", - "class HorizontalConv2D(keras.layers.Conv2D):\n", - " def __init__(self,\n", - " filters,\n", - " kernel_size,\n", - " **kwargs):\n", - "\n", - " if not isinstance(kernel_size, tuple):\n", - " kernel_size = (kernel_size // 2 + 1,) * 2\n", - "\n", - " super(HorizontalConv2D, self).__init__(filters, kernel_size, **kwargs)\n", - " self.pad = tf.keras.layers.ZeroPadding2D(\n", - " (\n", - " (kernel_size[0] - 1, 0), # (Top, Bottom)\n", - " (kernel_size[1] - 1, 0), # (Left, Right)\n", - " )\n", - " )\n", - "\n", - " def call(self, inputs):\n", - " inputs = self.pad(inputs)\n", - " outputs = super(HorizontalConv2D, self).call(inputs)\n", - "\n", - " return outputs" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "g625fSNYR9I6", - "colab_type": "code", - "colab": {} - }, - "source": [ - "filters = 1\n", - "kernel_size = 3\n", - "vertical_conv = VerticalConv2D(filters=2 * filters,\n", - " kernel_size=kernel_size)" - ], - "execution_count": 0, - "outputs": [] + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "NN6vJl7eVnZ4" + }, + "outputs": [], + "source": [ + "# Defining random seeds\n", + "random_seed = 42\n", + "tf.random.set_seed(random_seed)\n", + "np.random.seed(random_seed)\n", + "rn.seed(random_seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "8BnkhgCjVpJu" + }, + "outputs": [], + "source": [ + "# Loading data\n", + "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", + "\n", + "height = 28\n", + "width = 28\n", + "n_channel = 1\n", + "\n", + "x_train = x_train.astype('float32') / 255.\n", + "x_test = x_test.astype('float32') / 255.\n", + "\n", + "x_train = x_train.reshape(x_train.shape[0], height, width, n_channel)\n", + "x_test = x_test.reshape(x_test.shape[0], height, width, n_channel)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3ne-qY7JVZaB" + }, + "outputs": [], + "source": [ + "def quantise(images, q_levels):\n", + " \"\"\"Quantise image into q levels\"\"\"\n", + " return (np.digitize(images, np.arange(q_levels) / q_levels) - 1).astype('float32')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3QVhnMymVrzc" + }, + "outputs": [], + "source": [ + "# Quantise the input data in q levels\n", + "q_levels = 2\n", + "x_train_quantised = quantise(x_train, q_levels)\n", + "x_test_quantised = quantise(x_test, q_levels)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ZObIXqzNGwmo" + }, + "outputs": [], + "source": [ + "# Creating input stream using tf.data API\n", + "batch_size = 192\n", + "train_buf = 10000\n", + "\n", + "train_dataset = tf.data.Dataset.from_tensor_slices((x_train_quantised / (q_levels - 1),\n", + " x_train_quantised.astype('int32')))\n", + "train_dataset = train_dataset.shuffle(buffer_size=train_buf)\n", + "train_dataset = train_dataset.batch(batch_size)\n", + "\n", + "test_dataset = tf.data.Dataset.from_tensor_slices((x_test_quantised / (q_levels - 1),\n", + " x_test_quantised.astype('int32')))\n", + "test_dataset = test_dataset.batch(batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "75VTDkK8VZLA" + }, + "outputs": [], + "source": [ + "class VerticalConv2D(keras.layers.Conv2D):\n", + " \"\"\"https://github.com/JesseFarebro/PixelCNNPP/blob/master/layers/VerticalConv2D.py\"\"\"\n", + " def __init__(self,\n", + " filters,\n", + " kernel_size,\n", + " **kwargs):\n", + " if not isinstance(kernel_size, tuple):\n", + " kernel_size = (kernel_size // 2 + 1, kernel_size)\n", + "\n", + " super(VerticalConv2D, self).__init__(filters, kernel_size, **kwargs)\n", + "\n", + " self.pad = tf.keras.layers.ZeroPadding2D(\n", + " (\n", + " (kernel_size[0] - 1, 0), # Top, Bottom\n", + " (kernel_size[1] // 2, kernel_size[1] // 2), # Left, Right\n", + " )\n", + " )\n", + "\n", + " def call(self, inputs):\n", + " inputs = self.pad(inputs)\n", + " output = super(VerticalConv2D, self).call(inputs)\n", + "\n", + " return output\n", + "\n", + "\n", + "class HorizontalConv2D(keras.layers.Conv2D):\n", + " def __init__(self,\n", + " filters,\n", + " kernel_size,\n", + " **kwargs):\n", + "\n", + " if not isinstance(kernel_size, tuple):\n", + " kernel_size = (kernel_size // 2 + 1,) * 2\n", + "\n", + " super(HorizontalConv2D, self).__init__(filters, kernel_size, **kwargs)\n", + " self.pad = tf.keras.layers.ZeroPadding2D(\n", + " (\n", + " (kernel_size[0] - 1, 0), # (Top, Bottom)\n", + " (kernel_size[1] - 1, 0), # (Left, Right)\n", + " )\n", + " )\n", + "\n", + " def call(self, inputs):\n", + " inputs = self.pad(inputs)\n", + " outputs = super(HorizontalConv2D, self).call(inputs)\n", + "\n", + " return outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "g625fSNYR9I6" + }, + "outputs": [], + "source": [ + "filters = 1\n", + "kernel_size = 3\n", + "vertical_conv = VerticalConv2D(filters=2 * filters,\n", + " kernel_size=kernel_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 }, + "colab_type": "code", + "id": "DqsfC9JnS8C8", + "outputId": "ef53f88a-c437-4aa4-b40f-0184e140b1b3" + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "DqsfC9JnS8C8", - "colab_type": "code", - "outputId": "ef53f88a-c437-4aa4-b40f-0184e140b1b3", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - } - }, - "source": [ - "vertical_conv.kernel_size\n" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(2, 3)" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 9 - } + "data": { + "text/plain": [ + "(2, 3)" ] + }, + "execution_count": 9, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "vertical_conv.kernel_size\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "PTUN4s52Nu3w" + }, + "outputs": [], + "source": [ + "class GatedBlock(tf.keras.Model):\n", + " \"\"\" Gated block of the Gated PixelCNN.\"\"\"\n", + "\n", + " def __init__(self,\n", + " mask_type,\n", + " filters,\n", + " kernel_size):\n", + " super(GatedBlock, self).__init__(name='')\n", + "\n", + " self.mask_type = mask_type\n", + " self.vertical_conv = VerticalConv2D(filters=2 * filters,\n", + " kernel_size=kernel_size)\n", + " \n", + "\n", + " if mask_type =='A':\n", + " self.horizontal_conv = keras.layers.Conv2D(filters=2 * filters, \n", + " kernel_size=1)\n", + "\n", + " else: \n", + " self.horizontal_conv = HorizontalConv2D(filters=2 * filters,\n", + " kernel_size=kernel_size)\n", + "\n", + " self.padding_A = keras.layers.ZeroPadding2D(padding=(0, (1,0)))\n", + " self.cropping_A = keras.layers.Cropping2D(cropping=(0, (0, 1)))\n", + "\n", + " self.padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n", + " self.cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", + "\n", + " self.v_to_h_conv = keras.layers.Conv2D(filters=2 * filters, kernel_size=1)\n", + "\n", + " self.horizontal_output = keras.layers.Conv2D(filters=filters, kernel_size=1)\n", + "\n", + " def _gate(self, x):\n", + " tanh_preactivation, sigmoid_preactivation = tf.split(x, 2, axis=-1)\n", + " return tf.nn.tanh(tanh_preactivation) * tf.nn.sigmoid(sigmoid_preactivation)\n", + "\n", + " def call(self, input_tensor):\n", + " v = input_tensor[0]\n", + " h = input_tensor[1]\n", + "\n", + " vertical_preactivation = self.vertical_conv(v) # NxN\n", + "\n", + " # Shifting feature map down to ensure causality\n", + " v_to_h = self.padding(vertical_preactivation)\n", + " v_to_h = self.cropping(v_to_h)\n", + " v_to_h = self.v_to_h_conv(v_to_h) # 1x1\n", + "\n", + " horizontal_preactivation = self.horizontal_conv(h) # 1xN\n", + " if self.mask_type == 'A':\n", + " horizontal_preactivation = self.padding_A(horizontal_preactivation)\n", + " horizontal_preactivation = self.cropping_A(horizontal_preactivation)\n", + " \n", + " \n", + " v_out = self._gate(vertical_preactivation)\n", + "\n", + " horizontal_preactivation = horizontal_preactivation + v_to_h\n", + " h_activated = self._gate(horizontal_preactivation)\n", + " h_activated = self.horizontal_output(h_activated)\n", + "\n", + " if self.mask_type == 'A':\n", + " h_out = h_activated\n", + " elif self.mask_type == 'B':\n", + " h_out = h + h_activated\n", + "\n", + " return v_out, h_out" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WB57YufrVxn2" + }, + "outputs": [], + "source": [ + "# Create Gated PixelCNN model\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=64, kernel_size=3)([inputs, inputs])\n", + "\n", + "for i in range(7):\n", + " v, h = GatedBlock(mask_type='B', filters=64, kernel_size=3)([v, h])\n", + "\n", + "x = keras.layers.Activation(activation='relu')(h)\n", + "x = keras.layers.Conv2D(filters=128, kernel_size=1, strides=1)(x)\n", + "\n", + "x = keras.layers.Activation(activation='relu')(x)\n", + "x = keras.layers.Conv2D(filters=q_levels, kernel_size=1, strides=1)(x)\n", + "\n", + "gated_pixelcnn = tf.keras.Model(inputs=inputs, outputs=x)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "_LnzHUaqV77d" + }, + "outputs": [], + "source": [ + "# Prepare optimizer and loss function\n", + "lr_decay = 0.999995\n", + "learning_rate = 1e-3\n", + "optimizer = keras.optimizers.Adam(lr=learning_rate)\n", + "\n", + "compute_loss = keras.losses.CategoricalCrossentropy(from_logits=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "CsAgEKVzLCJD" + }, + "outputs": [], + "source": [ + "@tf.function\n", + "def train_step(batch_x, batch_y):\n", + " with tf.GradientTape() as ae_tape:\n", + " logits = gated_pixelcnn(batch_x, training=True)\n", + "\n", + " loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits)\n", + "\n", + " gradients = ae_tape.gradient(loss, gated_pixelcnn.trainable_variables)\n", + " gradients, _ = tf.clip_by_global_norm(gradients, 1.0)\n", + " optimizer.apply_gradients(zip(gradients, gated_pixelcnn.trainable_variables))\n", + "\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 476 }, + "colab_type": "code", + "id": "NoEPrfwQNM-s", + "outputId": "4a9422c4-8ce1-4310-8783-882be0c7c924" + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "PTUN4s52Nu3w", - "colab": {} - }, - "source": [ - "class GatedBlock(tf.keras.Model):\n", - " \"\"\" Gated block of the Gated PixelCNN.\"\"\"\n", - "\n", - " def __init__(self,\n", - " mask_type,\n", - " filters,\n", - " kernel_size):\n", - " super(GatedBlock, self).__init__(name='')\n", - "\n", - " self.mask_type = mask_type\n", - " self.vertical_conv = VerticalConv2D(filters=2 * filters,\n", - " kernel_size=kernel_size)\n", - " \n", - "\n", - " if mask_type =='A':\n", - " self.horizontal_conv = keras.layers.Conv2D(filters=2 * filters, \n", - " kernel_size=1)\n", - "\n", - " else: \n", - " self.horizontal_conv = HorizontalConv2D(filters=2 * filters,\n", - " kernel_size=kernel_size)\n", - "\n", - " self.padding_A = keras.layers.ZeroPadding2D(padding=(0, (1,0)))\n", - " self.cropping_A = keras.layers.Cropping2D(cropping=(0, (0, 1)))\n", - "\n", - " self.padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n", - " self.cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", - "\n", - " self.v_to_h_conv = keras.layers.Conv2D(filters=2 * filters, kernel_size=1)\n", - "\n", - " self.horizontal_output = keras.layers.Conv2D(filters=filters, kernel_size=1)\n", - "\n", - " def _gate(self, x):\n", - " tanh_preactivation, sigmoid_preactivation = tf.split(x, 2, axis=-1)\n", - " return tf.nn.tanh(tanh_preactivation) * tf.nn.sigmoid(sigmoid_preactivation)\n", - "\n", - " def call(self, input_tensor):\n", - " v = input_tensor[0]\n", - " h = input_tensor[1]\n", - "\n", - " vertical_preactivation = self.vertical_conv(v) # NxN\n", - "\n", - " # Shifting feature map down to ensure causality\n", - " v_to_h = self.padding(vertical_preactivation)\n", - " v_to_h = self.cropping(v_to_h)\n", - " v_to_h = self.v_to_h_conv(v_to_h) # 1x1\n", - "\n", - " horizontal_preactivation = self.horizontal_conv(h) # 1xN\n", - " if self.mask_type == 'A':\n", - " horizontal_preactivation = self.padding_A(horizontal_preactivation)\n", - " horizontal_preactivation = self.cropping_A(horizontal_preactivation)\n", - " \n", - " \n", - " v_out = self._gate(vertical_preactivation)\n", - "\n", - " horizontal_preactivation = horizontal_preactivation + v_to_h\n", - " h_activated = self._gate(horizontal_preactivation)\n", - " h_activated = self.horizontal_output(h_activated)\n", - "\n", - " if self.mask_type == 'A':\n", - " h_out = h_activated\n", - " elif self.mask_type == 'B':\n", - " h_out = h + h_activated\n", - "\n", - " return v_out, h_out" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "WB57YufrVxn2", - "colab": {} - }, - "source": [ - "# Create Gated PixelCNN model\n", - "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", - "v, h = GatedBlock(mask_type='A', filters=64, kernel_size=3)([inputs, inputs])\n", - "\n", - "for i in range(7):\n", - " v, h = GatedBlock(mask_type='B', filters=64, kernel_size=3)([v, h])\n", - "\n", - "x = keras.layers.Activation(activation='relu')(h)\n", - "x = keras.layers.Conv2D(filters=128, kernel_size=1, strides=1)(x)\n", - "\n", - "x = keras.layers.Activation(activation='relu')(x)\n", - "x = keras.layers.Conv2D(filters=q_levels, kernel_size=1, strides=1)(x)\n", - "\n", - "gated_pixelcnn = tf.keras.Model(inputs=inputs, outputs=x)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "_LnzHUaqV77d", - "colab": {} - }, - "source": [ - "# Prepare optimizer and loss function\n", - "lr_decay = 0.999995\n", - "learning_rate = 1e-3\n", - "optimizer = keras.optimizers.Adam(lr=learning_rate)\n", - "\n", - "compute_loss = keras.losses.CategoricalCrossentropy(from_logits=True)" - ], - "execution_count": 0, - "outputs": [] + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/50\n", + "313/313 [==============================] - 167s 533ms/step - loss: 0.1132\n", + "Epoch 2/50\n", + "313/313 [==============================] - 160s 512ms/step - loss: 0.0888\n", + "Epoch 3/50\n", + "188/313 [=================>............]" + ] }, { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "CsAgEKVzLCJD", - "colab": {} - }, - "source": [ - "@tf.function\n", - "def train_step(batch_x, batch_y):\n", - " with tf.GradientTape() as ae_tape:\n", - " logits = gated_pixelcnn(batch_x, training=True)\n", - "\n", - " loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits)\n", - "\n", - " gradients = ae_tape.gradient(loss, gated_pixelcnn.trainable_variables)\n", - " gradients, _ = tf.clip_by_global_norm(gradients, 1.0)\n", - " optimizer.apply_gradients(zip(gradients, gated_pixelcnn.trainable_variables))\n", - "\n", - " return loss" - ], - "execution_count": 0, - "outputs": [] + "ename": "KeyboardInterrupt", + "evalue": "ignored", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mprogbar\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py\u001b[0m in \u001b[0;36madd\u001b[0;34m(self, n, values)\u001b[0m\n\u001b[1;32m 676\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 677\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 678\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_seen_so_far\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 679\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 680\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, current, values, finalize)\u001b[0m\n\u001b[1;32m 638\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m' - %s:'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 640\u001b[0;31m \u001b[0mavg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 641\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mavg\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1e-3\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 642\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m' %.4f'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mavg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36mmean\u001b[0;34m(*args, **kwargs)\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36mmean\u001b[0;34m(a, axis, dtype, out, keepdims)\u001b[0m\n\u001b[1;32m 3333\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3334\u001b[0m return _methods._mean(a, axis=axis, dtype=dtype,\n\u001b[0;32m-> 3335\u001b[0;31m out=out, **kwargs)\n\u001b[0m\u001b[1;32m 3336\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3337\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/_methods.py\u001b[0m in \u001b[0;36m_mean\u001b[0;34m(a, axis, dtype, out, keepdims)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_mean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0marr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0masanyarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0mis_float16_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/_asarray.py\u001b[0m in \u001b[0;36masanyarray\u001b[0;34m(a, dtype, order)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \"\"\"\n\u001b[0;32m--> 138\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msubok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 139\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# Training loop\n", + "n_epochs = 50\n", + "n_iter = int(np.ceil(x_train_quantised.shape[0] / batch_size))\n", + "for epoch in range(n_epochs):\n", + " progbar = Progbar(n_iter)\n", + " print('Epoch {:}/{:}'.format(epoch + 1, n_epochs))\n", + "\n", + " for i_iter, (batch_x, batch_y) in enumerate(train_dataset):\n", + " optimizer.lr = optimizer.lr * lr_decay\n", + " loss = train_step(batch_x, batch_y)\n", + "\n", + " progbar.add(1, values=[('loss', loss)])" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 }, + "colab_type": "code", + "id": "ue0vZbitSNmz", + "outputId": "42de5250-7404-4fb2-daff-e5283be9f53e" + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "NoEPrfwQNM-s", - "outputId": "4a9422c4-8ce1-4310-8783-882be0c7c924", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 476 - } - }, - "source": [ - "# Training loop\n", - "n_epochs = 50\n", - "n_iter = int(np.ceil(x_train_quantised.shape[0] / batch_size))\n", - "for epoch in range(n_epochs):\n", - " progbar = Progbar(n_iter)\n", - " print('Epoch {:}/{:}'.format(epoch + 1, n_epochs))\n", - "\n", - " for i_iter, (batch_x, batch_y) in enumerate(train_dataset):\n", - " optimizer.lr = optimizer.lr * lr_decay\n", - " loss = train_step(batch_x, batch_y)\n", - "\n", - " progbar.add(1, values=[('loss', loss)])" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Epoch 1/50\n", - "313/313 [==============================] - 167s 533ms/step - loss: 0.1132\n", - "Epoch 2/50\n", - "313/313 [==============================] - 160s 512ms/step - loss: 0.0888\n", - "Epoch 3/50\n", - "188/313 [=================>............]" - ], - "name": "stdout" - }, - { - "output_type": "error", - "ename": "KeyboardInterrupt", - "evalue": "ignored", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mprogbar\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py\u001b[0m in \u001b[0;36madd\u001b[0;34m(self, n, values)\u001b[0m\n\u001b[1;32m 676\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 677\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 678\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_seen_so_far\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 679\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 680\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, current, values, finalize)\u001b[0m\n\u001b[1;32m 638\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m' - %s:'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 640\u001b[0;31m \u001b[0mavg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 641\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mavg\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1e-3\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 642\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m' %.4f'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mavg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36mmean\u001b[0;34m(*args, **kwargs)\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36mmean\u001b[0;34m(a, axis, dtype, out, keepdims)\u001b[0m\n\u001b[1;32m 3333\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3334\u001b[0m return _methods._mean(a, axis=axis, dtype=dtype,\n\u001b[0;32m-> 3335\u001b[0;31m out=out, **kwargs)\n\u001b[0m\u001b[1;32m 3336\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3337\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/_methods.py\u001b[0m in \u001b[0;36m_mean\u001b[0;34m(a, axis, dtype, out, keepdims)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_mean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0marr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0masanyarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0mis_float16_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/_asarray.py\u001b[0m in \u001b[0;36masanyarray\u001b[0;34m(a, dtype, order)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \"\"\"\n\u001b[0;32m--> 138\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msubok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 139\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "nll : 0.08501224964857101 nats\n", + "bits/dim : 0.12264675098280793\n" + ] + } + ], + "source": [ + "# Test set performance\n", + "test_loss = []\n", + "for batch_x, batch_y in test_dataset:\n", + " logits = gated_pixelcnn(batch_x, training=False)\n", + "\n", + " # Calculate cross-entropy (= negative log-likelihood)\n", + " loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits)\n", + "\n", + " test_loss.append(loss)\n", + "print('nll : {:} nats'.format(np.array(test_loss).mean()))\n", + "print('bits/dim : {:}'.format(np.array(test_loss).mean() / np.log(2)))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 581 }, + "colab_type": "code", + "id": "-Ia9VXYySkuW", + "outputId": "ad1075b1-70cc-4012-851e-925cae32ac2c" + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "ue0vZbitSNmz", - "outputId": "42de5250-7404-4fb2-daff-e5283be9f53e", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 51 - } - }, - "source": [ - "# Test set performance\n", - "test_loss = []\n", - "for batch_x, batch_y in test_dataset:\n", - " logits = gated_pixelcnn(batch_x, training=False)\n", - "\n", - " # Calculate cross-entropy (= negative log-likelihood)\n", - " loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits)\n", - "\n", - " test_loss.append(loss)\n", - "print('nll : {:} nats'.format(np.array(test_loss).mean()))\n", - "print('bits/dim : {:}'.format(np.array(test_loss).mean() / np.log(2)))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "nll : 0.08501224964857101 nats\n", - "bits/dim : 0.12264675098280793\n" - ], - "name": "stdout" - } + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj4AAAI0CAYAAAAdqSPKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3dW3LkOLKtYfDYGYLyuWIOrfmPQJqD6rk0B+6HzOimongBQFyWu/+fWZvtnRWS4ATIcDpAcFnXNQEAAETw/2Y3AAAAYBQSHwAAEAaJDwAACIPEBwAAhEHiAwAAwiDxAQAAYZD4AACAMEh8AABAGCQ+AAAgjP9f8uG3t7f18Xh0ako/X19f6fv7e7n6nNX4Ukrp8/Pze13XX1efsxpjbh+mRIzKvJ+L9OFPxKgr8rlYlPg8Ho/08fHRplUDvb+/Z33OanwppbQsy985n7MaY24fpkSMyryfi/ThT8SoK/K5yFQXAAAIg8QHAACEQeIDAADCIPEBAABhFC1ubmVZshbLn1rXtUFLgGN745RxB2AWrkltdE98WiQ5Ob93VufnxOdxYPboV6XjdBTf89+V2opjJeOUPoUqEp62mOoCAABhTJnq8iL3bnJZFhfZea/q3fb3ezhOmKv3OAVGYSz3QcUHAACE0bXi4zVbzZ1vtRb/7PYqVHtmHwMrtsdJod+OKLcNOMO1qJ8uiU+kDsu9sCovih3RX4px41zJVG5K8/s40nUHx47Gwezx2YKHGJ5Kz9eWsTPVBQAAwmBx8yDKmfpZ266ycuW4SuTefXiIN0plxENf4VjpOFapTF7xen7ejavl9DoVHwAAEEbTio/XTLWUp+OgfnfUQoRqT88xqXBcPJ1zOEY/23RUrantz7vVO/NTXbMvupyIdkVIeO54jZuxrqN1XyiPccadfUfjq2TctRwHTHUBAIAwTFZ8Zt+dPP9+SQaaU5rjfSw6vB13b/FEFan6ESnWK5y/bVHxAQAAYZio+Khmu3cWaVnbHA5xnY1VlXG5ruu/2hmhgloTj3Il5U7brq7H3vreu57XHdnEx9ogzX1lRc6JbS12+KL8xXhX9BfhekwIchbhW48Rv7XqR6a6AABAGHIVH0+Zee5UmKeYYZvlR/zv7EDunXoVpMV7m6L3sSe9xysVHwAAEEbTik/NY94RcDygzvsY3au+KlU8cr1eYy1X6FJqM+6opvsxqjJJxQcAAIQht8YnGu5I/LF0F+690nPGU+Xn6L9HZTF+9XVYPY2OvUvi0+IlZJbxyLoNLfrA8viuaTvjVku0/sj5bol2TCyb1YdMdQEAgDC6T3VFe8PzVXxe70aiVA88VPNaPDps1TZ2T3FZlfNADP2E1qj4AACAMFjc3IjnSk+vKp21xaXWFx9aWnTdivU+i6K0Tzz2a7TtYGb2IYnPIK+dbOUkjXISnvFwkY2Q9DBd4l+E61HEsTo6Zqa6AABAGMMqPhEy9RLWpnl6Uj0WVqt0NTzG5jEm/Bv9bNPMfqPiAwAAwuhe8YlS6dlbmJaz2Zbnx2stvkHZw3qeXJ7i8nweRcdGhWita+Kj/iXXw9HJmLNiX3XKJ9fRnk0W4vGe8HiKBTGwWB29MNUFAADCkHmcPUoGb22vhlbvs3r+nrPpv+3nRmKaxIfcvou0aD2HwviPsN0CdFDxAQAAYUhUfCJm8dYqP3eo9q/CnS7aOVtT5n0Nl3cR+4oqWD9TE5/IHeY54bG0qBn+bMcfCc+xWdegyC/JfWp57LmBK8dUFwAACGMpyRCXZfknpfR3v+Z089e6rr+uPmQ4vpT8x5gVX0rEKI5x+gcxSiPGPzzGV5T4AAAAWMZUFwAACIPEBwAAhEHiAwAAwiDxAQAAYZD4AACAMEh8AABAGCQ+AAAgjKJXVry9va2Px6NTU/r5+vpK39/fl3uEW40vpZQ+Pz+/czajshpjbh+mRIzKvJ+L9OFPxKgr8rlYlPg8Ho/08fHRplUDvb+/Z33OanwppbQsS9bOmlZjzO3DlIhRmfdzkT78iRh1RT4XmeoCAABhkPgAAIAwSHxQbFmypr4BAJBD4gMAAMIoWtyM2Kj0AABGOvveWde16nfKJj7bYGuDAwD41uOLcaacG0yvcW09Y9yL9fm7lmWpOhZMdQEAgDBkKz5R7GXBatk8U1yxHPW32rhEPKXXIiszB16vsb3ievblsiz//Rsl/UvFBwAAhCFX8fGa+b46i1O9CqTUFtRTH2dAStffCWdjVvX75G67ate2jNDimOfGtq5r1d+TS3xeqXZuqdoFeHcXcd1lpVSsSDWxsJx0w7fW4+/ulMgoV21STeBSatu20r6p/dtMdQEAgDBkKj6qGe2disedxxJnH4/Zf98yS8due0d89N9gn3r/jmyfcuXnyN55qhDHrGvd3ZkIKj4AACAMiYqPpTvkK7mxlC7Im5nVW7ozUqVwDL1t9Fai9hpj/bjkrOdSjbHldbB2EWwvqse81Ojj2upvTU18rF2IW3bw3kVHLeFBPot73yi37Y7W56nF41RyDBRiPNud9+jfStt8Nq2rTvm7oeVxzf1efP3bpZjqAgAAYUyp+NzN4BTuUGrl3tnMiu/qLislnTsNdUrHyeJdbqkIMeaoOQ6K015Xd/211yWlGD2pOa5HYzV3DPOuLgAAgAvDKz531kLMuKNrMX9ZOxetSqlCNZvFKoPHvrLYD62NvEYp8rjZqvdxPWvd1bDEp3R6a1t6Vej8knaUnHQKsW2VDkQvF5iWlI6J2vhS8No/V8dIcRroKUKyc5TQeN+B3NrDP3ecnZNnfc7iZgAAgAvdKz61lZ6cnx+td5atlMUrtWWEu+P06HPQcNY3lh9zrmFlnNa+nPRqwax6/JEqPUdyr7m1qPgAAIAwulZ8ShYyl2RzHrJejwvxvOl91zGLlTvfFrzG6Hkn6rvV1KvzVnX8U+nJd/d4UPEBAABhDH+cveYu2lO266FikMtyVat0nKrGF239inel/ag6LkdTeTr4iMVX3lgm8ZLSM9473mN8yheYGhYTnlfqF35ci7IcYC9ZV52eusvDtcUiproAAEAYcu/q8sp7zNbje73L9Dj92qLNHh/jV+/rqJuJep6mpdJzz90KIBUfAAAQhsy7ul55yno93iU/ncXm8U7NS7/VsL4xnFWez6eIWMg8n9ziZqudHymh83ji5vSf5fjuyB3b1hIgL4mEleM9k8ITpkxv3dPyGDHVBQAAwuha8cl9k67VbDfK46VPHvswJZ8VrBYsVUSOHtffq0JZimsr2njM7c/cn53F63XTMio+AAAgjGFrfDxltp4XK+8pjff186rHRumuUInH3YHp6zgsVHCV2hKR3OJmayIO4NzXOUQ8Nt4p92mPp5+U4/Uqp/+sJOck3JqY6gIAAGFQ8amwvbO09ghvqas7Fio9vnhYiNliIbOVWD1qWblT7EfFNkVDxQcAAIRBxeeG7SO0VuacW7CweBBlvK5FeB2T6u/lmmFZFhdxe4jhLu8zEK3IJT7WdkAuvbBaw0Jm36LsWO3tvGxhe0wsf2FabHNLd/Y7ioqpLgAAEMZSkhEuy/JPSunvfs3p5q91XX9dfchwfCn5jzErvpSIURzj9A9ilEaMf3iMryjxAQAAsIypLgAAEAaJDwAACIPEBwAAhEHiAwAAwiDxAQAAYZD4AACAMEh8AABAGEWvrHh7e1sfj0enpvTz9fWVvr+/L/estxpfSil9fn5+52xGZTXG3D5MiRiVeT8X6cOfiFFX5HOxKPF5PB7p4+OjTasGen9/z/qc1fhSSmlZlqydNa3GmNuHKRGjMu/nIn34EzHqinwuMtUFAMBAy7Lw4tyJSHwAAEAYRVNdAGJ7vUvlXX9Ame059Py/a84jzsV6VHwAAEAYJD4AYNRzrQjrRWLZ62/GQD6munBL7slGGVbHXp/l9A8XVh1nX3yca355Owdz4ukxnqn4AACAMKj4INuduw3uRuc767/tf8vpI/pxjpxzsLaih/Fy+8VbpSel/Jh6jGcqPgAAIAyTFZ+jTNHCXc1Zlqvafo93G1HU9J3F/o6w1qxFxTUl28fg6epYKMYY5Vy8Uvsd2GobgJSMJD6zFkC1ZLVE3erEmx1HNL0umGr9WBrnneMyK3aPX36lWvSz2thNqb5v13U1t4/P3T7Zi7kWU10AACAM+YqPxZLmq9JMV+EOr7QNFvqhN2vTmNs2WauqKpwjyp595eE4tYrBwwMW27Zb6uPWMwfP37csS1V/UvEBAABhyFd8jljI2mvnNLdzmep3Kart6qX2zmVkP5ZUSa1VenDNalWgRITHwEvPu9rqxyyt2lpzbZVNfKxNG2xZPtlyqB//1nKTA/V+f/0StNyPOV/oLV78qM5yH+a6E6NS4nfVFo992WuR+d1+ZaoLAACEIVvx2WM5I7bc9icPMeSqWVT/+m9H71NSeSy6ZudUtTHQ8o5eoSqw506MqjH1phy32jnUw4ztBEoq2VR8AABAGHIVH+VMPYeVjbOubO8yLbb/rqM1Oy023PKwxkbNnWPpYcuMI0prXPb0OBciraFRNOs70MXi5i2rA7b0yYO9i9Ts2Gf//VksJ7B3FlqXPvFlUUlcVvrcopavIIjI2s7NLd29NjHVBQAAwpCp+Hi9uzxyNgWSUqzsXdFe1aS0f2aO6bMpjjvTmNbvzC1fZ7w8nq/YJvShep2g4gMAAMKQqfgcUc0Yj1wtaM39HZjvrGpi5S3fZ3+rph3Wx2bu+ifrcSrrteDa8qa3paxUzVpXiHk7OwAAQCH5io81e08q5PJ2V+JFyyed6OP5zjaa9N4/6vFF3T4jl+WnTUvlXmtr4pdIfKyU7V61arfXgbvHejnaQht78Rq7x7gsXFP3ph1L+sLTlKWnm+Tc6eQWY7T2ODDVBQAAwpCo+FhTk6kqZ+gjWLgDxW/0FWY5mnaMPiatf3+07r+7x4OKDwAACEO24qOc4V49jqnc9hk8vwvJm0iLJ72y1IcetozoyWIcatWdPc0Sn4hJgOfYWiDhAZCjxd4+lq8nltt+x1G/9z4eTHUBAIAwbld87jyeHH3BmleeHjONwtL0CPZ5uJ5ePQrNmLSr9S7yd1DxAQAAYdyu+Nx50zN8sVzpsdz2O6j02He2PkK9CtRzd17oUOu/ZoubWwWmdoCQR/0Ce+XsNQZ7/+ZhnFrvM9jqQ0ttHY3CwVhMdQEAgDBk9/GBttq7t+fPqd/dXE0XjH78sgXuuP0onaK0dN6dnW/q7b/D4jVlT+67umai4gMAAMKg4gNcyNlcTX2OvuQOLMLdtVWWF6PntjNnN2crMd+hfk0poRYLiQ+6OZsuUjsRclgo4ZbYO/7b+CJ9yaRkM6m4at9rEmHpvMt9nYWVeM4cXSutnoN7facUC1NdAAAgjKUk+1qW5Z+U0t/9mtPNX+u6/rr6kOH4UvIfY1Z8KRGjOMbpH8QojRj/8BhfUeIDAABgGVNdAAAgDBIfAAAQBokPAAAIg8QHAACEQeIDAADCIPEBAABhkPgAAIAwSHwAAEAYRe/qent7Wx+PR6em9PP19ZW+v78vX7JkNb6UUvr8/PzO2YXTaoy5fZgSMSrzfi7Shz8Ro67I52JR4vN4PNLHx0ebVg30/v6e9Tmr8aWU0rIsWVuKW40xtw9TIkZl3s9F+vCn0hhVXj5KP/6Px/h4OzsQ2Oubr3mFDWZQenM3/GONDwAACIOKDxDMa5UHmIWxiBmo+AAAgDDkKj5ndwDM/6I3r2sNuLMGMNqdxeo9r8USiU/uRdnal1Ltl42V+LaWZTltt7WE9ioeD9Z1DZ0Q5cSuNgasnUeIaW+c7v3b0ZjtfV1iqgsAAIQxpeJzN5tTuhvvkZkqxXdlG3/tsbBWybPk7C4rQrXH6rUmQt9s1U6DjPhbPUSaDTgzq4JJxQcAAIQxpeKTc8e5l+2p3QWVtOc1HrVYSllv/6voG/m1iFetcmdxjN5p853jP6PvSmJt2ZezxmmLGErWyeDY1MXNpR22TZjULrJ7ztq2/W8WL9Atp0uU+9CTXsdZ5Vxs8aUw+1yccV3Ym66e3Zc5ats4uo97/z1LfXZlVAxMdQEAgDAkHmf3piRrnX2HOYvS3YnHPhg1dady7DxOAZRO97eIV/GYterbGWN15N+09FDM06z2UvEBAABhUPGpcOeuS+UOeSRLdyGW2lqr9SJLhWO2txljxHPNk15rthTG65OVdUqlah9gGsVU4qPc2bW7Fpf+LjWW94Sx2OYcXuO64u3JyVxWvzxLdu0l4fltdp+VOvt+mDk1x1QXAAAIw1TFZ0ttN9XtI4WeKzwlLMXVsq2zHy/tOe1jab+j3EfDLewZFgXVnn3Wx+PR9+KsayUVHwAAEIZ8xcdapmutva1sM3qlu6tRlPr9zo7irX6vmtw4rcXo5Vzzuqvx3b/f61yeQWk9qGzio7gavLTjckrt6oO1hKdYStVMcfZyNe68TR3Avu219fU6W3NeeRifKtcTj5jqAgAAYchWfPZYyeJzHqu1EgvOqfdt67YoxRZRy5fJtv69LWzbcadNlqfdIz4ckzMmWz7+TsUHAACEYaLiYy2jVVyfhHOW7xB7sfToOrB1th5TYRyzfufa2ePvWzX9ScUHAACEIVvxUcjKW/IWD/yi0uNHTmXBc//uPeE4u7rbutozO57ZauKXTXyUKe1HgHp7j9DmUi2h17IUj9oLUmfjOpRn7zF5xk9MTHUBAIAwqPjcoLRpHepFvuuzPq0VtfrT4roT6XhtzY77zvdG7vvnImFxMwAAwImmFZ+IGy+x3icWT68esVotiVppbfHqhojHTVHO98bVOWl1zVKL78y7cTZLfGqCsHrh3UMC5J+XvvV03kXVYjE+5ot8/tVM27FzMwAAQKFmFZ+j8nNuqdXzbsfW2x+Rxztk6wuZr3jfz8RrXMBoVHwAAEAYTRc359yR1Mzreb9T9eCqL9X77G6FR31hIet6YlKuXHp6UGAG5b7tpVVVd/g+Pi06i4u4PcpTmVH2RLHQRvhUco5xo3vN0k7rLbG4GQAAoJDszs1Wd6iMuseIN7njz8pd1nZcWmkz6qmM2c/PT/fTyCPx3dIGFR8AABDGsIrPncWv1nYfjX53cha/pb5TuWtuxWKbW1Efd7W87ZZPxfynnGOh3qettIyzS8VnWZZ//e9MlI7zbF3X//7Psqsx6yFG70qvP2e89PWscfuf//znx7Xh6n+R1Xxvejlmo2NhqgsAAISxlGRZy7L8k1L6u19zuvlrXddfVx8yHF9K/mPMii8lYhTHOP2DGKUR4x8e4ytKfAAAACxjqgsAAIRB4gMAAMIg8QEAAGGQ+AAAgDBIfAAAQBgkPgAAIAwSHwAAEEbRu7re3t7Wx+PRqSn9fH19pe/v78t9663Gl1JKn5+f3zmbUVmNMbcPUyJGZd7PRfrwJ2LUFflcLEp8Ho9H+vj4aNOqgd7f37M+ZzW+lFJaliVrZ02rMeb2YUrEqMz7uUgf/kSMuiKfi0x1AQCAMEh8AABAGCQ+AAAgDBIfAAAQRtHi5tmW5d8LtNXeLr/XxpT02gkA0Mb3SR9yic9RR79S6/irdj//u1q7AfzP2XnMuYvWcr/vzn6OcVmOqS4AABDGlIpPSZarns2WZuzeKj8R7jwsTLHiPqq2GKW20nP1uxibeaj4AACAMIZXfDwt1iqJ5fWzFrP0lncpqnJjtLwWpGUFa1kW+XhbiRIn0Eqv74y75+KwxMfiF32No9jOYlb/Es35orScFLVuu+JYz53GiWxd11DHgak7f2b06Yxz5u41lqkuAAAQhtzj7BEpV09yM2uLd401x/koTqU+2+rdLov9fuYZj2p/5rDc9ihqx1nOMoqeWvytFteMbTtqqlxUfAAAQBhUfBrz8uiz4jqVVlrcZe19Zu/3zlpHwV3/PVbG/J1+VozxarG8t+vSNobaa4VilXJk39T8LRKfG/YG7dF/v/O7R3qNw8PFZevq4uAt3j0RYvSqR9KuYG/qouRnUrIT65GW7Z9xwzXib7VK7pjqAgAAYQyr+FjPxq+0XrCFsbyNz73yt7cYo2m9B5rS4+wtFu0qxRNRy+Pf+52dVHwAAEAYrPGZjCrPfBGqddwF+xKhP7dVS+UtPyLpvZB61BpMEp/B1E/Y14Ed6ZUEV9T7bo/XvsvpCy+xe52uzL22eIrZixlPkrUcB0x1AQCAMKj4dHQnG569UG+b0c9uyywqu5Tit9rFriWU+6tX26jqznNVzfP2yP6R0e+rpOIDAADCoOLTSMsdVFXXkvR6XFHxLqZ0Q0rVPvOkZpy0qBKpbCaK8ZWBu3IX624r63d+3+vvVTdrjFPxAQAAYZiq+KhVCbytIejp7FiNvMvee6dW643hWv087pv9NmsLZo7T3CqytT6ree1GiyelLF1zzr7P7759/YpM4uN9l86cdqqf3D2TkZFyLjC5bVLvM/wbU5a/zb52ej7HSpOY7bG487MKjhbLl9zgHt2gtoqVqS4AABDGlIqPxQy+RGlWqrSgsgfV+O60wdoiS1zzOE4jUDo+LTbWVFvSUaPXBqOtZnqo+AAAgDCGVXxqqzx7c30qrGbjOVSP+WxUevxRHevRx5Nqv5zpVUVWM2IBds/XYnRPfCK9U6cVtePR4ukmSyf1EY9Jz17pOMo5qzq1Bajy8sQZU10AACCMKYubWzzarZQ91lCtgPRul/V+e2UpntK9lGp/V0rax6XXvk0AflM/l6j4AACAMExuYKieTe6Jsm7iiOXYVKtzCpTWb6m0A2NZ28zWsr0FxxaPu0Ti4/WCZXXnTfzmdVzWUhynLfpIMa7IOO/0jThneo4DproAAEAY3Ss+OS8fs+zO/kRWRCwl5y7etfRIdI92zTqHvTxWC2D87AgVHwAAEMbwNT4tMraWb2kdzWq7o8i988h5NNxzX8/czNHyo/S4djZLkPN5q/aqyJ6uJXerwy2PwdTFzbmB7B0wTwNCXfRjvBd/zklsOUHfM3tq2uPO2aNYfwrnzOu48BafZaoPHzDVBQAAwlhKsqllWf5JKf3drznd/LWu66+rDxmOLyX/MWbFlxIximOc/kGM0ojxD4/xFSU+AAAAljHVBQAAwiDxAQAAYZD4AACAMEh8AABAGCQ+AAAgDBIfAAAQBokPAAAIo+iVFW9vb+vj8ejUlH6+vr7S9/f35d7ZVuNLKaXPz8/vnM2orMaY24cpEaMy7+ciffgTMeqKfC4WJT6PxyN9fHy0adVA7+/vWZ+zGl9KKS3LkrWzptUYc/swJWJU5v1cpA9/IkZdkc/FqS8pBTCG55dUAkAJ1vgAAIAwqPgAQVDpwbbyt8XYQCRUfAAAQBjDKz6sNfDh2Y/0ITDOUcUGti3LwrV0oKlTXXx52sTFF/Bl9rU495pS276936/wvVN73FXjsYKpLgAAEMawig8Zqn3RKj2v8Vocr63v5D0ck8i2/aVwPp+14dnW7Wdatnnm9JLCsY+Mig8AAAhjyhof7hJtiVit24uZBYj/5uGYeH7g4iyes4rKiOOQe13Z+zcvFZMWa3tQrnviE6U0nlOyrfnZKzOOZ69pkx5/o1RpX8xeFIoyV/3roT9btH1EQnvn99/52dl9TPLyb6O/C5jqAgAAYbBz8w25mbvlDH9025XuwnovrvRAvTJS218Wp7+YNtF2NH3+dNZ/itXxV70qaT3ORSo+AAAgjGEVH5WstIUId0q9FzQrLZi+asvZI8CexrU33s7TFvF4OyYl1GNXb9+eqypWr79397rbNfGx2JEtePsy7J3wzKKUfFlwp1Q/ktIYG6HHcVfpy1YUx+ldM+NofY7VTvPVYqoLAACE0aXi4zG73lLb/bSnkv47OhbrupqprnjvT89q+079fG6xsLX3u7AsmB2b+jgbZXY/pETFBwAABNK04hM5i42qxcaNCncAuSy1tQfV+FtUNKxUJVNqW+lRjbEF1WvM3TUts+PZbvXRqy0984mixOfz8/M00Nd/X5bF/VMwe3u9eHAnLit9fLUVvpU49vQYj5aPh8W23+lDC1+eI3iJUTWO0UkP+/gAAAAUKkp8/vOf/xRlXOu6/vd/T94qI1Ft+/b1f6W/Q0lNHNCV05dHlWnFMdCiXaqxRWdpHLb2jH3vGDy1PA5UfAAAQBjDd27eZnQeMlnvFaxefeSh72Gb8rnba30P5x1UlI7xlgupqfgAAIAwhr+dfbuZnafKD44p31l7xTl1zGpFJHd7iJqfxVylWylcXVMV+7vlu+ak39WVq+deANBCP2Okkout0tjc2+187zpp8Qswujt9Zm27jV43vXcTIKa6AABAGBIVH6ss7fQ6A1NcsMDSORv1nVt3ryWzjkdNu61uPnmnj3Liafl9QsUHAACEQcUHgHl7a2Ksq31tjGI1oEaLCsL2QRql4+J5rVbuuVga095GyLX9SuJTwdsFthev7zHzzvKC4JT8jTelY9xbq6k8lTEQ6Um8bXtHtr1moTNTXQAAIIxhFR9Pme0eDzGMwN5NOq7eibP33qC936HWl3sl8b2Y1NodWcupHgsPnXj/PuylVSWPig8AAAija8XH04ItwLKatRM556f6OfzaPvX24n9K+8pCpSclO+1U0/Kt7V0SH+9lPA8xjJS7C60F1hb+cpGFBTVjUmUBcy5r7VXR47gx1QUAAMLoUvE5W1yIa1arIaWsjI3aO44o/QiMlntOKixm9z4DUqtFJYd3dQEAAFzo/jh75Iy2FHPAWugPAD1E/l7s/U6vHOzcDDSiuHOwYpugRWE6qKW9pRZKPBzjGVoeN6a6AABAGEvhjpj/pJT+7tecbv5a1/XX1YcMx5eS/xiz4kuJGMUxTv8gRmnE+IfH+IoSHwAAAMuY6gIAAGGQ+AAAgDBIfAAAQBgkPgAAIAwSHwAAEAaJDwAACIPEBwAAhFH0yoq3t7f18Xh0ako/X19f6fv7+3LvcqvxpZTS5+fnd85mVFZjzO3DlIhRmfdzkT78iRh1RT4XixKfx+ORPj4+2rRqoPf396zPWY0vpZSWZcnaWdNqjLl9mBIxKvN+LtKHPxGjrsjnIi8pFeftBYIAAMzEGh8AABAGFR9R20pPSlR7AKBUhIr5M0av8fVAxQcAAIQhX/F5rXycIeMFEHkOW9MAABS+SURBVJ3HKkfJ98DV7+CYQC7xKe1ML4MYAO7wlhz0+mJXirGlZVncxdQLU10AACAMmYrPWXYfLYu1VKq+c1emHhti2BvDXsZmbnVDfdqktD/U47nLe3y9UfEBAABhTK34UOX5SbXSw1z7NdW+K+W1+lE6hi31593zU6160KI9zz7zNp7V+sqqKYmPt8F4l+KePa1PsLMLkaUEKOe4ePvStBTPK74oyljrX6sUr/mRMNUFAADCGF7x8V7tKX2kUOluuubueHabR8qdmrVSZcg9F7efs3CnauX4j3DWP0rXnq2z6rAVltseARUfAAAQxrCKz1kGrHrnUcvSmpUcXuIoVbv4fnvHqjgWSs+3s2qWtU3TctvKHTtq9H4QBG3I7OPzpPhFkaM2eVvX9b8/Ozv2bVt6/o2Ujhc5q/e7evvOWJimQn8Wv0RLr41KMeZMNx5d+zy9smlUn+QcB6a6AABAGN0rPjULKLf/v3oW21KkWNXdWYSvdLf55G06Gb9djbWcsag6Hvaqw2fnpcJ5d+dYenpP5cy+yJk5oOIDAADCkFnjo5S1j6acuacUa21Ir40bZ2lZ6fG+FUVKMca6t5gsf2d4/N6zEEvXxCfihdKy3Fg8Tpu0eH2K0pdmy3Fp7TyuXUBv6Vz2NC3y6m4/WNxTa8SDJb3ltr/XWLyaEt1iqgsAAIQhM9XlhaU7qxY8VH+O7g6s7sB9V5SXB3uO02r7W+8eb+k45O6ybdGIfiiZNqTiAwAAwuhS8SldE2BtDcGT9Sz81dUxL310Vn2uvfW6HvUx23LthAVH6yYsV3pK1+Gpx5PLSxy1lONXOZ9er8Xv7++Hn6XiAwAAwmha8Sm9o7Ra6YmqtIJj+ckTy22/w0scTyXz/t5it47+sEu977ovbs49AOoHao+HRxBr3X2vzNXv6qlln6mOW9V2zcLx0Bb1Oor7arYRYaoLAACEwePsjVl4w7gC1WOk2i4gCs5B5KqtFFLxAQAAYUyp+Cht7Y/+LPSvhTYCXnH++TLjwaWS3z888fG2iM3jS+Za4WLmbz8V2MB4g5qW+/3c/b5lqgsAAIQxrOJDRQQqRlXpGPOYgWoPRmpxPb3zszXjnYoPAAAIo3vFR+U9Hpawo/V4rMWBdYxdzNRyU9s7fzNH08SHhb51cl//wIWtrb3x2uKlo4ovLmX8ABgtNyd4vTYdfb7VNYypLgAAEEaXqa6rLM/TnWdtdSv35zwdK1VHL18tmf5SnZ5kCg+wxePykKP3WuZWevY+ewcVHwAAEEbXxc1Ws9MrpVkpa57suFr3U/I7FCi1BX3Qxz7krvVMyV+fj65yyb6ktHQx1Ai5nRM10XlNGiyfpDUJrMU4vfM2bWCxzb1Yvr68ivCdUXpT2bNfmeoCAABhLCVZ1bIs/6SU/u7XnG7+Wtf119WHDMeXkv8Ys+JLiRjFMU7/IEZpxPiHx/iKEh8AAADLmOoCAABhkPgAAIAwSHwAAEAYJD4AACAMEh8AABAGiQ8AAAiDxAcAAIRB4gMAAMIoelfX29vb+ng8OjWln6+vr/T9/X35MhSr8aWU0ufn53fOLpxWY8ztw5SIUZn3c5E+/IkYdUU+F4sSn8fjkT4+Ptq0aqD39/esz1mNL6WUlmXJ2lLcaoy5fZgSMSrzfi7Shz8Ro67I56Ls29mBGbZvC470OpeocQOIhzU+AAAgDCo+QPpZ8UiJqgcAeEXFBwAAhEHFZ5LXCsMrKg7jRF3fcjUGrdmL56w/qfL5wjVVV+m1pndfkfgMUPMF8/wZTtYxIh3ns/Ho/Th4S/YiWJalKIHN/Zz3sT5Li3Os980oU10AACAM+YqPxfJlSca71/7tzytWfnLj27ZZ8W4r2t2/53ivYotc5bKqtk+vrqnbf6Pv22l9fenZN1R8AABAGLIVn9J529mZe8s7ynVd//X7Zt6d1GbyZz83c0GxYvWpt5K7Yy+uYrMcu8p1T8nZsdi7pqqzMtuhtnA5h1zik/ulpJIY9PoSff6e2Sdr6ZMyR7/j9WdmJnIt/m5tv8xO8rYULkA9qE+x3nH3eqD+ZXp3mUDO52dfU/fceQDmaXbf7an5rmj1u84w1QUAAMKQqfiUZK8qGbvn/V9aVHrOWD1eKmPvjpbHXnmBqIe+quUpdtXxVatX3yiei1dTsjnHgsfZAQAAbpCo+LSonIzMdEfeTW3npUcsaOxd6bHM0130XRaOhcUFraVaLCxVP0atKhmz45z992e4M3PT83tnauLjYarIarv3kPT4GJNbrWOwsKByS719vViLO3fPntq4ZiYd3hOevUXj6ktVmOoCAABhSEx15WbxCpniLNuyfespr97H1drd56vSBXheK2feqmGWRTr+udc+xf1kIn1nWRqTVHwAAEAYUys+LSo9Mxc1W8pwz5xt7HVUvfC6c+xV5ebu5y2yVunxOjbV3+HXyl5Mr9co9XNtxqJz9bFQGk/P7/0piU9OAqE4sGcOrBHHg5f7/bS32/TRf8v9HZZYSvQjPMGlfPxbOkvqWuy+PPqp2Nw2WNkNvoTqOclUFwAACENicfOWaoaYy9q0wJWjOxPPpfYa1sftlpdYolQmPah975aSO9WeGorH4En9GkLFBwAAhDGs4rOtELRYqKaQ7fbedOvobyjE7lVpxU5l4X0rVseb+h2mIm/VaYs8jFuLMVDxAQAAYcit8ckx6+4k9/HmFhmwehYd/Q6RSo8+6+2HLblbW6hf2xX03iZkeOJjtdOP2j3y5XmzL+RW+65WhHitJzwR+gj2tHjsvufvUzXq2sNUFwAACMPEVJfCHWhJG2rfTBthF2B1uXdWHu7ArFd7RohwjLzFo6D1dUH5OlN7LZw57qj4AACAMIZVfEqzQk93IdYXwlp6dUErEWL0qrTvlO+mYUuLsWS1mmzpmjl8qsvzCy5ftZzWUn6SLSLLxyXC9M2rll9IwJGW0z4R3j83C1NdAAAgjCmLmz3fOZU+9p67iFaBUluAI63vkr2Ne6oIOlq+pR35qPgAAIAwTDzOrqomM4+4uHsPb84er/edpKe1Q1bbDQ17a32uxlTJ+RlhfPa8XpH4VDi7wB91lvIC5j2UW2MoGXOKY6LXDrmwQ/namvt3uSHO0yp+proAAEAYVHwqnGWdNYvV1LN49fapUD5ONdNQd6ooo49FyaO/yv00gsX4a/dCe/6c9al1y22v0bu6TMUHAACE0b3i42nBYy3eweWHtf5rsQDfyvmas97HSiyRKa/ZGS1KnE+jzt2miY+1LwWgFbUL1N2pgavPKbPabuy705/bZFjxjQF8Z/42+maFqS4AABDGUvg46z8ppb/7Naebv9Z1/XX1IcPxpeQ/xqz4UiJGcYzTP4hRGjH+4TG+osQHAADAMqa6AABAGCQ+AAAgDBIfAAAQBokPAAAIg8QHAACEQeIDAADCIPEBAABhFL2y4u3tbX08Hp2a0s/X11f6/v6+3BvcanwppfT5+fmdsxmV1Rhz+zAlYlTm/VykD38iRl2Rz8WixOfxeKSPj482rRro/f0963NW40sppWVZsnbWtBpjbh+mRIzKvJ+L9OFPxKgr8rnIVBcAAAiDxAcAAIRB4gMAAMIg8QEAAGEULW5GH8vy74Xn67pOaAm2nv1CXwCAHyYSn73E4EnlS6n0S/IsJmiJlADVjEu142LhegFgHqa6AABAGLIVn9w7z9l346/tXJblX22hugMlkccj08qY7er88z4eFc5BKj4AACAMuYpP6d2oYnbc4o56diULPrQ4n6xUiF7beXTu7FVpzz6v5G5fXPWv0jFQbVet0lmMLevxX8U++hyUSXxKTmjrgyCl8xhmftHkfnl4Z+XL/kjLBb7Pz1s/Jk9H8exNUytoedzVvoDO2uBFbUJe8ztUqBcwmOoCAABhSFR8vD5+yqPtx7yVsZV4LJX34KGSta5rUZVm77xTil+pLa1d9U/ONLPl81ip7VR8AABAGBIVn+1dizW1WaxivIptGq22WnJ07Gbf5cz++7hvrzKz7deSPt5ea5Uqgy0fCLli5Zywej1WGldHJBKfPWoH6i6rg7g11eNQ2y7FeFqdOzmxWT5PrU0j3Gmf5b1jahYDnxm9gLv071hIHHKptpupLgAAEIZcxUc1QxyN4zBf1D6wWOlRrLzNZGV7kNJ+a/1ov8o4Vp0qL2HpgRUqPgAAIAy5io9XJY+OqmfL3tTMqXvrR+sVk5IHJKyt7cmlvmlcrTtjU3Uncq9buFhB4iOEAa+v5KKpVErf8vpKldebi6sXBiu1vVRtH1qO+cheH5fuodbzuHh72uyIpfYz1QUAAMKg4jOYpawY9vUu66tWtV4pTG+MYqE/Umo/hXX132aNASv9EQkVHwAAEIZcxUdx7QDa2VvkPavPe98BWnpMeI/qwtASnjaDu0JcQB4qPgAAIAy5is8TlR89Xp6IqWWl2tFz8zoL/X62dcRZ+y291kGpLb20jNHDBoFoRybxKdnnBhgpZ88NS+M2ysW+9JHm3M9FOX4esF8O9jDVBQAAwpCp+GC+vd1vIy0OLaF4DLi7vWapMuddyW7brf8u7Gm51IKKDwAACEOu4mNxzYQnHP/fchdDRj9O6lr2D5WCse6sqeK89GOvL++ut5NLfKDvzm691qfOvL2c1LM770jii9MeS0/l4drey4RbnZdMdQEAgDCo+GBXyykv7p7niXiXy2PnduRcZ64qzNGrsFEeamBxMwAAQAW5ig/VAS1H86qvd9V3+m3WHXrtDr/Q1Ks6yVjob3uMSzeUzP29lvG9+Fur/pRJfKx3bGn7o52Qyk+LeemLPVGmfVqUwS0mPT3bvLe4dJbSGy31fssRIcYjvb8nmOoCAABhyFR8XlnJZmsfmbV0J15arTmLSbny49mdLQhyfveWhTGdUpxFoTVUz89I01tbHmPaM2rcUfEBAABhyFZ8rCrNzHveibe2rdZYaXMkIxdrK6+HOaumelk30evOeHa/3onLSt9dOVuv5rVKObrCKJf4WO68lNruNaDKa1xe9Oyf2V+MR3JfsHv1O9SNTHhGuhpXXr/wr+S8Ouf5GYs3pLPGHVNdAAAgjKUkQ1yW5Z+U0t/9mtPNX+u6/rr6kOH4UvIfY1Z8KRGjOMbpH8QojRj/8BhfUeIDAABgGVNdAAAgDBIfAAAQBokPAAAIg8QHAACEQeIDAADCIPEBAABhkPgAAIAwil5Z8fb2tj4ej05N6efr6yt9f39f7o1tNb6UUvr8/PzO2YzKaoy5fZgSMSrzfi7Shz8Ro67I52JR4vN4PNLHx0ebVg30/v6e9Tmr8aWU0rIsWTtrWo0xtw9TIkZl3s9F+vAnYtQV+VxkqgsAAIRB4gMAAMIg8QEAAGGQ+AAAgDCKFjfPtiy/F2irv1H+2c49Z23f+zn1WBGTlXMRwDy134W9ySY+ZwdsWRa5C+5Ze2s+50luzGp9Wuo1TqvxlIxRxXPxzDY2S+1+5SWO3lS/eD2r+S4c3RdMdQEAgDDkKj6l2eLMrL1n9Ub5bmQvU29xLNTvYi1X63q23Uuly4K9fqSq8ZvlyrKHZQ6l43D7+dHf51R8AABAGDIVn6NscZsBnt3tjMyOr+4sStpioYpwtd7Kq7uxKax/mdE/CnF7cqcPPVQSrliseOX0qeUK1lNp20Z9n0skPrknZ8tplRoWT7C7PCc2Z0qT26jHae/GxOu5oOTOtTBCP3mO7dXsJQK1yfW6rtOum0x1AQCAMCQqPqVmZoqtqFePWk7nbX/f7KrdHWcxH03Jeru7zj0Gyqy0c0/Ltnsbm0/q8ZRc+yxfL3Mcxdd7upyKDwAACEO24lN6N9I7Q7zKuK/am5Oxq9+pPN1pZ+tKEupdPThw9nnLLFc6ctZTlFYHWIyujzWFbUkkPmfP+M8+KXOnpLbtLWXlolPbTotPJ+xNp85eRFgjt51W4sFPZw+B7OELc57SY8852Q9TXQAAIAyJis+VnIWxI7NjxUfte6pdyNzr949S2qce+94j1fF2pGV7PTwY4p218XlHzv59PVDxAQAAYchVfEp3Ce6VGd7ZlKn291sWYW3T2WPqs+5cAPw0e13okdyHCc6+eyx8b9z5DhzVb8MTn1YdN3pgK55ICkoXLlt+ouZqwTNsiPQEGzT1mEafNWZbxDK67Ux1AQCAMOSmus5wFzZHbVWj9FFbbyLFasVe1S5SP1GhRC+WziMqPgAAIIxhFR9Lm9j1vCvydsdVuvtvJKqLLKOL2Cecm35ZGc8Ka3ueqPgAAIAwJNb4KGesym3roeY9P6+8HrOau2bLT7HBPuWngdDmMXUL1xi174lhiY+1fQhe373VqpPUj0PO7ti5L2x9/TkAY5Dw2MKSgbGY6gIAAGEMn+qy+q6YngtVFUuVZ20pvTvxssj36lFoi+Ma/lyNQw/nomelswKK/XnWdoX2UvEBAABhSCxutlIRaLlAS32tz5XaRdAW+vlMTiXMap8+WXzv2LbNyu3soWS8RTs2Flmu9JxRau+UxOfspY9nnx+Jhbz/Y/2LXIF64mepj0vOSdXj3YrFJBX7PCewau1lqgsAAIQxdarLQlWlZiHv2e/I+Vml7Lh1JUAxxlYsVU2eLLW5dNHusiwhKrKvIsTozV6V0tK5uZVbcW0dX8m4p+IDAADCkFjcvJepKWa7e3eUORRjuVJzd11yfKwsaM9l8RHiq+qk8rjNOZ57sURa/wNdyudWbwqxSyQ+e84uwCpfmla+JEbY64+rkq2HaS+PCY+qO220Pn3wqkUfWh0HVuVeK9Qe9Gml17lXc0yY6gIAAGHIVnysaTFdZyWbb/3eMossvgvJS7UD9RgDujxWelT3JKLiAwAAwpCv+Fieh859lM9DPF60uCNWPE5sdOfD1QMDJej7sWoe7bbYR7nv45wZm1ziQynWDyulW6/JTkpt+0DloYIaFm80WooWs7UHJ6y0M5f6wwRMdQEAgDCWkkxzWZZ/Ukp/92tON3+t6/rr6kOG40vJf4xZ8aVEjOIYp38QozRi/MNjfEWJDwAAgGVMdQEAgDBIfAAAQBgkPgAAIAwSHwAAEAaJDwAACIPEBwAAhEHiAwAAwiDxAQAAYZD4AACAMP4P6a4zeY5hWWgAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "# Generating new images\n", + "samples = np.zeros((100, height, width, n_channel), dtype='float32')\n", + "for i in range(height):\n", + " for j in range(width):\n", + " logits = gated_pixelcnn(samples)\n", + " next_sample = tf.random.categorical(logits[:, i, j, :], 1)\n", + " samples[:, i, j, 0] = (next_sample.numpy() / (q_levels - 1))[:, 0]\n", + "\n", + "fig = plt.figure(figsize=(10, 10))\n", + "for i in range(100):\n", + " ax = fig.add_subplot(10, 10, i + 1)\n", + " ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary)\n", + " plt.xticks(np.array([]))\n", + " plt.yticks(np.array([]))\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 153 }, + "colab_type": "code", + "id": "KBuWx-FtSouR", + "outputId": "3498bba6-068c-413a-b1a9-74e98192b543" + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "-Ia9VXYySkuW", - "outputId": "ad1075b1-70cc-4012-851e-925cae32ac2c", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 581 - } - }, - "source": [ - "# Generating new images\n", - "samples = np.zeros((100, height, width, n_channel), dtype='float32')\n", - "for i in range(height):\n", - " for j in range(width):\n", - " logits = gated_pixelcnn(samples)\n", - " next_sample = tf.random.categorical(logits[:, i, j, :], 1)\n", - " samples[:, i, j, 0] = (next_sample.numpy() / (q_levels - 1))[:, 0]\n", - "\n", - "fig = plt.figure(figsize=(10, 10))\n", - "for i in range(100):\n", - " ax = fig.add_subplot(10, 10, i + 1)\n", - " ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary)\n", - " plt.xticks(np.array([]))\n", - " plt.yticks(np.array([]))\n", - "plt.show()" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj4AAAI0CAYAAAAdqSPKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3dW3LkOLKtYfDYGYLyuWIOrfmPQJqD6rk0B+6HzOimongBQFyWu/+fWZvtnRWS4ATIcDpAcFnXNQEAAETw/2Y3AAAAYBQSHwAAEAaJDwAACIPEBwAAhEHiAwAAwiDxAQAAYZD4AACAMEh8AABAGCQ+AAAgjP9f8uG3t7f18Xh0ako/X19f6fv7e7n6nNX4Ukrp8/Pze13XX1efsxpjbh+mRIzKvJ+L9OFPxKgr8rlYlPg8Ho/08fHRplUDvb+/Z33OanwppbQsy985n7MaY24fpkSMyryfi/ThT8SoK/K5yFQXAAAIg8QHAACEQeIDAADCIPEBAABhFC1ubmVZshbLn1rXtUFLgGN745RxB2AWrkltdE98WiQ5Ob93VufnxOdxYPboV6XjdBTf89+V2opjJeOUPoUqEp62mOoCAABhTJnq8iL3bnJZFhfZea/q3fb3ezhOmKv3OAVGYSz3QcUHAACE0bXi4zVbzZ1vtRb/7PYqVHtmHwMrtsdJod+OKLcNOMO1qJ8uiU+kDsu9sCovih3RX4px41zJVG5K8/s40nUHx47Gwezx2YKHGJ5Kz9eWsTPVBQAAwmBx8yDKmfpZ266ycuW4SuTefXiIN0plxENf4VjpOFapTF7xen7ejavl9DoVHwAAEEbTio/XTLWUp+OgfnfUQoRqT88xqXBcPJ1zOEY/23RUrantz7vVO/NTXbMvupyIdkVIeO54jZuxrqN1XyiPccadfUfjq2TctRwHTHUBAIAwTFZ8Zt+dPP9+SQaaU5rjfSw6vB13b/FEFan6ESnWK5y/bVHxAQAAYZio+Khmu3cWaVnbHA5xnY1VlXG5ruu/2hmhgloTj3Il5U7brq7H3vreu57XHdnEx9ogzX1lRc6JbS12+KL8xXhX9BfhekwIchbhW48Rv7XqR6a6AABAGHIVH0+Zee5UmKeYYZvlR/zv7EDunXoVpMV7m6L3sSe9xysVHwAAEEbTik/NY94RcDygzvsY3au+KlU8cr1eYy1X6FJqM+6opvsxqjJJxQcAAIQht8YnGu5I/LF0F+690nPGU+Xn6L9HZTF+9XVYPY2OvUvi0+IlZJbxyLoNLfrA8viuaTvjVku0/sj5bol2TCyb1YdMdQEAgDC6T3VFe8PzVXxe70aiVA88VPNaPDps1TZ2T3FZlfNADP2E1qj4AACAMFjc3IjnSk+vKp21xaXWFx9aWnTdivU+i6K0Tzz2a7TtYGb2IYnPIK+dbOUkjXISnvFwkY2Q9DBd4l+E61HEsTo6Zqa6AABAGMMqPhEy9RLWpnl6Uj0WVqt0NTzG5jEm/Bv9bNPMfqPiAwAAwuhe8YlS6dlbmJaz2Zbnx2stvkHZw3qeXJ7i8nweRcdGhWita+Kj/iXXw9HJmLNiX3XKJ9fRnk0W4vGe8HiKBTGwWB29MNUFAADCkHmcPUoGb22vhlbvs3r+nrPpv+3nRmKaxIfcvou0aD2HwviPsN0CdFDxAQAAYUhUfCJm8dYqP3eo9q/CnS7aOVtT5n0Nl3cR+4oqWD9TE5/IHeY54bG0qBn+bMcfCc+xWdegyC/JfWp57LmBK8dUFwAACGMpyRCXZfknpfR3v+Z089e6rr+uPmQ4vpT8x5gVX0rEKI5x+gcxSiPGPzzGV5T4AAAAWMZUFwAACIPEBwAAhEHiAwAAwiDxAQAAYZD4AACAMEh8AABAGCQ+AAAgjKJXVry9va2Px6NTU/r5+vpK39/fl3uEW40vpZQ+Pz+/czajshpjbh+mRIzKvJ+L9OFPxKgr8rlYlPg8Ho/08fHRplUDvb+/Z33OanwppbQsS9bOmlZjzO3DlIhRmfdzkT78iRh1RT4XmeoCAABhkPgAAIAwSHxQbFmypr4BAJBD4gMAAMIoWtyM2Kj0AABGOvveWde16nfKJj7bYGuDAwD41uOLcaacG0yvcW09Y9yL9fm7lmWpOhZMdQEAgDBkKz5R7GXBatk8U1yxHPW32rhEPKXXIiszB16vsb3ievblsiz//Rsl/UvFBwAAhCFX8fGa+b46i1O9CqTUFtRTH2dAStffCWdjVvX75G67ate2jNDimOfGtq5r1d+TS3xeqXZuqdoFeHcXcd1lpVSsSDWxsJx0w7fW4+/ulMgoV21STeBSatu20r6p/dtMdQEAgDBkKj6qGe2disedxxJnH4/Zf98yS8due0d89N9gn3r/jmyfcuXnyN55qhDHrGvd3ZkIKj4AACAMiYqPpTvkK7mxlC7Im5nVW7ozUqVwDL1t9Fai9hpj/bjkrOdSjbHldbB2EWwvqse81Ojj2upvTU18rF2IW3bw3kVHLeFBPot73yi37Y7W56nF41RyDBRiPNud9+jfStt8Nq2rTvm7oeVxzf1efP3bpZjqAgAAYUyp+NzN4BTuUGrl3tnMiu/qLislnTsNdUrHyeJdbqkIMeaoOQ6K015Xd/211yWlGD2pOa5HYzV3DPOuLgAAgAvDKz531kLMuKNrMX9ZOxetSqlCNZvFKoPHvrLYD62NvEYp8rjZqvdxPWvd1bDEp3R6a1t6Vej8knaUnHQKsW2VDkQvF5iWlI6J2vhS8No/V8dIcRroKUKyc5TQeN+B3NrDP3ecnZNnfc7iZgAAgAvdKz61lZ6cnx+td5atlMUrtWWEu+P06HPQcNY3lh9zrmFlnNa+nPRqwax6/JEqPUdyr7m1qPgAAIAwulZ8ShYyl2RzHrJejwvxvOl91zGLlTvfFrzG6Hkn6rvV1KvzVnX8U+nJd/d4UPEBAABhDH+cveYu2lO266FikMtyVat0nKrGF239inel/ag6LkdTeTr4iMVX3lgm8ZLSM9473mN8yheYGhYTnlfqF35ci7IcYC9ZV52eusvDtcUiproAAEAYcu/q8sp7zNbje73L9Dj92qLNHh/jV+/rqJuJep6mpdJzz90KIBUfAAAQhsy7ul55yno93iU/ncXm8U7NS7/VsL4xnFWez6eIWMg8n9ziZqudHymh83ji5vSf5fjuyB3b1hIgL4mEleM9k8ITpkxv3dPyGDHVBQAAwuha8cl9k67VbDfK46VPHvswJZ8VrBYsVUSOHtffq0JZimsr2njM7c/cn53F63XTMio+AAAgjGFrfDxltp4XK+8pjff186rHRumuUInH3YHp6zgsVHCV2hKR3OJmayIO4NzXOUQ8Nt4p92mPp5+U4/Uqp/+sJOck3JqY6gIAAGFQ8amwvbO09ghvqas7Fio9vnhYiNliIbOVWD1qWblT7EfFNkVDxQcAAIRBxeeG7SO0VuacW7CweBBlvK5FeB2T6u/lmmFZFhdxe4jhLu8zEK3IJT7WdkAuvbBaw0Jm36LsWO3tvGxhe0wsf2FabHNLd/Y7ioqpLgAAEMZSkhEuy/JPSunvfs3p5q91XX9dfchwfCn5jzErvpSIURzj9A9ilEaMf3iMryjxAQAAsIypLgAAEAaJDwAACIPEBwAAhEHiAwAAwiDxAQAAYZD4AACAMEh8AABAGEWvrHh7e1sfj0enpvTz9fWVvr+/L/estxpfSil9fn5+52xGZTXG3D5MiRiVeT8X6cOfiFFX5HOxKPF5PB7p4+OjTasGen9/z/qc1fhSSmlZlqydNa3GmNuHKRGjMu/nIn34EzHqinwuMtUFAMBAy7Lw4tyJSHwAAEAYRVNdAGJ7vUvlXX9Ame059Py/a84jzsV6VHwAAEAYJD4AYNRzrQjrRWLZ62/GQD6munBL7slGGVbHXp/l9A8XVh1nX3yca355Owdz4ukxnqn4AACAMKj4INuduw3uRuc767/tf8vpI/pxjpxzsLaih/Fy+8VbpSel/Jh6jGcqPgAAIAyTFZ+jTNHCXc1Zlqvafo93G1HU9J3F/o6w1qxFxTUl28fg6epYKMYY5Vy8Uvsd2GobgJSMJD6zFkC1ZLVE3erEmx1HNL0umGr9WBrnneMyK3aPX36lWvSz2thNqb5v13U1t4/P3T7Zi7kWU10AACAM+YqPxZLmq9JMV+EOr7QNFvqhN2vTmNs2WauqKpwjyp595eE4tYrBwwMW27Zb6uPWMwfP37csS1V/UvEBAABhyFd8jljI2mvnNLdzmep3Kart6qX2zmVkP5ZUSa1VenDNalWgRITHwEvPu9rqxyyt2lpzbZVNfKxNG2xZPtlyqB//1nKTA/V+f/0StNyPOV/oLV78qM5yH+a6E6NS4nfVFo992WuR+d1+ZaoLAACEIVvx2WM5I7bc9icPMeSqWVT/+m9H71NSeSy6ZudUtTHQ8o5eoSqw506MqjH1phy32jnUw4ztBEoq2VR8AABAGHIVH+VMPYeVjbOubO8yLbb/rqM1Oy023PKwxkbNnWPpYcuMI0prXPb0OBciraFRNOs70MXi5i2rA7b0yYO9i9Ts2Gf//VksJ7B3FlqXPvFlUUlcVvrcopavIIjI2s7NLd29NjHVBQAAwpCp+Hi9uzxyNgWSUqzsXdFe1aS0f2aO6bMpjjvTmNbvzC1fZ7w8nq/YJvShep2g4gMAAMKQqfgcUc0Yj1wtaM39HZjvrGpi5S3fZ3+rph3Wx2bu+ifrcSrrteDa8qa3paxUzVpXiHk7OwAAQCH5io81e08q5PJ2V+JFyyed6OP5zjaa9N4/6vFF3T4jl+WnTUvlXmtr4pdIfKyU7V61arfXgbvHejnaQht78Rq7x7gsXFP3ph1L+sLTlKWnm+Tc6eQWY7T2ODDVBQAAwpCo+FhTk6kqZ+gjWLgDxW/0FWY5mnaMPiatf3+07r+7x4OKDwAACEO24qOc4V49jqnc9hk8vwvJm0iLJ72y1IcetozoyWIcatWdPc0Sn4hJgOfYWiDhAZCjxd4+lq8nltt+x1G/9z4eTHUBAIAwbld87jyeHH3BmleeHjONwtL0CPZ5uJ5ePQrNmLSr9S7yd1DxAQAAYdyu+Nx50zN8sVzpsdz2O6j02He2PkK9CtRzd17oUOu/ZoubWwWmdoCQR/0Ce+XsNQZ7/+ZhnFrvM9jqQ0ttHY3CwVhMdQEAgDBk9/GBttq7t+fPqd/dXE0XjH78sgXuuP0onaK0dN6dnW/q7b/D4jVlT+67umai4gMAAMKg4gNcyNlcTX2OvuQOLMLdtVWWF6PntjNnN2crMd+hfk0poRYLiQ+6OZsuUjsRclgo4ZbYO/7b+CJ9yaRkM6m4at9rEmHpvMt9nYWVeM4cXSutnoN7facUC1NdAAAgjKUk+1qW5Z+U0t/9mtPNX+u6/rr6kOH4UvIfY1Z8KRGjOMbpH8QojRj/8BhfUeIDAABgGVNdAAAgDBIfAAAQBokPAAAIg8QHAACEQeIDAADCIPEBAABhkPgAAIAwSHwAAEAYRe/qent7Wx+PR6em9PP19ZW+v78vX7JkNb6UUvr8/PzO2YXTaoy5fZgSMSrzfi7Shz8Ro67I52JR4vN4PNLHx0ebVg30/v6e9Tmr8aWU0rIsWVuKW40xtw9TIkZl3s9F+vCn0hhVXj5KP/6Px/h4OzsQ2Oubr3mFDWZQenM3/GONDwAACIOKDxDMa5UHmIWxiBmo+AAAgDDkKj5ndwDM/6I3r2sNuLMGMNqdxeo9r8USiU/uRdnal1Ltl42V+LaWZTltt7WE9ioeD9Z1DZ0Q5cSuNgasnUeIaW+c7v3b0ZjtfV1iqgsAAIQxpeJzN5tTuhvvkZkqxXdlG3/tsbBWybPk7C4rQrXH6rUmQt9s1U6DjPhbPUSaDTgzq4JJxQcAAIQxpeKTc8e5l+2p3QWVtOc1HrVYSllv/6voG/m1iFetcmdxjN5p853jP6PvSmJt2ZezxmmLGErWyeDY1MXNpR22TZjULrJ7ztq2/W8WL9Atp0uU+9CTXsdZ5Vxs8aUw+1yccV3Ym66e3Zc5ats4uo97/z1LfXZlVAxMdQEAgDAkHmf3piRrnX2HOYvS3YnHPhg1dady7DxOAZRO97eIV/GYterbGWN15N+09FDM06z2UvEBAABhUPGpcOeuS+UOeSRLdyGW2lqr9SJLhWO2txljxHPNk15rthTG65OVdUqlah9gGsVU4qPc2bW7Fpf+LjWW94Sx2OYcXuO64u3JyVxWvzxLdu0l4fltdp+VOvt+mDk1x1QXAAAIw1TFZ0ttN9XtI4WeKzwlLMXVsq2zHy/tOe1jab+j3EfDLewZFgXVnn3Wx+PR9+KsayUVHwAAEIZ8xcdapmutva1sM3qlu6tRlPr9zo7irX6vmtw4rcXo5Vzzuqvx3b/f61yeQWk9qGzio7gavLTjckrt6oO1hKdYStVMcfZyNe68TR3Avu219fU6W3NeeRifKtcTj5jqAgAAYchWfPZYyeJzHqu1EgvOqfdt67YoxRZRy5fJtv69LWzbcadNlqfdIz4ckzMmWz7+TsUHAACEYaLiYy2jVVyfhHOW7xB7sfToOrB1th5TYRyzfufa2ePvWzX9ScUHAACEIVvxUcjKW/IWD/yi0uNHTmXBc//uPeE4u7rbutozO57ZauKXTXyUKe1HgHp7j9DmUi2h17IUj9oLUmfjOpRn7zF5xk9MTHUBAIAwqPjcoLRpHepFvuuzPq0VtfrT4roT6XhtzY77zvdG7vvnImFxMwAAwImmFZ+IGy+x3icWT68esVotiVppbfHqhojHTVHO98bVOWl1zVKL78y7cTZLfGqCsHrh3UMC5J+XvvV03kXVYjE+5ot8/tVM27FzMwAAQKFmFZ+j8nNuqdXzbsfW2x+Rxztk6wuZr3jfz8RrXMBoVHwAAEAYTRc359yR1Mzreb9T9eCqL9X77G6FR31hIet6YlKuXHp6UGAG5b7tpVVVd/g+Pi06i4u4PcpTmVH2RLHQRvhUco5xo3vN0k7rLbG4GQAAoJDszs1Wd6iMuseIN7njz8pd1nZcWmkz6qmM2c/PT/fTyCPx3dIGFR8AABDGsIrPncWv1nYfjX53cha/pb5TuWtuxWKbW1Efd7W87ZZPxfynnGOh3qettIyzS8VnWZZ//e9MlI7zbF3X//7Psqsx6yFG70qvP2e89PWscfuf//znx7Xh6n+R1Xxvejlmo2NhqgsAAISxlGRZy7L8k1L6u19zuvlrXddfVx8yHF9K/mPMii8lYhTHOP2DGKUR4x8e4ytKfAAAACxjqgsAAIRB4gMAAMIg8QEAAGGQ+AAAgDBIfAAAQBgkPgAAIAwSHwAAEEbRu7re3t7Wx+PRqSn9fH19pe/v78t9663Gl1JKn5+f3zmbUVmNMbcPUyJGZd7PRfrwJ2LUFflcLEp8Ho9H+vj4aNOqgd7f37M+ZzW+lFJaliVrZ02rMeb2YUrEqMz7uUgf/kSMuiKfi0x1AQCAMEh8AABAGCQ+AAAgDBIfAAAQRtHi5tmW5d8LtNXeLr/XxpT02gkA0Mb3SR9yic9RR79S6/irdj//u1q7AfzP2XnMuYvWcr/vzn6OcVmOqS4AABDGlIpPSZarns2WZuzeKj8R7jwsTLHiPqq2GKW20nP1uxibeaj4AACAMIZXfDwt1iqJ5fWzFrP0lncpqnJjtLwWpGUFa1kW+XhbiRIn0Eqv74y75+KwxMfiF32No9jOYlb/Es35orScFLVuu+JYz53GiWxd11DHgak7f2b06Yxz5u41lqkuAAAQhtzj7BEpV09yM2uLd401x/koTqU+2+rdLov9fuYZj2p/5rDc9ihqx1nOMoqeWvytFteMbTtqqlxUfAAAQBhUfBrz8uiz4jqVVlrcZe19Zu/3zlpHwV3/PVbG/J1+VozxarG8t+vSNobaa4VilXJk39T8LRKfG/YG7dF/v/O7R3qNw8PFZevq4uAt3j0RYvSqR9KuYG/qouRnUrIT65GW7Z9xwzXib7VK7pjqAgAAYQyr+FjPxq+0XrCFsbyNz73yt7cYo2m9B5rS4+wtFu0qxRNRy+Pf+52dVHwAAEAYrPGZjCrPfBGqddwF+xKhP7dVS+UtPyLpvZB61BpMEp/B1E/Y14Ed6ZUEV9T7bo/XvsvpCy+xe52uzL22eIrZixlPkrUcB0x1AQCAMKj4dHQnG569UG+b0c9uyywqu5Tit9rFriWU+6tX26jqznNVzfP2yP6R0e+rpOIDAADCoOLTSMsdVFXXkvR6XFHxLqZ0Q0rVPvOkZpy0qBKpbCaK8ZWBu3IX624r63d+3+vvVTdrjFPxAQAAYZiq+KhVCbytIejp7FiNvMvee6dW643hWv087pv9NmsLZo7T3CqytT6ree1GiyelLF1zzr7P7759/YpM4uN9l86cdqqf3D2TkZFyLjC5bVLvM/wbU5a/zb52ej7HSpOY7bG487MKjhbLl9zgHt2gtoqVqS4AABDGlIqPxQy+RGlWqrSgsgfV+O60wdoiS1zzOE4jUDo+LTbWVFvSUaPXBqOtZnqo+AAAgDCGVXxqqzx7c30qrGbjOVSP+WxUevxRHevRx5Nqv5zpVUVWM2IBds/XYnRPfCK9U6cVtePR4ukmSyf1EY9Jz17pOMo5qzq1Bajy8sQZU10AACCMKYubWzzarZQ91lCtgPRul/V+e2UpntK9lGp/V0rax6XXvk0AflM/l6j4AACAMExuYKieTe6Jsm7iiOXYVKtzCpTWb6m0A2NZ28zWsr0FxxaPu0Ti4/WCZXXnTfzmdVzWUhynLfpIMa7IOO/0jThneo4DproAAEAY3Ss+OS8fs+zO/kRWRCwl5y7etfRIdI92zTqHvTxWC2D87AgVHwAAEMbwNT4tMraWb2kdzWq7o8i988h5NNxzX8/czNHyo/S4djZLkPN5q/aqyJ6uJXerwy2PwdTFzbmB7B0wTwNCXfRjvBd/zklsOUHfM3tq2uPO2aNYfwrnzOu48BafZaoPHzDVBQAAwlhKsqllWf5JKf3drznd/LWu66+rDxmOLyX/MWbFlxIximOc/kGM0ojxD4/xFSU+AAAAljHVBQAAwiDxAQAAYZD4AACAMEh8AABAGCQ+AAAgDBIfAAAQBokPAAAIo+iVFW9vb+vj8ejUlH6+vr7S9/f35d7ZVuNLKaXPz8/vnM2orMaY24cpEaMy7+ciffgTMeqKfC4WJT6PxyN9fHy0adVA7+/vWZ+zGl9KKS3LkrWzptUYc/swJWJU5v1cpA9/IkZdkc/FqS8pBTCG55dUAkAJ1vgAAIAwqPgAQVDpwbbyt8XYQCRUfAAAQBjDKz6sNfDh2Y/0ITDOUcUGti3LwrV0oKlTXXx52sTFF/Bl9rU495pS276936/wvVN73FXjsYKpLgAAEMawig8Zqn3RKj2v8Vocr63v5D0ck8i2/aVwPp+14dnW7Wdatnnm9JLCsY+Mig8AAAhjyhof7hJtiVit24uZBYj/5uGYeH7g4iyes4rKiOOQe13Z+zcvFZMWa3tQrnviE6U0nlOyrfnZKzOOZ69pkx5/o1RpX8xeFIoyV/3roT9btH1EQnvn99/52dl9TPLyb6O/C5jqAgAAYbBz8w25mbvlDH9025XuwnovrvRAvTJS218Wp7+YNtF2NH3+dNZ/itXxV70qaT3ORSo+AAAgjGEVH5WstIUId0q9FzQrLZi+asvZI8CexrU33s7TFvF4OyYl1GNXb9+eqypWr79397rbNfGx2JEtePsy7J3wzKKUfFlwp1Q/ktIYG6HHcVfpy1YUx+ldM+NofY7VTvPVYqoLAACE0aXi4zG73lLb/bSnkv47OhbrupqprnjvT89q+079fG6xsLX3u7AsmB2b+jgbZXY/pETFBwAABNK04hM5i42qxcaNCncAuSy1tQfV+FtUNKxUJVNqW+lRjbEF1WvM3TUts+PZbvXRqy0984mixOfz8/M00Nd/X5bF/VMwe3u9eHAnLit9fLUVvpU49vQYj5aPh8W23+lDC1+eI3iJUTWO0UkP+/gAAAAUKkp8/vOf/xRlXOu6/vd/T94qI1Ft+/b1f6W/Q0lNHNCV05dHlWnFMdCiXaqxRWdpHLb2jH3vGDy1PA5UfAAAQBjDd27eZnQeMlnvFaxefeSh72Gb8rnba30P5x1UlI7xlgupqfgAAIAwhr+dfbuZnafKD44p31l7xTl1zGpFJHd7iJqfxVylWylcXVMV+7vlu+ak39WVq+deANBCP2Okkout0tjc2+187zpp8Qswujt9Zm27jV43vXcTIKa6AABAGBIVH6ss7fQ6A1NcsMDSORv1nVt3ryWzjkdNu61uPnmnj3Liafl9QsUHAACEQcUHgHl7a2Ksq31tjGI1oEaLCsL2QRql4+J5rVbuuVga095GyLX9SuJTwdsFthev7zHzzvKC4JT8jTelY9xbq6k8lTEQ6Um8bXtHtr1moTNTXQAAIIxhFR9Pme0eDzGMwN5NOq7eibP33qC936HWl3sl8b2Y1NodWcupHgsPnXj/PuylVSWPig8AAAija8XH04ItwLKatRM556f6OfzaPvX24n9K+8pCpSclO+1U0/Kt7V0SH+9lPA8xjJS7C60F1hb+cpGFBTVjUmUBcy5r7VXR47gx1QUAAMLoUvE5W1yIa1arIaWsjI3aO44o/QiMlntOKixm9z4DUqtFJYd3dQEAAFzo/jh75Iy2FHPAWugPAD1E/l7s/U6vHOzcDDSiuHOwYpugRWE6qKW9pRZKPBzjGVoeN6a6AABAGEvhjpj/pJT+7tecbv5a1/XX1YcMx5eS/xiz4kuJGMUxTv8gRmnE+IfH+IoSHwAAAMuY6gIAAGGQ+AAAgDBIfAAAQBgkPgAAIAwSHwAAEAaJDwAACIPEBwAAhFH0yoq3t7f18Xh0ako/X19f6fv7+3LvcqvxpZTS5+fnd85mVFZjzO3DlIhRmfdzkT78iRh1RT4XixKfx+ORPj4+2rRqoPf396zPWY0vpZSWZcnaWdNqjLl9mBIxKvN+LtKHPxGjrsjnIi8pFeftBYIAAMzEGh8AABAGFR9R20pPSlR7AKBUhIr5M0av8fVAxQcAAIQhX/F5rXycIeMFEHkOW9MAABS+SURBVJ3HKkfJ98DV7+CYQC7xKe1ML4MYAO7wlhz0+mJXirGlZVncxdQLU10AACAMmYrPWXYfLYu1VKq+c1emHhti2BvDXsZmbnVDfdqktD/U47nLe3y9UfEBAABhTK34UOX5SbXSw1z7NdW+K+W1+lE6hi31593zU6160KI9zz7zNp7V+sqqKYmPt8F4l+KePa1PsLMLkaUEKOe4ePvStBTPK74oyljrX6sUr/mRMNUFAADCGF7x8V7tKX2kUOluuubueHabR8qdmrVSZcg9F7efs3CnauX4j3DWP0rXnq2z6rAVltseARUfAAAQxrCKz1kGrHrnUcvSmpUcXuIoVbv4fnvHqjgWSs+3s2qWtU3TctvKHTtq9H4QBG3I7OPzpPhFkaM2eVvX9b8/Ozv2bVt6/o2Ujhc5q/e7evvOWJimQn8Wv0RLr41KMeZMNx5d+zy9smlUn+QcB6a6AABAGN0rPjULKLf/v3oW21KkWNXdWYSvdLf55G06Gb9djbWcsag6Hvaqw2fnpcJ5d+dYenpP5cy+yJk5oOIDAADCkFnjo5S1j6acuacUa21Ir40bZ2lZ6fG+FUVKMca6t5gsf2d4/N6zEEvXxCfihdKy3Fg8Tpu0eH2K0pdmy3Fp7TyuXUBv6Vz2NC3y6m4/WNxTa8SDJb3ltr/XWLyaEt1iqgsAAIQhM9XlhaU7qxY8VH+O7g6s7sB9V5SXB3uO02r7W+8eb+k45O6ybdGIfiiZNqTiAwAAwuhS8SldE2BtDcGT9Sz81dUxL310Vn2uvfW6HvUx23LthAVH6yYsV3pK1+Gpx5PLSxy1lONXOZ9er8Xv7++Hn6XiAwAAwmha8Sm9o7Ra6YmqtIJj+ckTy22/w0scTyXz/t5it47+sEu977ovbs49AOoHao+HRxBr3X2vzNXv6qlln6mOW9V2zcLx0Bb1Oor7arYRYaoLAACEwePsjVl4w7gC1WOk2i4gCs5B5KqtFFLxAQAAYUyp+Cht7Y/+LPSvhTYCXnH++TLjwaWS3z888fG2iM3jS+Za4WLmbz8V2MB4g5qW+/3c/b5lqgsAAIQxrOJDRQQqRlXpGPOYgWoPRmpxPb3zszXjnYoPAAAIo3vFR+U9Hpawo/V4rMWBdYxdzNRyU9s7fzNH08SHhb51cl//wIWtrb3x2uKlo4ovLmX8ABgtNyd4vTYdfb7VNYypLgAAEEaXqa6rLM/TnWdtdSv35zwdK1VHL18tmf5SnZ5kCg+wxePykKP3WuZWevY+ewcVHwAAEEbXxc1Ws9MrpVkpa57suFr3U/I7FCi1BX3Qxz7krvVMyV+fj65yyb6ktHQx1Ai5nRM10XlNGiyfpDUJrMU4vfM2bWCxzb1Yvr68ivCdUXpT2bNfmeoCAABhLCVZ1bIs/6SU/u7XnG7+Wtf119WHDMeXkv8Ys+JLiRjFMU7/IEZpxPiHx/iKEh8AAADLmOoCAABhkPgAAIAwSHwAAEAYJD4AACAMEh8AABAGiQ8AAAiDxAcAAIRB4gMAAMIoelfX29vb+ng8OjWln6+vr/T9/X35MhSr8aWU0ufn53fOLpxWY8ztw5SIUZn3c5E+/IkYdUU+F4sSn8fjkT4+Ptq0aqD39/esz1mNL6WUlmXJ2lLcaoy5fZgSMSrzfi7Shz8Ro67I56Ls29mBGbZvC470OpeocQOIhzU+AAAgDCo+QPpZ8UiJqgcAeEXFBwAAhEHFZ5LXCsMrKg7jRF3fcjUGrdmL56w/qfL5wjVVV+m1pndfkfgMUPMF8/wZTtYxIh3ns/Ho/Th4S/YiWJalKIHN/Zz3sT5Li3Os980oU10AACAM+YqPxfJlSca71/7tzytWfnLj27ZZ8W4r2t2/53ivYotc5bKqtk+vrqnbf6Pv22l9fenZN1R8AABAGLIVn9J529mZe8s7ynVd//X7Zt6d1GbyZz83c0GxYvWpt5K7Yy+uYrMcu8p1T8nZsdi7pqqzMtuhtnA5h1zik/ulpJIY9PoSff6e2Sdr6ZMyR7/j9WdmJnIt/m5tv8xO8rYULkA9qE+x3nH3eqD+ZXp3mUDO52dfU/fceQDmaXbf7an5rmj1u84w1QUAAMKQqfiUZK8qGbvn/V9aVHrOWD1eKmPvjpbHXnmBqIe+quUpdtXxVatX3yiei1dTsjnHgsfZAQAAbpCo+LSonIzMdEfeTW3npUcsaOxd6bHM0130XRaOhcUFraVaLCxVP0atKhmz45z992e4M3PT83tnauLjYarIarv3kPT4GJNbrWOwsKByS719vViLO3fPntq4ZiYd3hOevUXj6ktVmOoCAABhSEx15WbxCpniLNuyfespr97H1drd56vSBXheK2feqmGWRTr+udc+xf1kIn1nWRqTVHwAAEAYUys+LSo9Mxc1W8pwz5xt7HVUvfC6c+xV5ebu5y2yVunxOjbV3+HXyl5Mr9co9XNtxqJz9bFQGk/P7/0piU9OAqE4sGcOrBHHg5f7/bS32/TRf8v9HZZYSvQjPMGlfPxbOkvqWuy+PPqp2Nw2WNkNvoTqOclUFwAACENicfOWaoaYy9q0wJWjOxPPpfYa1sftlpdYolQmPah975aSO9WeGorH4En9GkLFBwAAhDGs4rOtELRYqKaQ7fbedOvobyjE7lVpxU5l4X0rVseb+h2mIm/VaYs8jFuLMVDxAQAAYcit8ckx6+4k9/HmFhmwehYd/Q6RSo8+6+2HLblbW6hf2xX03iZkeOJjtdOP2j3y5XmzL+RW+65WhHitJzwR+gj2tHjsvufvUzXq2sNUFwAACMPEVJfCHWhJG2rfTBthF2B1uXdWHu7ArFd7RohwjLzFo6D1dUH5OlN7LZw57qj4AACAMIZVfEqzQk93IdYXwlp6dUErEWL0qrTvlO+mYUuLsWS1mmzpmjl8qsvzCy5ftZzWUn6SLSLLxyXC9M2rll9IwJGW0z4R3j83C1NdAAAgjCmLmz3fOZU+9p67iFaBUluAI63vkr2Ne6oIOlq+pR35qPgAAIAwTDzOrqomM4+4uHsPb84er/edpKe1Q1bbDQ17a32uxlTJ+RlhfPa8XpH4VDi7wB91lvIC5j2UW2MoGXOKY6LXDrmwQ/namvt3uSHO0yp+proAAEAYVHwqnGWdNYvV1LN49fapUD5ONdNQd6ooo49FyaO/yv00gsX4a/dCe/6c9al1y22v0bu6TMUHAACE0b3i42nBYy3eweWHtf5rsQDfyvmas97HSiyRKa/ZGS1KnE+jzt2miY+1LwWgFbUL1N2pgavPKbPabuy705/bZFjxjQF8Z/42+maFqS4AABDGUvg46z8ppb/7Naebv9Z1/XX1IcPxpeQ/xqz4UiJGcYzTP4hRGjH+4TG+osQHAADAMqa6AABAGCQ+AAAgDBIfAAAQBokPAAAIg8QHAACEQeIDAADCIPEBAABhFL2y4u3tbX08Hp2a0s/X11f6/v6+3BvcanwppfT5+fmdsxmV1Rhz+zAlYlTm/VykD38iRl2Rz8WixOfxeKSPj482rRro/f0963NW40sppWVZsnbWtBpjbh+mRIzKvJ+L9OFPxKgr8rnIVBcAAAiDxAcAAIRB4gMAAMIg8QEAAGEULW5GH8vy74Xn67pOaAm2nv1CXwCAHyYSn73E4EnlS6n0S/IsJmiJlADVjEu142LhegFgHqa6AABAGLIVn9w7z9l346/tXJblX22hugMlkccj08qY7er88z4eFc5BKj4AACAMuYpP6d2oYnbc4o56diULPrQ4n6xUiF7beXTu7FVpzz6v5G5fXPWv0jFQbVet0lmMLevxX8U++hyUSXxKTmjrgyCl8xhmftHkfnl4Z+XL/kjLBb7Pz1s/Jk9H8exNUytoedzVvoDO2uBFbUJe8ztUqBcwmOoCAABhSFR8vD5+yqPtx7yVsZV4LJX34KGSta5rUZVm77xTil+pLa1d9U/ONLPl81ip7VR8AABAGBIVn+1dizW1WaxivIptGq22WnJ07Gbf5cz++7hvrzKz7deSPt5ea5Uqgy0fCLli5Zywej1WGldHJBKfPWoH6i6rg7g11eNQ2y7FeFqdOzmxWT5PrU0j3Gmf5b1jahYDnxm9gLv071hIHHKptpupLgAAEIZcxUc1QxyN4zBf1D6wWOlRrLzNZGV7kNJ+a/1ov8o4Vp0qL2HpgRUqPgAAIAy5io9XJY+OqmfL3tTMqXvrR+sVk5IHJKyt7cmlvmlcrTtjU3Uncq9buFhB4iOEAa+v5KKpVErf8vpKldebi6sXBiu1vVRtH1qO+cheH5fuodbzuHh72uyIpfYz1QUAAMKg4jOYpawY9vUu66tWtV4pTG+MYqE/Umo/hXX132aNASv9EQkVHwAAEIZcxUdx7QDa2VvkPavPe98BWnpMeI/qwtASnjaDu0JcQB4qPgAAIAy5is8TlR89Xp6IqWWl2tFz8zoL/X62dcRZ+y291kGpLb20jNHDBoFoRybxKdnnBhgpZ88NS+M2ysW+9JHm3M9FOX4esF8O9jDVBQAAwpCp+GC+vd1vIy0OLaF4DLi7vWapMuddyW7brf8u7Gm51IKKDwAACEOu4mNxzYQnHP/fchdDRj9O6lr2D5WCse6sqeK89GOvL++ut5NLfKDvzm691qfOvL2c1LM770jii9MeS0/l4drey4RbnZdMdQEAgDCo+GBXyykv7p7niXiXy2PnduRcZ64qzNGrsFEeamBxMwAAQAW5ig/VAS1H86qvd9V3+m3WHXrtDr/Q1Ks6yVjob3uMSzeUzP29lvG9+Fur/pRJfKx3bGn7o52Qyk+LeemLPVGmfVqUwS0mPT3bvLe4dJbSGy31fssRIcYjvb8nmOoCAABhyFR8XlnJZmsfmbV0J15arTmLSbny49mdLQhyfveWhTGdUpxFoTVUz89I01tbHmPaM2rcUfEBAABhyFZ8rCrNzHveibe2rdZYaXMkIxdrK6+HOaumelk30evOeHa/3onLSt9dOVuv5rVKObrCKJf4WO68lNruNaDKa1xe9Oyf2V+MR3JfsHv1O9SNTHhGuhpXXr/wr+S8Ouf5GYs3pLPGHVNdAAAgjKUkQ1yW5Z+U0t/9mtPNX+u6/rr6kOH4UvIfY1Z8KRGjOMbpH8QojRj/8BhfUeIDAABgGVNdAAAgDBIfAAAQBokPAAAIg8QHAACEQeIDAADCIPEBAABhkPgAAIAwil5Z8fb2tj4ej05N6efr6yt9f39f7o1tNb6UUvr8/PzO2YzKaoy5fZgSMSrzfi7Shz8Ro67I52JR4vN4PNLHx0ebVg30/v6e9Tmr8aWU0rIsWTtrWo0xtw9TIkZl3s9F+vAnYtQV+VxkqgsAAIRB4gMAAMIg8QEAAGGQ+AAAgDCKFjfPtiy/F2irv1H+2c49Z23f+zn1WBGTlXMRwDy134W9ySY+ZwdsWRa5C+5Ze2s+50luzGp9Wuo1TqvxlIxRxXPxzDY2S+1+5SWO3lS/eD2r+S4c3RdMdQEAgDDkKj6l2eLMrL1n9Ub5bmQvU29xLNTvYi1X63q23Uuly4K9fqSq8ZvlyrKHZQ6l43D7+dHf51R8AABAGDIVn6NscZsBnt3tjMyOr+4sStpioYpwtd7Kq7uxKax/mdE/CnF7cqcPPVQSrliseOX0qeUK1lNp20Z9n0skPrknZ8tplRoWT7C7PCc2Z0qT26jHae/GxOu5oOTOtTBCP3mO7dXsJQK1yfW6rtOum0x1AQCAMCQqPqVmZoqtqFePWk7nbX/f7KrdHWcxH03Jeru7zj0Gyqy0c0/Ltnsbm0/q8ZRc+yxfL3Mcxdd7upyKDwAACEO24lN6N9I7Q7zKuK/am5Oxq9+pPN1pZ+tKEupdPThw9nnLLFc6ctZTlFYHWIyujzWFbUkkPmfP+M8+KXOnpLbtLWXlolPbTotPJ+xNp85eRFgjt51W4sFPZw+B7OELc57SY8852Q9TXQAAIAyJis+VnIWxI7NjxUfte6pdyNzr949S2qce+94j1fF2pGV7PTwY4p218XlHzv59PVDxAQAAYchVfEp3Ce6VGd7ZlKn291sWYW3T2WPqs+5cAPw0e13okdyHCc6+eyx8b9z5DhzVb8MTn1YdN3pgK55ICkoXLlt+ouZqwTNsiPQEGzT1mEafNWZbxDK67Ux1AQCAMOSmus5wFzZHbVWj9FFbbyLFasVe1S5SP1GhRC+WziMqPgAAIIxhFR9Lm9j1vCvydsdVuvtvJKqLLKOL2Cecm35ZGc8Ka3ueqPgAAIAwJNb4KGesym3roeY9P6+8HrOau2bLT7HBPuWngdDmMXUL1xi174lhiY+1fQhe373VqpPUj0PO7ti5L2x9/TkAY5Dw2MKSgbGY6gIAAGEMn+qy+q6YngtVFUuVZ20pvTvxssj36lFoi+Ma/lyNQw/nomelswKK/XnWdoX2UvEBAABhSCxutlIRaLlAS32tz5XaRdAW+vlMTiXMap8+WXzv2LbNyu3soWS8RTs2Flmu9JxRau+UxOfspY9nnx+Jhbz/Y/2LXIF64mepj0vOSdXj3YrFJBX7PCewau1lqgsAAIQxdarLQlWlZiHv2e/I+Vml7Lh1JUAxxlYsVU2eLLW5dNHusiwhKrKvIsTozV6V0tK5uZVbcW0dX8m4p+IDAADCkFjcvJepKWa7e3eUORRjuVJzd11yfKwsaM9l8RHiq+qk8rjNOZ57sURa/wNdyudWbwqxSyQ+e84uwCpfmla+JEbY64+rkq2HaS+PCY+qO220Pn3wqkUfWh0HVuVeK9Qe9Gml17lXc0yY6gIAAGHIVnysaTFdZyWbb/3eMossvgvJS7UD9RgDujxWelT3JKLiAwAAwpCv+Fieh859lM9DPF60uCNWPE5sdOfD1QMDJej7sWoe7bbYR7nv45wZm1ziQynWDyulW6/JTkpt+0DloYIaFm80WooWs7UHJ6y0M5f6wwRMdQEAgDCWkkxzWZZ/Ukp/92tON3+t6/rr6kOG40vJf4xZ8aVEjOIYp38QozRi/MNjfEWJDwAAgGVMdQEAgDBIfAAAQBgkPgAAIAwSHwAAEAaJDwAACIPEBwAAhEHiAwAAwiDxAQAAYZD4AACAMP4P6a4zeY5hWWgAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - } + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAABECAYAAABu1lQcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAADy0lEQVR4nO3cQVLkOBAF0PTEHIFeTx2C+58A7kCvhztoFkwHBOFy2WW7UCrfW3Y4OpSWbL5Shqm1FgAAVfz10wMAAHgk4QcAKEX4AQBKEX4AgFKEHwCgFOEHAChF+AEAShF+AIBShB8AoJS/t1z89PTULpfLSUM519vbW7y/v09L12SuLyLi9fX1vbX2a+mazDWumcMINfbOs/ghc43W6afRa8xcX8T1Z3FT+LlcLvHy8nLcqB7o+fn55jWZ64uImKbp961rMte4Zg4j1Ng7z+KHzDVap59GrzFzfRHXn0XHXgBAKcIPAFDKpmMvtpum+ePU1tqDRwIAROj8AADF6Pwc5FqH59b1FTpAo9Q6N8fZawLGNE2T99MCnR8AoBSdnwNs7fpAT0bpzO2V8T58f/dkGjvb3XvC8If18Un4OcnSIvu6IDO+cNcaKRSOVMtRvt6TzOv3+9xmeCbPWI9Z5jPLOPfyzjmXYy8AoJRTOz8VPxDdWt/o9yO7W7uvDF2Ca/buLFtrdqc/4Oj3apajEWttbI/OCzo/AEApp3R+lhL6kem9lx3K2nFU3bn0Mk9b+LCwhuzd6UxjvdeanydZ78O9489a7x9r369L1+29B6eEn7Uf++6VfeGPLGvQO2rcPf+NjaOCWtY5zuqMD317/nh4y/rq+Xm75lZ92epZ0uO7wrEXAFDKw3/VfW+a7Xmnck321jrzvs5hjzubORmfnzNkma+IXGP9SVlOAqrM59o65+br+1zO/V97u306PwBAKf7I4cmqpPwR7NmpzO1QsuxEtxp1TY82T3Myd6GXOgRz/9ZbXZnv/dHu+cD76PeOzg8AUEqazk+2bxXO/BW9nmWbp7XW1tJrB+iseRlpjjPZ8ht72Tt1o66xUes6osO19K3PUboPP9kf3K9GXewjM2efRnoWM1nzg+Deuel1ffe2edjLs/Ohp7/z59gLACil+85P1rSfddx7Za575OOgI8bTW017ZKwl45jvtfWYuXdZxnmEpVp7ug86PwBAKcIPAFCK8AMAlCL8AAClCD8AQCnCDwBQivADAJQi/AAApQg/AEApwg8AUIrwAwCUIvwAAKUIPwBAKcIPAFCK8AMAlCL8AAClCD8AQCnCDwBQivADAJQi/AAApQg/AEApwg8AUIrwAwCUMrXW1l88Tf9GxO/zhnOqf1prv5YuSF5fxPg13qwvQo0JjL5OI8av0Tr93+g1Jq8v4kqNm8IPAEB2jr0AgFKEHwCgFOEHAChF+AEAShF+AIBShB8AoBThBwAoRfgBAEoRfgCAUv4Dzge+SeXbn58AAAAASUVORK5CYII=\n", + "text/plain": [ + "
" ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" }, { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "KBuWx-FtSouR", - "outputId": "3498bba6-068c-413a-b1a9-74e98192b543", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 153 - } - }, - "source": [ - "# Filling occluded images\n", - "occlude_start_row = 14\n", - "num_generated_images = 10\n", - "samples = np.copy(x_test_quantised[0:num_generated_images, :, :, :])\n", - "samples = samples / (q_levels - 1)\n", - "samples[:, occlude_start_row:, :, :] = 0\n", - "\n", - "fig = plt.figure(figsize=(10, 10))\n", - "\n", - "for i in range(10):\n", - " ax = fig.add_subplot(1, 10, i + 1)\n", - " ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary)\n", - " plt.xticks(np.array([]))\n", - " plt.yticks(np.array([]))\n", - "\n", - "for i in range(occlude_start_row, height):\n", - " for j in range(width):\n", - " logits = gated_pixelcnn(samples)\n", - " next_sample = tf.random.categorical(logits[:, i, j, :], 1)\n", - " samples[:, i, j, 0] = (next_sample.numpy() / (q_levels - 1))[:, 0]\n", - "\n", - "fig = plt.figure(figsize=(10, 10))\n", - "\n", - "for i in range(10):\n", - " ax = fig.add_subplot(1, 10, i + 1)\n", - " ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary)\n", - " plt.xticks(np.array([]))\n", - " plt.yticks(np.array([]))\n", - "plt.show()" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAABECAYAAABu1lQcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAADy0lEQVR4nO3cQVLkOBAF0PTEHIFeTx2C+58A7kCvhztoFkwHBOFy2WW7UCrfW3Y4OpSWbL5Shqm1FgAAVfz10wMAAHgk4QcAKEX4AQBKEX4AgFKEHwCgFOEHAChF+AEAShF+AIBShB8AoJS/t1z89PTULpfLSUM519vbW7y/v09L12SuLyLi9fX1vbX2a+mazDWumcMINfbOs/ghc43W6afRa8xcX8T1Z3FT+LlcLvHy8nLcqB7o+fn55jWZ64uImKbp961rMte4Zg4j1Ng7z+KHzDVap59GrzFzfRHXn0XHXgBAKcIPAFDKpmMvtpum+ePU1tqDRwIAROj8AADF6Pwc5FqH59b1FTpAo9Q6N8fZawLGNE2T99MCnR8AoBSdnwNs7fpAT0bpzO2V8T58f/dkGjvb3XvC8If18Un4OcnSIvu6IDO+cNcaKRSOVMtRvt6TzOv3+9xmeCbPWI9Z5jPLOPfyzjmXYy8AoJRTOz8VPxDdWt/o9yO7W7uvDF2Ca/buLFtrdqc/4Oj3apajEWttbI/OCzo/AEApp3R+lhL6kem9lx3K2nFU3bn0Mk9b+LCwhuzd6UxjvdeanydZ78O9489a7x9r369L1+29B6eEn7Uf++6VfeGPLGvQO2rcPf+NjaOCWtY5zuqMD317/nh4y/rq+Xm75lZ92epZ0uO7wrEXAFDKw3/VfW+a7Xmnck321jrzvs5hjzubORmfnzNkma+IXGP9SVlOAqrM59o65+br+1zO/V97u306PwBAKf7I4cmqpPwR7NmpzO1QsuxEtxp1TY82T3Myd6GXOgRz/9ZbXZnv/dHu+cD76PeOzg8AUEqazk+2bxXO/BW9nmWbp7XW1tJrB+iseRlpjjPZ8ht72Tt1o66xUes6osO19K3PUboPP9kf3K9GXewjM2efRnoWM1nzg+Deuel1ffe2edjLs/Ohp7/z59gLACil+85P1rSfddx7Za575OOgI8bTW017ZKwl45jvtfWYuXdZxnmEpVp7ug86PwBAKcIPAFCK8AMAlCL8AAClCD8AQCnCDwBQivADAJQi/AAApQg/AEApwg8AUIrwAwCUIvwAAKUIPwBAKcIPAFCK8AMAlCL8AAClCD8AQCnCDwBQivADAJQi/AAApQg/AEApwg8AUIrwAwCUMrXW1l88Tf9GxO/zhnOqf1prv5YuSF5fxPg13qwvQo0JjL5OI8av0Tr93+g1Jq8v4kqNm8IPAEB2jr0AgFKEHwCgFOEHAChF+AEAShF+AIBShB8AoBThBwAoRfgBAEoRfgCAUv4Dzge+SeXbn58AAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - }, - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAABECAYAAABu1lQcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAFtklEQVR4nO3dS5LbNhAAUDCVI9jrzCF8/xPYd7DXmTswC8U18hRFkRQ/3ej3li55Ck0AVKMBUsM4jg0AoIq/rm4AAMCZJD8AQCmSHwCgFMkPAFCK5AcAKEXyAwCUIvkBAEqR/AAApUh+AIBS/l7z4S9fvoxvb28HNeVYP3/+bO/v78PcZzLH11prP378eB/H8evcZzLHuKQPWxNjdObiTeYYjdMPvceYOb7WHs/FVcnP29tb+/79+36tOtG3b9+efiZzfK21NgzDr2efyRzjkj5sTYzRmYs3mWM0Tj/0HmPm+Fp7PBdtewEApUh+AIBSVm17sd4wTG+njuN4cksAgNZUfgCAYlR+dvKowvPs8xUqQL3EOtXH2WMC+jQMg/vTDJUfAKAUlZ8drK36QCS9VOZelfE6fL73ZGo7623dYfjN+Pgg+TnI3CC7H5AZb7hL9ZQU9hTLXu6vSebx+7lvM8zJI8Zjlv7M0s5Xueccy7YXAFDKoZWfigdE18bX+/XI7tnqK0OV4JFXV5bjOFqdXmDv+2qWrRFjrW9n5wsqPwBAKYdUfuYy9D2z9ygrlKXtqLpyidJPazhYWEP26nSmtm615Psk63XY2v6s8f629P4697lXr8Ehyc/Sw76vyj7we5Y10dur3ZHfsbFXopa1j7M64qBv5MPDa8ZX5Pn2yLP4ssUzJ+K9wrYXAFDK6Y+6v5rNRl6pPJK9tM60+z6MuLKZknH+HCFLf7WWq61XyrITUKU/l8Y51V+f+3Lqb71a7VP5AQBK8ZLDg1XJ8nvwykplaoWSZSW6Vq9jurd+mpK5Cj1XIZj6t2hxZb72e9tywHvv+47KDwBQSprKT7azCkc+ohdZtn5aamksUStAR/VLT32cyZon9rJX6nodY73GtUeFa+6sz17CJz/ZJ+69Xgd7z/TZh57mYiZLvgi29k3U8R1t8fAqc+cm0nv+bHsBAKWErfxsXeVkXh30Tr/FUPlt1JlX4EsP/G79W5yj4rWPOO9UfgCAUkJUfvbICiNn05Hbtpe5R06X/v8K1ymbCtW6zHGsrQZliXXq0eb7sz8RKwncHHVYee+xe1ny02u5tuqkXDPgH70ZuYeDjVE9e1Pqlr9FTGc8KXOljHFVnDN7vgH/iOtn2wsAKOXQyk+kx9q4zlzfRV+lrn1nRdQ4ftuyGutl7kXvm71Fn1vPbG1/L+O1J0v65Ow3YKv8AAClhDjw3JPqq461q7WMq9JeVqJR28U+en3b+iMVYmQ/Kj8AQCkqP+xqSVXk0Wd6W7n1Fg85ZKymftZLdZXnrvq1e8kPu5r7Yc8l/y+7XuKAK1Q7hM+fzuxX214AQClhKz8y+9yWvoU1Qj/3+sJNbnrYBnpV9BeIrumjqDHwp2dvGr96Xqr8AAClHFr5WfvYs4y+hqz9nLXdVWWpPO6p54cJeoihgq0PvZzdv6dse0UocXG9aDevil+OW2T6QrXQ6o++unk2tq+8TmvvEWvzgSOSJdteAEApoQ48Rz+Ut6dMq2nq6L1Ca37lU+l74d7W6siZ1+no7asj/77KDwBQymmVnzWHn4dh6CbL35K9R4z9ld8Jih7PlIhtvsLUdYhYHXr2WG3vIvbJGku+HyIckj3aHv0YuVIWqQ9VfgCAUk4/8/M5y3uU6UbOXqeseXLo2WezxT4l60o08zXfw5L478fx1deresVniUzXYe3rUXr55fqs98tnts7PMypElx94fvYYfNZEYK69WbYSWptu11yfRI1jqc/tzzbuzhC9j/XZTebrcN/2tYnQkrijJE2vJu9Xz8W5ZHVLbGc+CGTbCwAoZViTUQ3D8G9r7ddxzTnUP+M4fp37QPL4Wus/xqfxtSbGBHofp631H6Nx+r/eY0weX2sPYlyV/AAAZGfbCwAoRfIDAJQi+QEASpH8AAClSH4AgFIkPwBAKZIfAKAUyQ8AUIrkBwAo5T903wuZqmcetwAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - } + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAABECAYAAABu1lQcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAFtklEQVR4nO3dS5LbNhAAUDCVI9jrzCF8/xPYd7DXmTswC8U18hRFkRQ/3ej3li55Ck0AVKMBUsM4jg0AoIq/rm4AAMCZJD8AQCmSHwCgFMkPAFCK5AcAKEXyAwCUIvkBAEqR/AAApUh+AIBS/l7z4S9fvoxvb28HNeVYP3/+bO/v78PcZzLH11prP378eB/H8evcZzLHuKQPWxNjdObiTeYYjdMPvceYOb7WHs/FVcnP29tb+/79+36tOtG3b9+efiZzfK21NgzDr2efyRzjkj5sTYzRmYs3mWM0Tj/0HmPm+Fp7PBdtewEApUh+AIBSVm17sd4wTG+njuN4cksAgNZUfgCAYlR+dvKowvPs8xUqQL3EOtXH2WMC+jQMg/vTDJUfAKAUlZ8drK36QCS9VOZelfE6fL73ZGo7623dYfjN+Pgg+TnI3CC7H5AZb7hL9ZQU9hTLXu6vSebx+7lvM8zJI8Zjlv7M0s5Xueccy7YXAFDKoZWfigdE18bX+/XI7tnqK0OV4JFXV5bjOFqdXmDv+2qWrRFjrW9n5wsqPwBAKYdUfuYy9D2z9ygrlKXtqLpyidJPazhYWEP26nSmtm615Psk63XY2v6s8f629P4697lXr8Ehyc/Sw76vyj7we5Y10dur3ZHfsbFXopa1j7M64qBv5MPDa8ZX5Pn2yLP4ssUzJ+K9wrYXAFDK6Y+6v5rNRl6pPJK9tM60+z6MuLKZknH+HCFLf7WWq61XyrITUKU/l8Y51V+f+3Lqb71a7VP5AQBK8ZLDg1XJ8nvwykplaoWSZSW6Vq9jurd+mpK5Cj1XIZj6t2hxZb72e9tywHvv+47KDwBQSprKT7azCkc+ohdZtn5aamksUStAR/VLT32cyZon9rJX6nodY73GtUeFa+6sz17CJz/ZJ+69Xgd7z/TZh57mYiZLvgi29k3U8R1t8fAqc+cm0nv+bHsBAKWErfxsXeVkXh30Tr/FUPlt1JlX4EsP/G79W5yj4rWPOO9UfgCAUkJUfvbICiNn05Hbtpe5R06X/v8K1ymbCtW6zHGsrQZliXXq0eb7sz8RKwncHHVYee+xe1ny02u5tuqkXDPgH70ZuYeDjVE9e1Pqlr9FTGc8KXOljHFVnDN7vgH/iOtn2wsAKOXQyk+kx9q4zlzfRV+lrn1nRdQ4ftuyGutl7kXvm71Fn1vPbG1/L+O1J0v65Ow3YKv8AAClhDjw3JPqq461q7WMq9JeVqJR28U+en3b+iMVYmQ/Kj8AQCkqP+xqSVXk0Wd6W7n1Fg85ZKymftZLdZXnrvq1e8kPu5r7Yc8l/y+7XuKAK1Q7hM+fzuxX214AQClhKz8y+9yWvoU1Qj/3+sJNbnrYBnpV9BeIrumjqDHwp2dvGr96Xqr8AAClHFr5WfvYs4y+hqz9nLXdVWWpPO6p54cJeoihgq0PvZzdv6dse0UocXG9aDevil+OW2T6QrXQ6o++unk2tq+8TmvvEWvzgSOSJdteAEApoQ48Rz+Ut6dMq2nq6L1Ca37lU+l74d7W6siZ1+no7asj/77KDwBQymmVnzWHn4dh6CbL35K9R4z9ld8Jih7PlIhtvsLUdYhYHXr2WG3vIvbJGku+HyIckj3aHv0YuVIWqQ9VfgCAUk4/8/M5y3uU6UbOXqeseXLo2WezxT4l60o08zXfw5L478fx1deresVniUzXYe3rUXr55fqs98tnts7PMypElx94fvYYfNZEYK69WbYSWptu11yfRI1jqc/tzzbuzhC9j/XZTebrcN/2tYnQkrijJE2vJu9Xz8W5ZHVLbGc+CGTbCwAoZViTUQ3D8G9r7ddxzTnUP+M4fp37QPL4Wus/xqfxtSbGBHofp631H6Nx+r/eY0weX2sPYlyV/AAAZGfbCwAoRfIDAJQi+QEASpH8AAClSH4AgFIkPwBAKZIfAKAUyQ8AUIrkBwAo5T903wuZqmcetwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" ] - }, - { - "cell_type": "code", - "metadata": { - "id": "W7bzeLMjPNcI", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" } - ] -} \ No newline at end of file + ], + "source": [ + "# Filling occluded images\n", + "occlude_start_row = 14\n", + "num_generated_images = 10\n", + "samples = np.copy(x_test_quantised[0:num_generated_images, :, :, :])\n", + "samples = samples / (q_levels - 1)\n", + "samples[:, occlude_start_row:, :, :] = 0\n", + "\n", + "fig = plt.figure(figsize=(10, 10))\n", + "\n", + "for i in range(10):\n", + " ax = fig.add_subplot(1, 10, i + 1)\n", + " ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary)\n", + " plt.xticks(np.array([]))\n", + " plt.yticks(np.array([]))\n", + "\n", + "for i in range(occlude_start_row, height):\n", + " for j in range(width):\n", + " logits = gated_pixelcnn(samples)\n", + " next_sample = tf.random.categorical(logits[:, i, j, :], 1)\n", + " samples[:, i, j, 0] = (next_sample.numpy() / (q_levels - 1))[:, 0]\n", + "\n", + "fig = plt.figure(figsize=(10, 10))\n", + "\n", + "for i in range(10):\n", + " ax = fig.add_subplot(1, 10, i + 1)\n", + " ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary)\n", + " plt.xticks(np.array([]))\n", + " plt.yticks(np.array([]))\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "W7bzeLMjPNcI" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "cropped gated_pixelcnn.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/WIP/6-gated_pixelcnn_cropped/gated_pixelcnn_vs_cropped.ipynb b/WIP/6-gated_pixelcnn_cropped/gated_pixelcnn_vs_cropped.ipynb index aede3e3..c5d69fa 100644 --- a/WIP/6-gated_pixelcnn_cropped/gated_pixelcnn_vs_cropped.ipynb +++ b/WIP/6-gated_pixelcnn_cropped/gated_pixelcnn_vs_cropped.ipynb @@ -1,1135 +1,1255 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "gated_pixelcnn_vs_cropped.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "HgFGN07idT26" + }, + "source": [ + "# Masked vs cropped implementation for Gated PixelCNN" + ] }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "HgFGN07idT26", - "colab_type": "text" - }, - "source": [ - "# Converting masked-based implementation to cropping-based implementation" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ObS7YqtCbC33", - "colab_type": "code", - "colab": {} - }, - "source": [ - "import tensorflow as tf\n", - "import tensorflow.keras as keras\n", - "import numpy as np\n", - "import math" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "iCMR2mKLbt_l", - "colab_type": "code", - "colab": {} - }, - "source": [ - "test_ones_2d = np.ones([1, 5, 5, 1], dtype='float32')\n", - "test_ones_3d = np.ones([1, 5, 5, 5, 1], dtype='float32')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "NycU0IQZb1X1", - "colab_type": "code", - "colab": {} - }, - "source": [ - "def print_3d(matrix_3d):\n", - " for i in range(matrix_3d.shape[0]):\n", - " print(f'Depth {i}')\n", - " print(matrix_3d[i,...])" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "AeH21Zkzcrt5", - "colab_type": "code", - "outputId": "d47da908-4659-45aa-f540-e01d3e510bdc", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 527 - } - }, - "source": [ - "print_3d(test_ones_3d.squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Depth 0\n", - "[[1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]]\n", - "Depth 1\n", - "[[1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]]\n", - "Depth 2\n", - "[[1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]]\n", - "Depth 3\n", - "[[1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]]\n", - "Depth 4\n", - "[[1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "xqrqiDnbfoqW", - "colab_type": "code", - "outputId": "12aef41a-4992-464c-8d38-8fd732286c57", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 102 - } - }, - "source": [ - "print(test_ones_2d[0,:,:,0].squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "[[1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1.]]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1UcUkEj0d7wh", - "colab_type": "text" - }, - "source": [ - "## Creating 2D masked solution to check results with cropped solution later" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "83mZFyondaAT", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class MaskedConv2D(tf.keras.layers.Layer):\n", - " def __init__(self,\n", - " mask_type,\n", - " filters,\n", - " kernel_size,\n", - " strides=1,\n", - " padding='same',\n", - " kernel_initializer='glorot_uniform',\n", - " bias_initializer='zeros'):\n", - " super(MaskedConv2D, self).__init__()\n", - "\n", - " assert mask_type in {'A', 'B', 'V'}\n", - " self.mask_type = mask_type\n", - "\n", - " self.filters = filters\n", - "\n", - " if isinstance(kernel_size, int):\n", - " kernel_size = (kernel_size, kernel_size)\n", - " self.kernel_size = kernel_size\n", - "\n", - " self.strides = strides\n", - " self.padding = padding.upper()\n", - " self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)\n", - " self.bias_initializer = tf.keras.initializers.get(bias_initializer)\n", - "\n", - " def build(self, input_shape):\n", - " kernel_h, kernel_w = self.kernel_size\n", - "\n", - " self.kernel = self.add_weight(\"kernel\",\n", - " shape=(kernel_h,\n", - " kernel_w,\n", - " int(input_shape[-1]),\n", - " self.filters),\n", - " initializer=self.kernel_initializer,\n", - " trainable=True)\n", - "\n", - " self.bias = self.add_weight(\"bias\",\n", - " shape=(self.filters,),\n", - " initializer=self.bias_initializer,\n", - " trainable=True)\n", - "\n", - " mask = np.ones(self.kernel.shape, dtype=np.float32)\n", - "\n", - " if kernel_h % 2 != 0: \n", - " center_h = kernel_h // 2\n", - " else:\n", - " center_h = (kernel_h - 1) // 2\n", - "\n", - " if kernel_w % 2 != 0: \n", - " center_w = kernel_w // 2\n", - " else:\n", - " center_w = (kernel_w - 1) // 2\n", - "\n", - " if self.mask_type == 'V':\n", - " mask[center_h + 1:, :, :, :] = 0.\n", - " else:\n", - " mask[:center_h, :, :] = 0.\n", - " mask[center_h, center_w + (self.mask_type == 'B'):, :, :] = 0.\n", - " mask[center_h + 1:, :, :] = 0.\n", - "\n", - " self.mask = tf.constant(mask, dtype=tf.float32, name='mask')\n", - "\n", - " def call(self, input):\n", - " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n", - " x = tf.nn.conv2d(input, masked_kernel, strides=[1, self.strides, self.strides, 1], padding=self.padding)\n", - " x = tf.nn.bias_add(x, self.bias)\n", - " return x" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nMXqIeSUkdcV", - "colab_type": "text" - }, - "source": [ - "### Tests with kernel_size 3" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "M62GZQe8ixvy", - "colab_type": "text" - }, - "source": [ - "#### Vertical stack" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "kjUrpEtIg7p9", - "colab_type": "code", - "outputId": "8df4b5ee-1d0e-4f7d-9009-beaf8c8133ee", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 204 - } - }, - "source": [ - "mask_type = 'V'\n", - "kernel_size=(3, 3)\n", - "\n", - "conv = MaskedConv2D(mask_type=mask_type,\n", - " filters=1,\n", - " kernel_size=kernel_size, \n", - " padding='same',\n", - " kernel_initializer='ones', \n", - " bias_initializer='zeros')\n", - "\n", - "result_v = conv(test_ones_2d)\n", - "\n", - "print('MASK')\n", - "print(conv.mask.numpy().squeeze())\n", - "print('')\n", - "print('OUTPUT')\n", - "print(result_v.numpy().squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "MASK\n", - "[[1. 1. 1.]\n", - " [1. 1. 1.]\n", - " [0. 0. 0.]]\n", - "\n", - "OUTPUT\n", - "[[2. 3. 3. 3. 2.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PFqvGa439Z2o", - "colab_type": "text" - }, - "source": [ - "#### Feeding horizontal" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "oq_JTwdE9lr6", - "colab_type": "code", - "outputId": "a5eecd56-e068-4d9f-ff4b-01bd048f0bbb", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 119 - } - }, - "source": [ - "padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n", - "cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", - "\n", - "x = padding(result_v)\n", - "result = cropping(x)\n", - "\n", - "print('OUTPUT')\n", - "print(result.numpy().squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "OUTPUT\n", - "[[0. 0. 0. 0. 0.]\n", - " [2. 3. 3. 3. 2.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iHY6UE2_p5oc", - "colab_type": "text" - }, - "source": [ - "#### Horizontal stack A" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "4_q_IunZkFmj", - "colab_type": "code", - "outputId": "1bfa1a88-7ae3-4a7e-d43e-46e7e683cce4", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 204 - } - }, - "source": [ - "mask_type = 'A'\n", - "kernel_size=(3, 3)\n", - "\n", - "conv = MaskedConv2D(mask_type=mask_type,\n", - " filters=1,\n", - " kernel_size=kernel_size, \n", - " padding='same',\n", - " kernel_initializer='ones', \n", - " bias_initializer='zeros')\n", - "\n", - "result = conv(test_ones_2d)\n", - "\n", - "print('MASK')\n", - "print(conv.mask.numpy().squeeze())\n", - "print('')\n", - "print('OUTPUT')\n", - "print(result.numpy().squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "MASK\n", - "[[0. 0. 0.]\n", - " [1. 0. 0.]\n", - " [0. 0. 0.]]\n", - "\n", - "OUTPUT\n", - "[[0. 1. 1. 1. 1.]\n", - " [0. 1. 1. 1. 1.]\n", - " [0. 1. 1. 1. 1.]\n", - " [0. 1. 1. 1. 1.]\n", - " [0. 1. 1. 1. 1.]]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jMuS-vgWqAWK", - "colab_type": "text" - }, - "source": [ - "#### Horizontal stack B" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "5yeB5h2tkSs_", - "colab_type": "code", - "outputId": "9e7346b9-d360-42dd-85b4-a9a22ba2bacb", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 204 - } - }, - "source": [ - "mask_type = 'B'\n", - "kernel_size=(3, 3)\n", - "\n", - "conv = MaskedConv2D(mask_type=mask_type,\n", - " filters=1,\n", - " kernel_size=kernel_size, \n", - " padding='same',\n", - " kernel_initializer='ones', \n", - " bias_initializer='zeros')\n", - "\n", - "result = conv(test_ones_2d)\n", - "\n", - "print('MASK')\n", - "print(conv.mask.numpy().squeeze())\n", - "print('')\n", - "print('OUTPUT')\n", - "print(result.numpy().squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "MASK\n", - "[[0. 0. 0.]\n", - " [1. 1. 0.]\n", - " [0. 0. 0.]]\n", - "\n", - "OUTPUT\n", - "[[1. 2. 2. 2. 2.]\n", - " [1. 2. 2. 2. 2.]\n", - " [1. 2. 2. 2. 2.]\n", - " [1. 2. 2. 2. 2.]\n", - " [1. 2. 2. 2. 2.]]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1NxkQ3U1knbE", - "colab_type": "text" - }, - "source": [ - "### Tests with kernel_size 4" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WNykK-WpqMlu", - "colab_type": "text" - }, - "source": [ - "#### Vertical stack" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "p3DTiFYCk57Y", - "colab_type": "code", - "outputId": "01d4c588-acde-43ba-8b52-7b5b2fddc52a", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 221 - } - }, - "source": [ - "mask_type = 'V'\n", - "kernel_size=(4, 4)\n", - "\n", - "padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n", - "\n", - "conv = MaskedConv2D(mask_type=mask_type,\n", - " filters=1,\n", - " kernel_size=kernel_size, \n", - " padding='same',\n", - " kernel_initializer='ones', \n", - " bias_initializer='zeros')\n", - "\n", - "cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", - "\n", - "x = padding(test_ones_2d)\n", - "x = conv(x)\n", - "result = cropping(x)\n", - "\n", - "print('MASK')\n", - "print(conv.mask.numpy().squeeze())\n", - "print('')\n", - "print('OUTPUT')\n", - "print(result.numpy().squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "MASK\n", - "[[1. 1. 1. 1.]\n", - " [1. 1. 1. 1.]\n", - " [0. 0. 0. 0.]\n", - " [0. 0. 0. 0.]]\n", - "\n", - "OUTPUT\n", - "[[0. 0. 0. 0. 0.]\n", - " [3. 4. 4. 3. 2.]\n", - " [6. 8. 8. 6. 4.]\n", - " [6. 8. 8. 6. 4.]\n", - " [6. 8. 8. 6. 4.]]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "E5jUGK3_qbT8", - "colab_type": "text" - }, - "source": [ - "#### Horizontal stack A" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "J5HSUo7Rk5xO", - "colab_type": "code", - "outputId": "d2eabc6f-1b49-43f9-f7d8-1cf13021c47b", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 221 - } - }, - "source": [ - "mask_type = 'A'\n", - "kernel_size=(4, 4)\n", - "\n", - "conv = MaskedConv2D(mask_type=mask_type,\n", - " filters=1,\n", - " kernel_size=kernel_size, \n", - " padding='same',\n", - " kernel_initializer='ones', \n", - " bias_initializer='zeros')\n", - "\n", - "result = conv(test_ones_2d)\n", - "\n", - "print('MASK')\n", - "print(conv.mask.numpy().squeeze())\n", - "print('')\n", - "print('OUTPUT')\n", - "print(result.numpy().squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "MASK\n", - "[[0. 0. 0. 0.]\n", - " [1. 0. 0. 0.]\n", - " [0. 0. 0. 0.]\n", - " [0. 0. 0. 0.]]\n", - "\n", - "OUTPUT\n", - "[[0. 1. 1. 1. 1.]\n", - " [0. 1. 1. 1. 1.]\n", - " [0. 1. 1. 1. 1.]\n", - " [0. 1. 1. 1. 1.]\n", - " [0. 1. 1. 1. 1.]]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KqORK7mLqvPP", - "colab_type": "text" - }, - "source": [ - "#### Horizontal stack B" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "_2V51aerk5l1", - "colab_type": "code", - "outputId": "b54f7147-0080-48a2-a84c-d3ca557c31d9", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 221 - } - }, - "source": [ - "mask_type = 'B'\n", - "kernel_size=(4, 4)\n", - "\n", - "conv = MaskedConv2D(mask_type=mask_type,\n", - " filters=1,\n", - " kernel_size=kernel_size, \n", - " padding='same',\n", - " kernel_initializer='ones', \n", - " bias_initializer='zeros')\n", - "\n", - "result = conv(test_ones_2d)\n", - "\n", - "print('MASK')\n", - "print(conv.mask.numpy().squeeze())\n", - "print('')\n", - "print('OUTPUT')\n", - "print(result.numpy().squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "MASK\n", - "[[0. 0. 0. 0.]\n", - " [1. 1. 0. 0.]\n", - " [0. 0. 0. 0.]\n", - " [0. 0. 0. 0.]]\n", - "\n", - "OUTPUT\n", - "[[1. 2. 2. 2. 2.]\n", - " [1. 2. 2. 2. 2.]\n", - " [1. 2. 2. 2. 2.]\n", - " [1. 2. 2. 2. 2.]\n", - " [1. 2. 2. 2. 2.]]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kZmui789Br2B", - "colab_type": "text" - }, - "source": [ - "## Creating 2D cropped solution" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "PxBNsvzhB1ec", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class VerticalCroppedConv2d(tf.keras.Model):\n", - " def __init__(self,\n", - " filters,\n", - " kernel_size,\n", - " kernel_initializer, \n", - " bias_initializer):\n", - " super(VerticalCroppedConv2d, self).__init__(name='')\n", - "\n", - " if isinstance(kernel_size, int):\n", - " kernel_size = (kernel_size, kernel_size)\n", - "\n", - " kernel_h, kernel_w = kernel_size\n", - "\n", - " self.padding = keras.layers.ZeroPadding2D(padding=((kernel_h-1, 0),(int((kernel_w-1)/2),int((kernel_w-1)/2))))\n", - "\n", - " self.conv = keras.layers.Conv2D(filters=filters,\n", - " kernel_size=kernel_size,\n", - " strides=1,\n", - " padding='valid',\n", - " kernel_initializer=kernel_initializer, \n", - " bias_initializer=bias_initializer)\n", - "\n", - " def call(self, input_value):\n", - "\n", - " x = self.padding(input_value)\n", - " x = self.conv(x)\n", - " out = self.cropping(x)\n", - "\n", - " return out\n", - "\n" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RfFuwKlrP2JU", - "colab_type": "text" - }, - "source": [ - "Example step by step" - ] + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Hi all, in this notebook we will compare the masked implemntation of the convolutions from the Gated PixelCNN versus the alternative sugexted in the paper, the use of convolutions operaritions with appropriate croppings and padding to achieve the same result.\n", + "Let's check out!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we willcheck if both implementation create the same result. For this we will create a 5x5 matrix filled with ones as our input example." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ObS7YqtCbC33" + }, + "outputs": [], + "source": [ + "import math\n", + "\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "from tensorflow import nn\n", + "from tensorflow.keras import initializers" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "iCMR2mKLbt_l" + }, + "outputs": [], + "source": [ + "test_ones_2d = np.ones([1, 5, 5, 1], dtype='float32')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 }, + "colab_type": "code", + "id": "xqrqiDnbfoqW", + "outputId": "12aef41a-4992-464c-8d38-8fd732286c57" + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "fH3I0lfoPdcH", - "colab_type": "code", - "outputId": "f142569c-f99d-4244-ab5d-1887243d0d29", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 204 - } - }, - "source": [ - "kernel_h = 2\n", - "kernel_w = 3\n", - "\n", - "kernel_size = (kernel_h, kernel_w)\n", - "\n", - "padding = keras.layers.ZeroPadding2D(padding=((kernel_h-1, 0),(int((kernel_w-1)/2),int((kernel_w-1)/2))))\n", - "\n", - "res = padding(test_ones_2d)\n", - "print(res.numpy().squeeze())\n", - "\n", - "conv = keras.layers.Conv2D(filters=1,\n", - " kernel_size=kernel_size,\n", - " strides=1,\n", - " padding='valid',\n", - " kernel_initializer='ones', \n", - " bias_initializer='zeros')\n", - "\n", - "res2 = conv(res)\n", - "print(res2.numpy().squeeze())\n", - "\n", - "\n", - "\n" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "[[0. 0. 0. 0. 0. 0. 0.]\n", - " [0. 1. 1. 1. 1. 1. 0.]\n", - " [0. 1. 1. 1. 1. 1. 0.]\n", - " [0. 1. 1. 1. 1. 1. 0.]\n", - " [0. 1. 1. 1. 1. 1. 0.]\n", - " [0. 1. 1. 1. 1. 1. 0.]]\n", - "[[2. 3. 3. 3. 2.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]]\n" - ], - "name": "stdout" - } - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n" + ] + } + ], + "source": [ + "print(test_ones_2d[0,:,:,0].squeeze())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1UcUkEj0d7wh" + }, + "source": [ + "Now, let's copy themasked implementation that we have been using for our Gated PixelCNN models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Masked convolutions" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "83mZFyondaAT" + }, + "outputs": [], + "source": [ + "class MaskedConv2D(keras.layers.Layer):\n", + " \"\"\"Convolutional layers with masks extended to work with Gated PixelCNN.\n", + "\n", + " Convolutional layers with simple implementation of masks type A and B for\n", + " autoregressive models. Extended version to work with the verticala and horizontal\n", + " stacks from the Gated PixelCNN model.\n", + "\n", + " Arguments:\n", + " mask_type: one of `\"V\"`, `\"A\"` or `\"B\".`\n", + " filters: Integer, the dimensionality of the output space (i.e. the number of output\n", + " filters in the convolution).\n", + " kernel_size: An integer or tuple/list of 2 integers, specifying the height and width\n", + " of the 2D convolution window.\n", + " Can be a single integer to specify the same value for all spatial dimensions.\n", + " strides: An integer or tuple/list of 2 integers, specifying the strides of the\n", + " convolution along the height and width.\n", + " Can be a single integer to specify the same value for all spatial dimensions.\n", + " Specifying any stride value != 1 is incompatible with specifying any\n", + " `dilation_rate` value != 1.\n", + " padding: one of `\"valid\"` or `\"same\"` (case-insensitive).\n", + " kernel_initializer: Initializer for the `kernel` weights matrix.\n", + " bias_initializer: Initializer for the bias vector.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " mask_type,\n", + " filters,\n", + " kernel_size,\n", + " strides=1,\n", + " padding='same',\n", + " kernel_initializer='glorot_uniform',\n", + " bias_initializer='zeros'):\n", + " super(MaskedConv2D, self).__init__()\n", + "\n", + " assert mask_type in {'A', 'B', 'V'}\n", + " self.mask_type = mask_type\n", + "\n", + " self.filters = filters\n", + "\n", + " if isinstance(kernel_size, int):\n", + " kernel_size = (kernel_size, kernel_size)\n", + " self.kernel_size = kernel_size\n", + "\n", + " self.strides = strides\n", + " self.padding = padding.upper()\n", + " self.kernel_initializer = initializers.get(kernel_initializer)\n", + " self.bias_initializer = initializers.get(bias_initializer)\n", + "\n", + " def build(self, input_shape):\n", + " kernel_h, kernel_w = self.kernel_size\n", + "\n", + " self.kernel = self.add_weight('kernel',\n", + " shape=(kernel_h,\n", + " kernel_w,\n", + " int(input_shape[-1]),\n", + " self.filters),\n", + " initializer=self.kernel_initializer,\n", + " trainable=True)\n", + "\n", + " self.bias = self.add_weight('bias',\n", + " shape=(self.filters,),\n", + " initializer=self.bias_initializer,\n", + " trainable=True)\n", + "\n", + " mask = np.ones(self.kernel.shape, dtype=np.float32)\n", + "\n", + " # Get centre of the filter for even or odd dimensions\n", + " if kernel_h % 2 != 0:\n", + " center_h = kernel_h // 2\n", + " else:\n", + " center_h = (kernel_h - 1) // 2\n", + "\n", + " if kernel_w % 2 != 0:\n", + " center_w = kernel_w // 2\n", + " else:\n", + " center_w = (kernel_w - 1) // 2\n", + "\n", + " if self.mask_type == 'V':\n", + " mask[center_h + 1:, :, :, :] = 0.\n", + " else:\n", + " mask[:center_h, :, :] = 0.\n", + " mask[center_h, center_w + (self.mask_type == 'B'):, :, :] = 0.\n", + " mask[center_h + 1:, :, :] = 0.\n", + "\n", + " self.mask = tf.constant(mask, dtype=tf.float32, name='mask')\n", + "\n", + " def call(self, input):\n", + " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n", + " x = nn.conv2d(input,\n", + " masked_kernel,\n", + " strides=[1, self.strides, self.strides, 1],\n", + " padding=self.padding)\n", + " x = nn.bias_add(x, self.bias)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With this implementation, we will recreate all convolutional operation that occur inside of the Gated Block. These operations are:\n", + "\n", + "- Vertical stack\n", + "- Vertical to horizontal stack\n", + "- Horizontal stack - convolution layer with mask type \"A\"\n", + "- Horizontal stack - convolution layer with mask type \"B\"\n", + "\n", + "\n", + "\n", + " IMAGE GATED BLOCK\n", + " \n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "M62GZQe8ixvy" + }, + "source": [ + "## Vertical stack" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 }, + "colab_type": "code", + "id": "kjUrpEtIg7p9", + "outputId": "8df4b5ee-1d0e-4f7d-9009-beaf8c8133ee" + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "TFexYspQWEpo", - "colab_type": "code", - "outputId": "d0533bb8-38e8-482c-8e65-4f6522519087", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 51 - } - }, - "source": [ - "conv.weights[0].numpy().squeeze()" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "array([[1., 1., 1.],\n", - " [1., 1., 1.]], dtype=float32)" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 16 - } - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "MASK\n", + "[[1. 1. 1.]\n", + " [1. 1. 1.]\n", + " [0. 0. 0.]]\n", + "\n", + "OUTPUT\n", + "[[2. 3. 3. 3. 2.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]]\n" + ] + } + ], + "source": [ + "mask_type = 'V'\n", + "kernel_size = (3, 3)\n", + "\n", + "conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size,\n", + " padding='same',\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "result_v = conv(test_ones_2d)\n", + "\n", + "print('MASK')\n", + "print(conv.mask.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result_v.numpy().squeeze())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PFqvGa439Z2o" + }, + "source": [ + "## Vertical to horizontal stack" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 119 }, + "colab_type": "code", + "id": "oq_JTwdE9lr6", + "outputId": "a5eecd56-e068-4d9f-ff4b-01bd048f0bbb" + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "AfyRyUmTNYZ8", - "colab_type": "code", - "colab": {} - }, - "source": [ - "def build_test_croppedv_stack_2d(input_shape=(5, 5, 1), kernel_size=3):\n", - " inputs = tf.keras.layers.Input(shape=input_shape)\n", - " \n", - " x = VerticalCroppedConv2d(\n", - " filters=1,\n", - " kernel_size=kernel_size, \n", - " kernel_initializer='ones', \n", - " bias_initializer='zeros')(inputs)\n", - "\n", - " stack = tf.keras.Model(inputs=inputs, outputs=x)\n", - " stack.compile(optimizer='adam', loss='mse')\n", - " return stack" - ], - "execution_count": 0, - "outputs": [] + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUT\n", + "[[2. 3. 3. 3. 2.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]]\n", + "\n", + "OUTPUT\n", + "[[0. 0. 0. 0. 0.]\n", + " [2. 3. 3. 3. 2.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]]\n" + ] + } + ], + "source": [ + "padding = keras.layers.ZeroPadding2D(padding=((1, 0), 0))\n", + "cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", + "\n", + "x = padding(result_v)\n", + "result = cropping(x)\n", + "\n", + "print('INPUT')\n", + "print(result_v.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iHY6UE2_p5oc" + }, + "source": [ + "## Horizontal stack - convolution layer with mask type \"A\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 }, + "colab_type": "code", + "id": "4_q_IunZkFmj", + "outputId": "1bfa1a88-7ae3-4a7e-d43e-46e7e683cce4" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "Nskk-zJwN3Em", - "colab_type": "text" - }, - "source": [ - "###Tests with kernel_size 3" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "MASK\n", + "[1. 0. 0.]\n", + "\n", + "OUTPUT\n", + "[[0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]]\n" + ] + } + ], + "source": [ + "mask_type = 'A'\n", + "kernel_size = (1, 3)\n", + "\n", + "conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size,\n", + " padding='same',\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "result = conv(test_ones_2d)\n", + "\n", + "print('MASK')\n", + "print(conv.mask.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "jMuS-vgWqAWK" + }, + "source": [ + "## Horizontal stack - convolution layer with mask type \"B\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 }, + "colab_type": "code", + "id": "5yeB5h2tkSs_", + "outputId": "9e7346b9-d360-42dd-85b4-a9a22ba2bacb" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "DdcDFcMbxwpZ", - "colab_type": "text" - }, - "source": [ - "#### Vertical stack" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "MASK\n", + "[1. 1. 0.]\n", + "\n", + "OUTPUT\n", + "[[1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]]\n" + ] + } + ], + "source": [ + "mask_type = 'B'\n", + "kernel_size = (1, 3)\n", + "\n", + "conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size,\n", + " padding='same',\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "result = conv(test_ones_2d)\n", + "\n", + "print('MASK')\n", + "print(conv.mask.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using the results of the masked approach as reference, let's check the cropped method." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kZmui789Br2B" + }, + "source": [ + "# Cropped and padded convolutions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Vertical stack\n", + "\n", + "First, let's checkout this operation that some strategic padding and applying the convolution in \"valid\" mode to achieve the same result from the masked version. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 }, + "colab_type": "code", + "id": "fH3I0lfoPdcH", + "outputId": "f142569c-f99d-4244-ab5d-1887243d0d29" + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "jc06sHDoNzx8", - "colab_type": "code", - "outputId": "750bf826-6cf4-400b-f693-51ad6350fd15", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 187 - } - }, - "source": [ - "kernel_size=(2, 3)\n", - "kernel_h, kernel_w = kernel_size\n", - "\n", - "padding2 = keras.layers.ZeroPadding2D(padding=((kernel_h-1, 0),(int((kernel_w-1)/2),int((kernel_w-1)/2))))\n", - "conv = keras.layers.Conv2D(filters=1,\n", - " kernel_size=kernel_size,\n", - " strides=1,\n", - " padding='valid',\n", - " kernel_initializer='ones', \n", - " bias_initializer='zeros')\n", - "\n", - "x = padding2(test_ones_2d)\n", - "result_v = conv(x)\n", - "\n", - "print('KERNEL')\n", - "print(conv.weights[0].numpy().squeeze())\n", - "print('')\n", - "print('OUTPUT')\n", - "print(result_v.numpy().squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "KERNEL\n", - "[[1. 1. 1.]\n", - " [1. 1. 1.]]\n", - "\n", - "OUTPUT\n", - "[[2. 3. 3. 3. 2.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]]\n" - ], - "name": "stdout" - } - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUT\n", + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n", + "\n", + "PADDED INPUT\n", + "[[0. 0. 0. 0. 0. 0. 0.]\n", + " [0. 1. 1. 1. 1. 1. 0.]\n", + " [0. 1. 1. 1. 1. 1. 0.]\n", + " [0. 1. 1. 1. 1. 1. 0.]\n", + " [0. 1. 1. 1. 1. 1. 0.]\n", + " [0. 1. 1. 1. 1. 1. 0.]]\n", + "\n", + "CONV FILTER\n", + "[[1. 1. 1.]\n", + " [1. 1. 1.]]\n", + "\n", + "OUTPUT\n", + "[[2. 3. 3. 3. 2.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]]\n" + ] + } + ], + "source": [ + "kernel_h = 2\n", + "kernel_w = 3\n", + "\n", + "kernel_size = (kernel_h, kernel_w)\n", + "\n", + "padding = keras.layers.ZeroPadding2D(padding=((kernel_h - 1, 0), (int((kernel_w - 1) / 2), int((kernel_w - 1) / 2))))\n", + "\n", + "res = padding(test_ones_2d)\n", + "\n", + "conv = keras.layers.Conv2D(filters=1,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " padding='valid',\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "result_v = conv(res)\n", + "\n", + "print('INPUT')\n", + "print(test_ones_2d.squeeze())\n", + "print('')\n", + "print('PADDED INPUT')\n", + "print(res.numpy().squeeze())\n", + "print('')\n", + "print('CONV FILTER')\n", + "print(conv.weights[0].numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result_v.numpy().squeeze())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RfFuwKlrP2JU" + }, + "source": [ + "Now, let's implement a layer that we will include all the previous operations." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "PxBNsvzhB1ec" + }, + "outputs": [], + "source": [ + "class VerticalConv2D(keras.layers.Conv2D):\n", + " \"\"\"https://github.com/JesseFarebro/PixelCNNPP/blob/master/layers/VerticalConv2D.py\"\"\"\n", + "\n", + " def __init__(self,\n", + " filters,\n", + " kernel_size,\n", + " **kwargs):\n", + " if not isinstance(kernel_size, tuple):\n", + " kernel_size = (kernel_size // 2 + 1, kernel_size)\n", + "\n", + " super(VerticalConv2D, self).__init__(filters, kernel_size, **kwargs)\n", + "\n", + " self.pad = tf.keras.layers.ZeroPadding2D(\n", + " (\n", + " (kernel_size[0] - 1, 0), # Top, Bottom\n", + " (kernel_size[1] // 2, kernel_size[1] // 2), # Left, Right\n", + " )\n", + " )\n", + "\n", + " def call(self, inputs):\n", + " inputs = self.pad(inputs)\n", + " output = super(VerticalConv2D, self).call(inputs)\n", + "\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "AfyRyUmTNYZ8" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "SvWpzQFGEGGm", - "colab_type": "text" - }, - "source": [ - "#### Feeding horizontal" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUT\n", + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n", + "\n", + "CONV FILTER\n", + "[[1. 1. 1.]\n", + " [1. 1. 1.]]\n", + "\n", + "OUTPUT\n", + "[[2. 3. 3. 3. 2.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]]\n" + ] + } + ], + "source": [ + "kernel_h = 2\n", + "kernel_w = 3\n", + "\n", + "kernel_size = (kernel_h, kernel_w)\n", + "\n", + "conv = VerticalConv2D(filters=1,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " padding='valid',\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "result_v = conv(test_ones_2d)\n", + "\n", + "print('INPUT')\n", + "print(test_ones_2d.squeeze())\n", + "print('')\n", + "print('CONV FILTER')\n", + "print(conv.weights[0].numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result_v.numpy().squeeze())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "SvWpzQFGEGGm" + }, + "source": [ + "## Vertical to horizontal stack\n", + "In this operation, the implementation continue the same." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 119 }, + "colab_type": "code", + "id": "5jLEZhYtOZgi", + "outputId": "d293ab6f-7fa6-4aec-cc68-9ab3ae7185dd" + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "5jLEZhYtOZgi", - "colab_type": "code", - "outputId": "d293ab6f-7fa6-4aec-cc68-9ab3ae7185dd", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 119 - } - }, - "source": [ - "padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n", - "cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", - "\n", - "x = padding(result_v)\n", - "result = cropping(x)\n", - "\n", - "print('OUTPUT')\n", - "print(result.numpy().squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "OUTPUT\n", - "[[0. 0. 0. 0. 0.]\n", - " [2. 3. 3. 3. 2.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]\n", - " [4. 6. 6. 6. 4.]]\n" - ], - "name": "stdout" - } - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUT\n", + "[[2. 3. 3. 3. 2.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]]\n", + "\n", + "OUTPUT\n", + "[[0. 0. 0. 0. 0.]\n", + " [2. 3. 3. 3. 2.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]]\n" + ] + } + ], + "source": [ + "padding = keras.layers.ZeroPadding2D(padding=((1, 0), 0))\n", + "cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", + "\n", + "x = padding(result_v)\n", + "result = cropping(x)\n", + "\n", + "print('INPUT')\n", + "print(result_v.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "MQLekDEaEUUT" + }, + "source": [ + "## Horizontal stack - convolution layer with mask type \"A\"\n", + "Again, let's check each operation step by step." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 170 }, + "colab_type": "code", + "id": "bHiwKZniEk5A", + "outputId": "ebd659c5-d899-4d6f-9c81-5c821cf4ea61" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "MQLekDEaEUUT", - "colab_type": "text" - }, - "source": [ - "#### Horizontal stack A" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUT\n", + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n", + "\n", + "CONV FILTER\n", + "1.0\n", + "\n", + "CONVOLUTION RESULT\n", + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n", + "\n", + "PADDED RESULT\n", + "[[0. 1. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1. 1.]]\n", + "\n", + "CROPPED RESULT\n", + "[[0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]]\n" + ] + } + ], + "source": [ + "kernel_size = (1, 1)\n", + "conv = keras.layers.Conv2D(filters=1,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "padding = keras.layers.ZeroPadding2D(padding=(0, (1, 0)))\n", + "cropping = keras.layers.Cropping2D(cropping=(0, (0, 1)))\n", + "\n", + "res = conv(test_ones_2d)\n", + "res_2 = padding(res)\n", + "res_3 = cropping(res_2)\n", + "\n", + "print('INPUT')\n", + "print(test_ones_2d.squeeze())\n", + "print('')\n", + "print('CONV FILTER')\n", + "print(conv.weights[0].numpy().squeeze())\n", + "print('')\n", + "print('CONVOLUTION RESULT')\n", + "print(res.numpy().squeeze())\n", + "print('')\n", + "print('PADDED RESULT')\n", + "print(res_2.numpy().squeeze())\n", + "print('')\n", + "print('CROPPED RESULT')\n", + "print(res_3.numpy().squeeze())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: Since our input test just have one channel, the convolution 1x1 looks like did not perform any change." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "IvmGrDziEadf" + }, + "source": [ + "## Horizontal stack - convolution layer with mask type \"B\"\n", + "The step by step of the mask type \"B\" convolution layer is a little different." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 170 }, + "colab_type": "code", + "id": "pRKJFE4TFx4I", + "outputId": "75d34ade-0983-49a5-f157-b98975c22560" + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "bHiwKZniEk5A", - "colab_type": "code", - "outputId": "ebd659c5-d899-4d6f-9c81-5c821cf4ea61", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 170 - } - }, - "source": [ - "kernel_size=(1, 1)\n", - "kernel_h, kernel_w = kernel_size\n", - "\n", - "conv = keras.layers.Conv2D(filters=1,\n", - " kernel_size=kernel_size,\n", - " strides=1,\n", - " kernel_initializer='ones', \n", - " bias_initializer='zeros')\n", - "\n", - "padding = keras.layers.ZeroPadding2D(padding=(0,(1,0)))\n", - "cropping = keras.layers.Cropping2D(cropping=(0, (0, 1)))\n", - "\n", - "x = conv(test_ones_2d)\n", - "x = padding(x)\n", - "result = cropping(x)\n", - "\n", - "print('KERNEL')\n", - "print(conv.weights[0].numpy().squeeze())\n", - "print('')\n", - "print('OUTPUT')\n", - "print(result.numpy().squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "KERNEL\n", - "1.0\n", - "\n", - "OUTPUT\n", - "[[0. 1. 1. 1. 1.]\n", - " [0. 1. 1. 1. 1.]\n", - " [0. 1. 1. 1. 1.]\n", - " [0. 1. 1. 1. 1.]\n", - " [0. 1. 1. 1. 1.]]\n" - ], - "name": "stdout" - } - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUT\n", + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n", + "\n", + "PADDED INPUT\n", + "[[0. 1. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1. 1.]]\n", + "\n", + "CONV FILTER\n", + "[1. 1.]\n", + "\n", + "RESULT\n", + "[[1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]]\n" + ] + } + ], + "source": [ + "kernel_size = (1, 2)\n", + "kernel_h, kernel_w = kernel_size\n", + "\n", + "padding = keras.layers.ZeroPadding2D(padding=((int((kernel_h - 1) / 2), int((kernel_h - 1) / 2)), (kernel_w - 1, 0)))\n", + "conv = keras.layers.Conv2D(filters=1,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " padding='valid',\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "res = padding(test_ones_2d)\n", + "result = conv(res)\n", + "\n", + "print('INPUT')\n", + "print(test_ones_2d.squeeze())\n", + "print('')\n", + "print('PADDED INPUT')\n", + "print(res.numpy().squeeze())\n", + "print('')\n", + "print('CONV FILTER')\n", + "print(conv.weights[0].numpy().squeeze())\n", + "print('')\n", + "print('RESULT')\n", + "print(result.numpy().squeeze())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, we also implemented a layer version encapsulation these operations" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "class HorizontalConv2D(keras.layers.Conv2D):\n", + " def __init__(self,\n", + " filters,\n", + " kernel_size,\n", + " **kwargs):\n", + " if not isinstance(kernel_size, tuple):\n", + " kernel_size = (kernel_size // 2 + 1,) * 2\n", + "\n", + " super(HorizontalConv2D, self).__init__(filters, kernel_size, **kwargs)\n", + " self.pad = tf.keras.layers.ZeroPadding2D(\n", + " (\n", + " (kernel_size[0] - 1, 0), # (Top, Bottom)\n", + " (kernel_size[1] - 1, 0), # (Left, Right)\n", + " )\n", + " )\n", + "\n", + " def call(self, inputs):\n", + " inputs = self.pad(inputs)\n", + " outputs = super(HorizontalConv2D, self).call(inputs)\n", + "\n", + " return outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "IvmGrDziEadf", - "colab_type": "text" - }, - "source": [ - "#### Horizontal stack B" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUT\n", + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n", + "\n", + "CONV FILTER\n", + "[1. 1.]\n", + "\n", + "RESULT\n", + "[[1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]]\n" + ] + } + ], + "source": [ + "kernel_size = (1, 2)\n", + "conv = HorizontalConv2D(filters=1,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "result = conv(test_ones_2d)\n", + "\n", + "print('INPUT')\n", + "print(test_ones_2d.squeeze())\n", + "print('')\n", + "print('CONV FILTER')\n", + "print(conv.weights[0].numpy().squeeze())\n", + "print('')\n", + "print('RESULT')\n", + "print(result.numpy().squeeze())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Execution time\n", + "Now we will compare the time that takes to perform each convolutional operation." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "def measure_time(conv_fn):\n", + " exec_time = []\n", + " n_iter = 100\n", + " for _ in range(n_iter):\n", + " test_input = np.random.rand(128, 256, 256, 1).astype('float32') \n", + " start = time.time()\n", + " conv_fn(test_input)\n", + " exec_time.append(time.time() - start)\n", + " exec_time = np.array(exec_time, dtype='float32')\n", + " return exec_time.mean(), exec_time.std()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Vertical stack" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "pRKJFE4TFx4I", - "colab_type": "code", - "outputId": "75d34ade-0983-49a5-f157-b98975c22560", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 170 - } - }, - "source": [ - "kernel_size=(1, 2)\n", - "kernel_h, kernel_w = kernel_size\n", - "\n", - "padding2 = keras.layers.ZeroPadding2D(padding=((int((kernel_h-1)/2),int((kernel_h-1)/2)), (kernel_w-1, 0)))\n", - "conv = keras.layers.Conv2D(filters=1,\n", - " kernel_size=kernel_size,\n", - " strides=1,\n", - " padding='valid',\n", - " kernel_initializer='ones', \n", - " bias_initializer='zeros')\n", - "\n", - "\n", - "x = padding2(test_ones_2d)\n", - "result = conv(x)\n", - "\n", - "print('KERNEL')\n", - "print(conv.weights[0].numpy().squeeze())\n", - "print('')\n", - "print('OUTPUT')\n", - "print(result.numpy().squeeze())" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "KERNEL\n", - "[1. 1.]\n", - "\n", - "OUTPUT\n", - "[[1. 2. 2. 2. 2.]\n", - " [1. 2. 2. 2. 2.]\n", - " [1. 2. 2. 2. 2.]\n", - " [1. 2. 2. 2. 2.]\n", - " [1. 2. 2. 2. 2.]]\n" - ], - "name": "stdout" - } - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Vertical stack\n", + "Masked convolution: 0.01410292 +- 0.00891058 seconds\n", + "Cropped padded convolution: 0.01386628 +- 0.00675169 seconds\n" + ] + } + ], + "source": [ + "mask_type = 'V'\n", + "kernel_size = (3, 3)\n", + "masked_conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=32,\n", + " kernel_size=kernel_size,\n", + " padding='same',\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "@tf.function\n", + "def test_masked_fn(x):\n", + " _ = masked_conv(x)\n", + " \n", + "\n", + "masked_time = measure_time(test_masked_fn)\n", + "# ----------------------------------------------------------------\n", + "\n", + "kernel_size = (2, 3)\n", + "cropped_conv = VerticalConv2D(filters=32,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " padding='valid',\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "@tf.function\n", + "def test_cropped_fn(x):\n", + " _ = cropped_conv(x)\n", + "\n", + "cropped_time = measure_time(test_cropped_fn)\n", + "# ----------------------------------------------------------------\n", + "\n", + "print(\"Vertical stack\")\n", + "print(f\"Masked convolution: {masked_time[0]:.8f} +- {masked_time[1]:.8f} seconds\")\n", + "print(f\"Cropped padded convolution: {cropped_time[0]:.8f} +- {cropped_time[1]:.8f} seconds\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Horizontal stack - convolution layer with mask type \"A\"" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "swYJ4XMofUWv", - "colab_type": "text" - }, - "source": [ - "REFERENCES\n", - "\n", - "https://wiki.math.uwaterloo.ca/statwiki/index.php?title=STAT946F17/Conditional_Image_Generation_with_PixelCNN_Decoders#Gated_PixelCNN\n", - "\n", - "https://www.slideshare.net/suga93/conditional-image-generation-with-pixelcnn-decoders\n", - "\n", - "https://www.youtube.com/watch?v=1BURwCCYNEI" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Horizontal stack - convolution layer with mask type 'A'\n", + "Masked convolution: 0.01360846 +- 0.00381987 seconds\n", + "Cropped padded convolution: 0.01365352 +- 0.00476047 seconds\n" + ] + } + ], + "source": [ + "mask_type = 'A'\n", + "kernel_size = (1, 3)\n", + "masked_conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size,\n", + " padding='same',\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "@tf.function\n", + "def test_masked_fn(x):\n", + " _ = masked_conv(x)\n", + " \n", + "masked_time = measure_time(test_masked_fn)\n", + "# ----------------------------------------------------------------\n", + "\n", + "kernel_size = (1, 1)\n", + "conv = keras.layers.Conv2D(filters=1,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "padding = keras.layers.ZeroPadding2D(padding=(0, (1, 0)))\n", + "cropping = keras.layers.Cropping2D(cropping=(0, (0, 1)))\n", + "\n", + "@tf.function\n", + "def test_cropped_fn(x):\n", + " x = conv(x)\n", + " x = padding(x)\n", + " x = cropping(x)\n", + "\n", + "cropped_time = measure_time(test_cropped_fn)\n", + "# ----------------------------------------------------------------\n", + "\n", + "print(\"Horizontal stack - convolution layer with mask type 'A'\")\n", + "print(f\"Masked convolution: {masked_time[0]:.8f} +- {masked_time[1]:.8f} seconds\")\n", + "print(f\"Cropped padded convolution: {cropped_time[0]:.8f} +- {cropped_time[1]:.8f} seconds\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Horizontal stack - convolution layer with mask type \"B\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "jTI9ts7i7Wch", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] + "name": "stdout", + "output_type": "stream", + "text": [ + "Horizontal stack - convolution layer with mask type 'B'\n", + "Masked convolution: 0.01353339 +- 0.00374499 seconds\n", + "Cropped padded convolution: 0.01384839 +- 0.00734248 seconds\n" + ] } - ] -} \ No newline at end of file + ], + "source": [ + "mask_type = 'B'\n", + "kernel_size = (1, 3)\n", + "masked_conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size,\n", + " padding='same',\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "@tf.function\n", + "def test_masked_fn(x):\n", + " _ = masked_conv(x)\n", + " \n", + "masked_time = measure_time(test_masked_fn)\n", + "# ----------------------------------------------------------------\n", + "\n", + "kernel_size = (1, 2)\n", + "cropped_conv = HorizontalConv2D(filters=1,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " kernel_initializer='ones',\n", + " bias_initializer='zeros')\n", + "\n", + "@tf.function\n", + "def test_cropped_fn(x):\n", + " _ = cropped_conv(x)\n", + "\n", + "cropped_time = measure_time(test_cropped_fn)\n", + "# ----------------------------------------------------------------\n", + "\n", + "print(\"Horizontal stack - convolution layer with mask type 'B'\")\n", + "print(f\"Masked convolution: {masked_time[0]:.8f} +- {masked_time[1]:.8f} seconds\")\n", + "print(f\"Cropped padded convolution: {cropped_time[0]:.8f} +- {cropped_time[1]:.8f} seconds\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Altough its looks like cropped is better in the vertical convolution, the difference does not to look very significant." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "swYJ4XMofUWv" + }, + "source": [ + "# REFERENCES\n", + "\n", + "https://wiki.math.uwaterloo.ca/statwiki/index.php?title=STAT946F17/Conditional_Image_Generation_with_PixelCNN_Decoders#Gated_PixelCNN\n", + "\n", + "https://www.slideshare.net/suga93/conditional-image-generation-with-pixelcnn-decoders\n", + "\n", + "https://www.youtube.com/watch?v=1BURwCCYNEI" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "gated_pixelcnn_vs_cropped.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/WIP/6-gated_pixelcnn_cropped/masked_vs_cropped_even_filter_size.ipynb b/WIP/6-gated_pixelcnn_cropped/masked_vs_cropped_even_filter_size.ipynb new file mode 100644 index 0000000..42c3205 --- /dev/null +++ b/WIP/6-gated_pixelcnn_cropped/masked_vs_cropped_even_filter_size.ipynb @@ -0,0 +1,113 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tests with kernel_size 4" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mask_type = 'V'\n", + "kernel_size=(4, 4)\n", + "\n", + "padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n", + "\n", + "conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size, \n", + " padding='same',\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", + "\n", + "x = padding(test_ones_2d)\n", + "x = conv(x)\n", + "result = cropping(x)\n", + "\n", + "print('MASK')\n", + "print(conv.mask.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mask_type = 'A'\n", + "kernel_size=(1, 4)\n", + "\n", + "conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size, \n", + " padding='same',\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "result = conv(test_ones_2d)\n", + "\n", + "print('MASK')\n", + "print(conv.mask.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mask_type = 'B'\n", + "kernel_size=(1, 4)\n", + "\n", + "conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size, \n", + " padding='same',\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "result = conv(test_ones_2d)\n", + "\n", + "print('MASK')\n", + "print(conv.mask.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}