diff --git a/3 - PixelCNNs Blind spot and Gated PixelCNNs/PixelCNN_Blind_spot.ipynb b/3 - PixelCNNs Blind spot and Gated PixelCNNs/PixelCNN_Blind_spot.ipynb new file mode 100644 index 0000000..d582176 --- /dev/null +++ b/3 - PixelCNNs Blind spot and Gated PixelCNNs/PixelCNN_Blind_spot.ipynb @@ -0,0 +1,1108 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "PixelCNN - Blind_spot", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "GDNiXfftap_W" + }, + "source": [ + "# Autoregressive models - PixelCNNs blind spot in the receptive field and how to fix it " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gU9fNJ8F2JP6" + }, + "source": [ + "In our previous notebooks, we trained a generative model to create black and white drawings of numbers and to generate coloured images. Now we are going to look into the blind splot, one of the biggested drawbacks of the PixelCNNs and how to fix it. We will also define the Gated PixelCNN and analyse its performance on the MNIST dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Q6fjrw1KDC0p" + }, + "source": [ + "This implementation uses Tensorflow 2.0. We start by installing and importing the code dependencies.\n", + "\n", + "*Note: Here we are using float64 to get more precise values of the gradients and avoid false values.\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "QZxm9F41DBSG" + }, + "source": [ + "%tensorflow_version 2.x" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "w8mSPfpIdE7L" + }, + "source": [ + "import random as rn\n", + "\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.ticker import FixedLocator\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "from tensorflow import nn\n", + "from tensorflow.keras import initializers\n", + "from tensorflow.keras.utils import Progbar\n", + "\n", + "tf.keras.backend.set_floatx('float64')" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k110FfjlDuMx" + }, + "source": [ + "Like in the previous notebooks we are defining the Convolutional layers with masks and the ResidualBlocks. The Residual block is an important building block of Oord's seminal paper that originated the PixelCNN.\n", + "\n", + "*Note: Here we removed the ReLU activations to not mess with the gradients while we are investigating them." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "q-NdMaYseHKw" + }, + "source": [ + "class ResidualBlock(keras.Model):\n", + " \"\"\"Residual blocks that compose pixelCNN\n", + "\n", + " Blocks of layers with 3 convolutional layers and one residual connection.\n", + " Based on Figure 5 from [1] where h indicates number of filters.\n", + "\n", + " Refs:\n", + " [1] - Oord, A. V. D., Kalchbrenner, N., & Kavukcuoglu, K. (2016). Pixel recurrent\n", + " neural networks. arXiv preprint arXiv:1601.06759.\n", + " \"\"\"\n", + "\n", + " def __init__(self, h):\n", + " super(ResidualBlock, self).__init__(name='')\n", + "\n", + " self.conv2a = keras.layers.Conv2D(filters=h, kernel_size=1, strides=1)\n", + " self.conv2b = MaskedConv2D(mask_type='B', filters=h, kernel_size=3, strides=1)\n", + " self.conv2c = keras.layers.Conv2D(filters=2 * h, kernel_size=1, strides=1)\n", + "\n", + " def call(self, input_tensor):\n", + "# x = nn.relu(input_tensor)\n", + "# x = self.conv2a(x)\n", + " x = self.conv2a(input_tensor)\n", + "\n", + "# x = nn.relu(x)\n", + " x = self.conv2b(x)\n", + "\n", + "# x = nn.relu(x)\n", + " x = self.conv2c(x)\n", + "\n", + " x += input_tensor\n", + " return x" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "toxhCFUEbDEH" + }, + "source": [ + "## Blind Spots" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "00dbW6maaOad" + }, + "source": [ + "As we saw above, the original PixelCNN had a problem with its masked convolution that lead to a blind spot. Given a specific pixel, the use of masked convolution was not able to capture the information of all previous pixels. As we can see on the image below, the information on *j* is never used to predict *m*; and the information on *j, n, o* are not used to predict *q*. \n", + "\n", + "![masks.png]()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QKIiWy5z7Nnf" + }, + "source": [ + "Let's start by implementing the convolution and looking at how the blind spot propagates. You should be familiar with the class below, as we have used in the previous notebooks. " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "vhhJHjj05yhN" + }, + "source": [ + "class MaskedConv2D(keras.layers.Layer):\n", + " \"\"\"Convolutional layers with masks.\n", + "\n", + " Convolutional layers with simple implementation of masks type A and B for\n", + " autoregressive models.\n", + "\n", + " Arguments:\n", + " mask_type: one of `\"A\"` or `\"B\".`\n", + " filters: Integer, the dimensionality of the output space\n", + " (i.e. the number of output filters in the convolution).\n", + " kernel_size: An integer or tuple/list of 2 integers, specifying the\n", + " height and width of the 2D convolution window.\n", + " Can be a single integer to specify the same value for\n", + " all spatial dimensions.\n", + " strides: An integer or tuple/list of 2 integers,\n", + " specifying the strides of the convolution along the height and width.\n", + " Can be a single integer to specify the same value for\n", + " all spatial dimensions.\n", + " Specifying any stride value != 1 is incompatible with specifying\n", + " any `dilation_rate` value != 1.\n", + " padding: one of `\"valid\"` or `\"same\"` (case-insensitive).\n", + " kernel_initializer: Initializer for the `kernel` weights matrix.\n", + " bias_initializer: Initializer for the bias vector.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " mask_type,\n", + " filters,\n", + " kernel_size,\n", + " strides=1,\n", + " padding='same',\n", + " kernel_initializer='glorot_uniform',\n", + " bias_initializer='zeros'):\n", + " super(MaskedConv2D, self).__init__()\n", + "\n", + " assert mask_type in {'A', 'B'}\n", + " self.mask_type = mask_type\n", + "\n", + " self.filters = filters\n", + " self.kernel_size = kernel_size\n", + " self.strides = strides\n", + " self.padding = padding.upper()\n", + " self.kernel_initializer = initializers.get(kernel_initializer)\n", + " self.bias_initializer = initializers.get(bias_initializer)\n", + "\n", + " def build(self, input_shape):\n", + " self.kernel = self.add_weight('kernel',\n", + " shape=(self.kernel_size,\n", + " self.kernel_size,\n", + " int(input_shape[-1]),\n", + " self.filters),\n", + " initializer=self.kernel_initializer,\n", + " trainable=True)\n", + "\n", + " self.bias = self.add_weight('bias',\n", + " shape=(self.filters,),\n", + " initializer=self.bias_initializer,\n", + " trainable=True)\n", + "\n", + " center = self.kernel_size // 2\n", + "\n", + " mask = np.ones(self.kernel.shape, dtype=np.float64)\n", + " mask[center, center + (self.mask_type == 'B'):, :, :] = 0.\n", + " mask[center + 1:, :, :, :] = 0.\n", + "\n", + " self.mask = tf.constant(mask, dtype=tf.float64, name='mask')\n", + "\n", + " def call(self, input):\n", + " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n", + " x = nn.conv2d(input,\n", + " masked_kernel,\n", + " strides=[1, self.strides, self.strides, 1],\n", + " padding=self.padding)\n", + " x = nn.bias_add(x, self.bias)\n", + " return x" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "tj1b_aLmeKl9" + }, + "source": [ + "def plot_receptive_field(model, data):\n", + " \"\"\"\n", + " This function allows the visualisation of the receptive field\n", + " \"\"\"\n", + " with tf.GradientTape() as tape:\n", + " tape.watch(data)\n", + " prediction = model(data)\n", + " loss = prediction[:,5,5,0]\n", + "\n", + " gradients = tape.gradient(loss, data)\n", + "\n", + " gradients = np.abs(gradients.numpy().squeeze())\n", + " gradients = (gradients > 0).astype('float64')\n", + " gradients[5, 5] = 0.5\n", + "\n", + " fig = plt.figure()\n", + " ax = fig.add_subplot(1, 1, 1)\n", + "\n", + " plt.xticks(np.arange(0, 10, step=1))\n", + " plt.yticks(np.arange(0, 10, step=1))\n", + " ax.xaxis.set_minor_locator(FixedLocator(np.arange(0.5, 10.5, step=1)))\n", + " ax.yaxis.set_minor_locator(FixedLocator(np.arange(0.5, 10.5, step=1)))\n", + " plt.grid(which=\"minor\")\n", + " plt.imshow(gradients, vmin=0, vmax=1)\n", + " plt.show()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "F16Dw0EkdZAN" + }, + "source": [ + "# --------------------------------------------------------------------------------------------------------------\n", + "height = 10\n", + "width = 10\n", + "n_channel = 1\n", + "\n", + "data = tf.random.normal((1, height, width, n_channel))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "APMjCAwveKo6" + }, + "source": [ + "# 1 layer PixelCNN\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", + "model = tf.keras.Model(inputs=inputs, outputs=x)\n", + "\n", + "plot_receptive_field(model, data)\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "aUHo0N7-6IYa" + }, + "source": [ + "# 2 layer PixelCNN\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", + "x = ResidualBlock(h=1)(x)\n", + "\n", + "model = tf.keras.Model(inputs=inputs, outputs=x)\n", + "\n", + "plot_receptive_field(model, data)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "RFMjAntD6Kot" + }, + "source": [ + "# 3 layer PixelCNN\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", + "x = ResidualBlock(h=1)(x)\n", + "x = ResidualBlock(h=1)(x)\n", + "\n", + "model = tf.keras.Model(inputs=inputs, outputs=x)\n", + "\n", + "plot_receptive_field(model, data)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "zbsQRrol6O6W" + }, + "source": [ + "# 4 layer PixelCNN\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", + "x = ResidualBlock(h=1)(x)\n", + "x = ResidualBlock(h=1)(x)\n", + "x = ResidualBlock(h=1)(x)\n", + "\n", + "model = tf.keras.Model(inputs=inputs, outputs=x)\n", + "\n", + "plot_receptive_field(model, data)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "uaDjY7gK6aBa" + }, + "source": [ + "# 5 layer PixelCNN\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", + "x = ResidualBlock(h=1)(x)\n", + "x = ResidualBlock(h=1)(x)\n", + "x = ResidualBlock(h=1)(x)\n", + "x = ResidualBlock(h=1)(x)\n", + "\n", + "model = tf.keras.Model(inputs=inputs, outputs=x)\n", + "\n", + "plot_receptive_field(model, data)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "deJlu0x4cesJ" + }, + "source": [ + "## Gated Pixel CNN and the blind spot\n", + "\n", + "A big drawback of PixelCNNs is the blind spot in the receptive filed. This blind spot is a result of using a mask that gets propagated via convolution. To overcome the blind spot problem, the Gated Pixel CNN was introduced. This new model solved the blind spot by using two separate convolutions (i.e., the horizontal and the vertical stack). \n", + "While the horizontal stack consists of the pixels just before the predicted pixel, the vertical stack represents the rows above the predicted pixel.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2zK_0-Ck7_sW" + }, + "source": [ + "![stacks.png]()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_NfGQikYfrL2" + }, + "source": [ + "Let's re-implement the `MaskedConv2D` class, but now using the horizonal and vertical stack to correct for the blind spot." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8WqV14wAeC_k" + }, + "source": [ + "# --------------------------------------------------------------------------------------------------------------\n", + "\n", + "class MaskedConv2D(keras.layers.Layer):\n", + " \"\"\"Convolutional layers with masks.\n", + "\n", + " Convolutional layers with simple implementation of masks type A and B for\n", + " autoregressive models.\n", + "\n", + " Arguments:\n", + " mask_type: one of `\"A\"` or `\"B\".`\n", + " filters: Integer, the dimensionality of the output space\n", + " (i.e. the number of output filters in the convolution).\n", + " kernel_size: An integer or tuple/list of 2 integers, specifying the\n", + " height and width of the 2D convolution window.\n", + " Can be a single integer to specify the same value for\n", + " all spatial dimensions.\n", + " strides: An integer or tuple/list of 2 integers,\n", + " specifying the strides of the convolution along the height and width.\n", + " Can be a single integer to specify the same value for\n", + " all spatial dimensions.\n", + " Specifying any stride value != 1 is incompatible with specifying\n", + " any `dilation_rate` value != 1.\n", + " padding: one of `\"valid\"` or `\"same\"` (case-insensitive).\n", + " kernel_initializer: Initializer for the `kernel` weights matrix.\n", + " bias_initializer: Initializer for the bias vector.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " mask_type,\n", + " filters,\n", + " kernel_size,\n", + " strides=1,\n", + " padding='same',\n", + " kernel_initializer='glorot_uniform',\n", + " bias_initializer='zeros'):\n", + " super(MaskedConv2D, self).__init__()\n", + "\n", + " assert mask_type in {'A', 'B', 'V'}\n", + " self.mask_type = mask_type\n", + "\n", + " self.filters = filters\n", + " self.kernel_size = kernel_size\n", + " self.strides = strides\n", + " self.padding = padding.upper()\n", + " self.kernel_initializer = initializers.get(kernel_initializer)\n", + " self.bias_initializer = initializers.get(bias_initializer)\n", + "\n", + " def build(self, input_shape):\n", + " kernel_h = self.kernel_size\n", + " kernel_w = self.kernel_size\n", + " self.kernel = self.add_weight('kernel',\n", + " shape=(self.kernel_size,\n", + " self.kernel_size,\n", + " int(input_shape[-1]),\n", + " self.filters),\n", + " initializer=self.kernel_initializer,\n", + " trainable=True)\n", + "\n", + " self.bias = self.add_weight('bias',\n", + " shape=(self.filters,),\n", + " initializer=self.bias_initializer,\n", + " trainable=True)\n", + " mask = np.ones(self.kernel.shape, dtype=np.float64)\n", + "\n", + " # Get centre of the filter for even or odd dimensions\n", + " if kernel_h % 2 != 0:\n", + " center_h = kernel_h // 2\n", + " else:\n", + " center_h = (kernel_h - 1) // 2\n", + "\n", + " if kernel_w % 2 != 0:\n", + " center_w = kernel_w // 2\n", + " else:\n", + " center_w = (kernel_w - 1) // 2\n", + "\n", + " if self.mask_type == 'V':\n", + " mask[center_h + 1:, :, :, :] = 0.\n", + " else:\n", + " mask[:center_h, :, :] = 0.\n", + " mask[center_h, center_w + (self.mask_type == 'B'):, :, :] = 0.\n", + " mask[center_h + 1:, :, :] = 0.\n", + "\n", + " self.mask = tf.constant(mask, dtype=tf.float64, name='mask')\n", + "\n", + " def call(self, input):\n", + " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n", + " x = nn.conv2d(input,\n", + " masked_kernel,\n", + " strides=[1, self.strides, self.strides, 1],\n", + " padding=self.padding)\n", + " x = nn.bias_add(x, self.bias)\n", + " return x" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "WP9iKPeueUQp" + }, + "source": [ + "class GatedBlock(tf.keras.Model):\n", + " \"\"\" Gated block that compose Gated PixelCNN.\"\"\"\n", + "\n", + " def __init__(self, mask_type, filters, kernel_size):\n", + " super(GatedBlock, self).__init__(name='')\n", + "\n", + " self.mask_type = mask_type\n", + " self.vertical_conv = MaskedConv2D(mask_type='V',\n", + " filters=2 * filters,\n", + " kernel_size=kernel_size)\n", + "\n", + " self.horizontal_conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=2 * filters,\n", + " kernel_size=kernel_size)\n", + "\n", + " self.padding = keras.layers.ZeroPadding2D(padding=((1, 0), 0))\n", + " self.cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", + "\n", + " self.v_to_h_conv = keras.layers.Conv2D(filters=2 * filters, kernel_size=1)\n", + "\n", + " self.horizontal_output = keras.layers.Conv2D(filters=filters, kernel_size=1)\n", + "\n", + " def _gate(self, x):\n", + " tanh_preactivation, sigmoid_preactivation = tf.split(x, 2, axis=-1)\n", + " return tf.nn.tanh(tanh_preactivation) * tf.nn.sigmoid(sigmoid_preactivation)\n", + "\n", + " def call(self, input_tensor):\n", + " v = input_tensor[0]\n", + " h = input_tensor[1]\n", + "\n", + " vertical_preactivation = self.vertical_conv(v)\n", + "\n", + " # Shifting vertical stack feature map down before feed into horizontal stack to\n", + " # ensure causality\n", + " v_to_h = self.padding(vertical_preactivation)\n", + " v_to_h = self.cropping(v_to_h)\n", + " v_to_h = self.v_to_h_conv(v_to_h)\n", + "\n", + " horizontal_preactivation = self.horizontal_conv(h)\n", + "\n", + " v_out = self._gate(vertical_preactivation)\n", + "\n", + " horizontal_preactivation = horizontal_preactivation + v_to_h\n", + " h_activated = self._gate(horizontal_preactivation)\n", + " h_activated = self.horizontal_output(h_activated)\n", + "\n", + " if self.mask_type == 'A':\n", + " h_out = h_activated\n", + " elif self.mask_type == 'B':\n", + " h_out = h + h_activated\n", + "\n", + " return v_out, h_out\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "H8TFrKacgCD_" + }, + "source": [ + "# Define the data\n", + "height = 10\n", + "width = 10\n", + "n_channel = 1\n", + "\n", + "data = tf.random.normal((1, height, width, n_channel))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "JCBEaq0pcd9y" + }, + "source": [ + "# 1 layer Gated PixelCNN\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs])\n", + "model = tf.keras.Model(inputs=inputs, outputs=h)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "dhFgEQI6clrA" + }, + "source": [ + "plot_receptive_field(model, data)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jgLy1fCyg1c2" + }, + "source": [ + "Excellent! Like we expected the block considered all the previous blocks in the same row of the analyssed pixel, and the two rows over it. The blind splot problem is fixed!\n", + "\n", + "Note that the receptive field is different from the original PixelCNN. In the original PixelCNN only one row over the analysed pixel influenced in its prediction (when using one masked convolution). In the Gated PixelCNN, the authors used a vertical stack with effective area of 2x3 per vertical convolution. This is not a problem, since the considered pixels still being the ones in past positions. We believe the main choice for this format is to implement an efficient way to apply the masked convolutions without using masking (which we will discuss in future posts).\n", + "\n", + "For the next step, we wll verify a model with 2, 3, 4, and 5 layers" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "S8TfMRxqg5tb" + }, + "source": [ + "# 2 layers Gated PixelCNN\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "model = tf.keras.Model(inputs=inputs, outputs=h)\n", + "\n", + "plot_receptive_field(model, data)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "8ljKvZsrg6yx" + }, + "source": [ + "# 3 layers Gated PixelCNN\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "model = tf.keras.Model(inputs=inputs, outputs=h)\n", + "\n", + "plot_receptive_field(model, data)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "kExc-22eg-MA" + }, + "source": [ + "# 4 layers Gated PixelCNN\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "model = tf.keras.Model(inputs=inputs, outputs=h)\n", + "\n", + "plot_receptive_field(model, data)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "wHImyVeXhBUT" + }, + "source": [ + "# 5 layers Gated PixelCNN\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "model = tf.keras.Model(inputs=inputs, outputs=h)\n", + "\n", + "plot_receptive_field(model, data)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dK7J1k81hJcG" + }, + "source": [ + "As you can notice, the Gated PixelCNN does not create blind spots when adding more and more layers." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Q8l736u3b-7n" + }, + "source": [ + "# Gated CNN \n", + "\n", + "We will now see how the Gated CNN works on real data (i.e., MNIST dataset). To make sure that this notebook is reproducible, we will fix the random seed.\n", + "\n", + "This implementation is similar to the PixelCNN implementation, we described on the [first notebook](https://colab.research.google.com/github/Mind-the-Pineapple/Autoregressive-models/blob/master/1%20-%20Autoregressive%20Models%20-%20PixelCNN/pixelCNN.ipynb#scrollTo=bU25WyouYYE3).\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "xlN0lIKYhuoB" + }, + "source": [ + "# Defining random seeds\n", + "random_seed = 42\n", + "tf.random.set_seed(random_seed)\n", + "np.random.seed(random_seed)\n", + "rn.seed(random_seed)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Yf6g0Mqnh2JU" + }, + "source": [ + "# Loading data\n", + "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", + "\n", + "height = 28\n", + "width = 28\n", + "n_channel = 1\n", + "\n", + "x_train = x_train.astype('float32') / 255.\n", + "x_test = x_test.astype('float32') / 255.\n", + "\n", + "x_train = x_train.reshape(x_train.shape[0], height, width, n_channel)\n", + "x_test = x_test.reshape(x_test.shape[0], height, width, n_channel)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "H3dU2Ei1h8_5" + }, + "source": [ + "\n", + "In this example, to make the probability distribution of a single pixel easier to be defined, we decide to quantitise the number of possible values that a pixel could have. Originally, in the MNIST dataset the pixels are represented by a uint8 variable, beeing able assume values between [0, 255]. In this example, we restrict the image to have only 2 different values ([0, 1])." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ZS9zpZXaiBBt" + }, + "source": [ + "def quantise(images, q_levels):\n", + " \"\"\"Quantise image into q levels\"\"\"\n", + " return (np.digitize(images, np.arange(q_levels) / q_levels) - 1).astype('float32')" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "GQou3cV7iCNc" + }, + "source": [ + "# Quantise the input data in q levels\n", + "q_levels = 2\n", + "x_train_quantised = quantise(x_train, q_levels)\n", + "x_test_quantised = quantise(x_test, q_levels)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iKD9dgbGiK95" + }, + "source": [ + "Using the `tensorflow.Data` API, we defined the input data streams for our model during the training and the evaluation. In these dataset, we define the inputs as the images with 2 levels normalized to be between [0, 1] and the target values are the categoricals pixels values between [0, 1]." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "t1-_yo7XiJMQ" + }, + "source": [ + "# Creating input stream using tf.data API\n", + "batch_size = 192\n", + "train_buf = 10000\n", + "\n", + "train_dataset = tf.data.Dataset.from_tensor_slices((x_train_quantised / (q_levels - 1),\n", + " x_train_quantised.astype('int32')))\n", + "train_dataset = train_dataset.shuffle(buffer_size=train_buf)\n", + "train_dataset = train_dataset.batch(batch_size)\n", + "\n", + "test_dataset = tf.data.Dataset.from_tensor_slices((x_test_quantised / (q_levels - 1),\n", + " x_test_quantised.astype('int32')))\n", + "test_dataset = test_dataset.batch(batch_size)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mOm8Z4okiUqq" + }, + "source": [ + "### PixelCNN architecture" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "pHka3SW5iYIZ" + }, + "source": [ + "# Create Gated PixelCNN model\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=64, kernel_size=3)([inputs, inputs])\n", + "\n", + "for i in range(7):\n", + " v, h = GatedBlock(mask_type='B', filters=64, kernel_size=3)([v, h])\n", + "\n", + "x = keras.layers.Activation(activation='relu')(h)\n", + "x = keras.layers.Conv2D(filters=128, kernel_size=1, strides=1)(x)\n", + "\n", + "x = keras.layers.Activation(activation='relu')(x)\n", + "x = keras.layers.Conv2D(filters=q_levels, kernel_size=1, strides=1)(x)\n", + "\n", + "gated_pixelcnn = tf.keras.Model(inputs=inputs, outputs=x)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gU8Zctn3icaL" + }, + "source": [ + "In this implementation we use a simple Adam optimizer with learning rate decay to train the neural network. The loss function is defined by the cross-entropy (that in this case is equivalent to minimizing the negative log-likelihood of the training data)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lCC2yDpHifhE" + }, + "source": [ + "# Prepare optimizer and loss function\n", + "lr_decay = 0.999995\n", + "learning_rate = 1e-3\n", + "optimizer = keras.optimizers.Adam(learning_rate=learning_rate)\n", + "\n", + "compute_loss = keras.losses.CategoricalCrossentropy(from_logits=True)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JdbGc1zqiwTr" + }, + "source": [ + "\n", + "The training step is defined by the forward propagation through the model. Then, the gradients are calculated, clipped to be between [-1, 1], and applied to upgrade the PixelCNN parameters." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4-o_69bbiiO9" + }, + "source": [ + "@tf.function\n", + "def train_step(batch_x, batch_y):\n", + " with tf.GradientTape() as ae_tape:\n", + " logits = gated_pixelcnn(batch_x, training=True)\n", + "\n", + " loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits)\n", + "\n", + " gradients = ae_tape.gradient(loss, gated_pixelcnn.trainable_variables)\n", + " gradients, _ = tf.clip_by_global_norm(gradients, 1.0)\n", + " optimizer.apply_gradients(zip(gradients, gated_pixelcnn.trainable_variables))\n", + "\n", + " return loss" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oihYvXrDi0Ze" + }, + "source": [ + "In this implementation, we defined the training loop with 50 epochs.\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "9-D01ij3i2q2" + }, + "source": [ + "# Training loop\n", + "n_epochs = 1\n", + "n_iter = int(np.ceil(x_train_quantised.shape[0] / batch_size))\n", + "for epoch in range(n_epochs):\n", + " progbar = Progbar(n_iter)\n", + " print('Epoch {:}/{:}'.format(epoch + 1, n_epochs))\n", + "\n", + " for i_iter, (batch_x, batch_y) in enumerate(train_dataset):\n", + " optimizer.lr = optimizer.lr * lr_decay\n", + " loss = train_step(batch_x, batch_y)\n", + "\n", + " progbar.add(1, values=[('loss', loss)])" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VytiisKUi7xn" + }, + "source": [ + "To evaluate the performance of the model, we measured its negative log-likelihood (NLL) in the test set.\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ysZAvORejASO" + }, + "source": [ + "# Test set performance\n", + "test_loss = []\n", + "for batch_x, batch_y in test_dataset:\n", + " logits = gated_pixelcnn(batch_x, training=False)\n", + "\n", + " # Calculate cross-entropy (= negative log-likelihood)\n", + " loss = compute_loss(tf.one_hot(np.squeeze(batch_y), q_levels), logits)\n", + "\n", + " test_loss.append(loss)\n", + "print('nll : {:} nats'.format(np.array(test_loss).mean()))\n", + "print('bits/dim : {:}'.format(np.array(test_loss).mean() / np.log(2)))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "20hQ3iOEjUNX" + }, + "source": [ + "# Test set performance\n", + "test_loss = []\n", + "for batch_x, batch_y in test_dataset:\n", + " logits = gated_pixelcnn(np.squeeze(batch_x), training=False)\n", + "\n", + " # Calculate cross-entropy (= negative log-likelihood)\n", + " loss = compute_loss(tf.one_hot(np.squeeze(batch_y), q_levels), logits)\n", + "\n", + " test_loss.append(loss)\n", + "print('nll : {:} nats'.format(np.array(test_loss).mean()))\n", + "print('bits/dim : {:}'.format(np.array(test_loss).mean() / np.log(2)))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "P2_qYR4TjXgD" + }, + "source": [ + "# Generating new images\n", + "samples = np.zeros((100, height, width, n_channel), dtype='float32')\n", + "for i in range(height):\n", + " for j in range(width):\n", + " logits = gated_pixelcnn(samples)\n", + " next_sample = tf.random.categorical(logits[:, i, j, :], 1)\n", + " samples[:, i, j, 0] = (next_sample.numpy() / (q_levels - 1))[:, 0]\n", + "\n", + "fig = plt.figure(figsize=(10, 10))\n", + "for i in range(100):\n", + " ax = fig.add_subplot(10, 10, i + 1)\n", + " ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary)\n", + " plt.xticks(np.array([]))\n", + " plt.yticks(np.array([]))\n", + "plt.show()\n", + "plt.savefig('numbers1.png')" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "tL8cNhDFjbOY" + }, + "source": [ + "# Filling occluded images\n", + "occlude_start_row = 14\n", + "num_generated_images = 10\n", + "samples = np.copy(x_test_quantised[0:num_generated_images, :, :, :])\n", + "samples = samples / (q_levels - 1)\n", + "samples[:, occlude_start_row:, :, :] = 0\n", + "\n", + "fig = plt.figure(figsize=(10, 10))\n", + "\n", + "for i in range(10):\n", + " ax = fig.add_subplot(1, 10, i + 1)\n", + " ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary)\n", + " plt.xticks(np.array([]))\n", + " plt.yticks(np.array([]))\n", + "\n", + "for i in range(occlude_start_row, height):\n", + " for j in range(width):\n", + " logits = gated_pixelcnn(samples)\n", + " next_sample = tf.random.categorical(logits[:, i, j, :], 1)\n", + " samples[:, i, j, 0] = (next_sample.numpy() / (q_levels - 1))[:, 0]\n", + "\n", + "fig = plt.figure(figsize=(10, 10))\n", + "\n", + "for i in range(10):\n", + " ax = fig.add_subplot(1, 10, i + 1)\n", + " ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary)\n", + " plt.xticks(np.array([]))\n", + " plt.yticks(np.array([]))\n", + "plt.show()\n", + "plt.savefig('numbers2.png')" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LengLN21jB_R" + }, + "source": [ + "Finally, we sampled some images from the trained model. First, we sampled from scratch, then we completed images partially occluded.\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "psOigzt7hF4R" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/3 - PixelCNNs Blind spot and Gated PixelCNNs/pixelcnn_blind_spot.py b/3 - PixelCNNs Blind spot and Gated PixelCNNs/pixelcnn_blind_spot.py new file mode 100644 index 0000000..906cedf --- /dev/null +++ b/3 - PixelCNNs Blind spot and Gated PixelCNNs/pixelcnn_blind_spot.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +"""PixelCNN - Blind_spot + +Automatically generated by Colaboratory. + +Original file is located at + https://colab.research.google.com/drive/1qENjDpnBADXAfLwU9rla66W51g3FX9Hn + +# Autoregressive models - PixelCNNs blind spot in the receptive field and how to fix it + +In our previous notebooks, we trained a generative model to create black and white drawings of numbers and to generate coloured images. Now we are going to look into the blind splot, one of the biggested drawbacks of the PixelCNNs and how to fix it. We will also define the Gated PixelCNN and analyse its performance on the MNIST dataset. + +This implementation uses Tensorflow 2.0. We start by installing and importing the code dependencies. + +*Note: Here we are using float64 to get more precise values of the gradients and avoid false values. +""" + +# Commented out IPython magic to ensure Python compatibility. +# %tensorflow_version 2.x + +import random as rn + +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.ticker import FixedLocator +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow import nn +from tensorflow.keras import initializers +from tensorflow.keras.utils import Progbar + +tf.keras.backend.set_floatx('float64') + +"""Like in the previous notebooks we are defining the Convolutional layers with masks and the ResidualBlocks. The Residual block is an important building block of Oord's seminal paper that originated the PixelCNN. + +*Note: Here we removed the ReLU activations to not mess with the gradients while we are investigating them. +""" + +class ResidualBlock(keras.Model): + """Residual blocks that compose pixelCNN + + Blocks of layers with 3 convolutional layers and one residual connection. + Based on Figure 5 from [1] where h indicates number of filters. + + Refs: + [1] - Oord, A. V. D., Kalchbrenner, N., & Kavukcuoglu, K. (2016). Pixel recurrent + neural networks. arXiv preprint arXiv:1601.06759. + """ + + def __init__(self, h): + super(ResidualBlock, self).__init__(name='') + + self.conv2a = keras.layers.Conv2D(filters=h, kernel_size=1, strides=1) + self.conv2b = MaskedConv2D(mask_type='B', filters=h, kernel_size=3, strides=1) + self.conv2c = keras.layers.Conv2D(filters=2 * h, kernel_size=1, strides=1) + + def call(self, input_tensor): +# x = nn.relu(input_tensor) +# x = self.conv2a(x) + x = self.conv2a(input_tensor) + +# x = nn.relu(x) + x = self.conv2b(x) + +# x = nn.relu(x) + x = self.conv2c(x) + + x += input_tensor + return x + +"""## Blind Spots + +As we saw above, the original PixelCNN had a problem with its masked convolution that lead to a blind spot. Given a specific pixel, the use of masked convolution was not able to capture the information of all previous pixels. As we can see on the image below, the information on *j* is never used to predict *m*; and the information on *j, n, o* are not used to predict *q*. + +![masks.png]() + +Let's start by implementing the convolution and looking at how the blind spot propagates. You should be familiar with the class below, as we have used in the previous notebooks. +""" + +class MaskedConv2D(keras.layers.Layer): + """Convolutional layers with masks. + + Convolutional layers with simple implementation of masks type A and B for + autoregressive models. + + Arguments: + mask_type: one of `"A"` or `"B".` + filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: one of `"valid"` or `"same"` (case-insensitive). + kernel_initializer: Initializer for the `kernel` weights matrix. + bias_initializer: Initializer for the bias vector. + """ + + def __init__(self, + mask_type, + filters, + kernel_size, + strides=1, + padding='same', + kernel_initializer='glorot_uniform', + bias_initializer='zeros'): + super(MaskedConv2D, self).__init__() + + assert mask_type in {'A', 'B'} + self.mask_type = mask_type + + self.filters = filters + self.kernel_size = kernel_size + self.strides = strides + self.padding = padding.upper() + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + def build(self, input_shape): + self.kernel = self.add_weight('kernel', + shape=(self.kernel_size, + self.kernel_size, + int(input_shape[-1]), + self.filters), + initializer=self.kernel_initializer, + trainable=True) + + self.bias = self.add_weight('bias', + shape=(self.filters,), + initializer=self.bias_initializer, + trainable=True) + + center = self.kernel_size // 2 + + mask = np.ones(self.kernel.shape, dtype=np.float64) + mask[center, center + (self.mask_type == 'B'):, :, :] = 0. + mask[center + 1:, :, :, :] = 0. + + self.mask = tf.constant(mask, dtype=tf.float64, name='mask') + + def call(self, input): + masked_kernel = tf.math.multiply(self.mask, self.kernel) + x = nn.conv2d(input, + masked_kernel, + strides=[1, self.strides, self.strides, 1], + padding=self.padding) + x = nn.bias_add(x, self.bias) + return x + +def plot_receptive_field(model, data): + """ + This function allows the visualisation of the receptive field + """ + with tf.GradientTape() as tape: + tape.watch(data) + prediction = model(data) + loss = prediction[:,5,5,0] + + gradients = tape.gradient(loss, data) + + gradients = np.abs(gradients.numpy().squeeze()) + gradients = (gradients > 0).astype('float64') + gradients[5, 5] = 0.5 + + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + + plt.xticks(np.arange(0, 10, step=1)) + plt.yticks(np.arange(0, 10, step=1)) + ax.xaxis.set_minor_locator(FixedLocator(np.arange(0.5, 10.5, step=1))) + ax.yaxis.set_minor_locator(FixedLocator(np.arange(0.5, 10.5, step=1))) + plt.grid(which="minor") + plt.imshow(gradients, vmin=0, vmax=1) + plt.show() + +# -------------------------------------------------------------------------------------------------------------- +height = 10 +width = 10 +n_channel = 1 + +data = tf.random.normal((1, height, width, n_channel)) + +# 1 layer PixelCNN +inputs = keras.layers.Input(shape=(height, width, n_channel)) +x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs) +model = tf.keras.Model(inputs=inputs, outputs=x) + +plot_receptive_field(model, data) + +# 2 layer PixelCNN +inputs = keras.layers.Input(shape=(height, width, n_channel)) +x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs) +x = ResidualBlock(h=1)(x) + +model = tf.keras.Model(inputs=inputs, outputs=x) + +plot_receptive_field(model, data) + +# 3 layer PixelCNN +inputs = keras.layers.Input(shape=(height, width, n_channel)) +x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs) +x = ResidualBlock(h=1)(x) +x = ResidualBlock(h=1)(x) + +model = tf.keras.Model(inputs=inputs, outputs=x) + +plot_receptive_field(model, data) + +# 4 layer PixelCNN +inputs = keras.layers.Input(shape=(height, width, n_channel)) +x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs) +x = ResidualBlock(h=1)(x) +x = ResidualBlock(h=1)(x) +x = ResidualBlock(h=1)(x) + +model = tf.keras.Model(inputs=inputs, outputs=x) + +plot_receptive_field(model, data) + +# 5 layer PixelCNN +inputs = keras.layers.Input(shape=(height, width, n_channel)) +x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs) +x = ResidualBlock(h=1)(x) +x = ResidualBlock(h=1)(x) +x = ResidualBlock(h=1)(x) +x = ResidualBlock(h=1)(x) + +model = tf.keras.Model(inputs=inputs, outputs=x) + +plot_receptive_field(model, data) + +"""## Gated Pixel CNN and the blind spot + +A big drawback of PixelCNNs is the blind spot in the receptive filed. This blind spot is a result of using a mask that gets propagated via convolution. To overcome the blind spot problem, the Gated Pixel CNN was introduced. This new model solved the blind spot by using two separate convolutions (i.e., the horizontal and the vertical stack). +While the horizontal stack consists of the pixels just before the predicted pixel, the vertical stack represents the rows above the predicted pixel. + +![stacks.png]() + +Let's re-implement the `MaskedConv2D` class, but now using the horizonal and vertical stack to correct for the blind spot. +""" + +# -------------------------------------------------------------------------------------------------------------- + +class MaskedConv2D(keras.layers.Layer): + """Convolutional layers with masks. + + Convolutional layers with simple implementation of masks type A and B for + autoregressive models. + + Arguments: + mask_type: one of `"A"` or `"B".` + filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: one of `"valid"` or `"same"` (case-insensitive). + kernel_initializer: Initializer for the `kernel` weights matrix. + bias_initializer: Initializer for the bias vector. + """ + + def __init__(self, + mask_type, + filters, + kernel_size, + strides=1, + padding='same', + kernel_initializer='glorot_uniform', + bias_initializer='zeros'): + super(MaskedConv2D, self).__init__() + + assert mask_type in {'A', 'B', 'V'} + self.mask_type = mask_type + + self.filters = filters + self.kernel_size = kernel_size + self.strides = strides + self.padding = padding.upper() + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + def build(self, input_shape): + kernel_h = self.kernel_size + kernel_w = self.kernel_size + self.kernel = self.add_weight('kernel', + shape=(self.kernel_size, + self.kernel_size, + int(input_shape[-1]), + self.filters), + initializer=self.kernel_initializer, + trainable=True) + + self.bias = self.add_weight('bias', + shape=(self.filters,), + initializer=self.bias_initializer, + trainable=True) + mask = np.ones(self.kernel.shape, dtype=np.float64) + + # Get centre of the filter for even or odd dimensions + if kernel_h % 2 != 0: + center_h = kernel_h // 2 + else: + center_h = (kernel_h - 1) // 2 + + if kernel_w % 2 != 0: + center_w = kernel_w // 2 + else: + center_w = (kernel_w - 1) // 2 + + if self.mask_type == 'V': + mask[center_h + 1:, :, :, :] = 0. + else: + mask[:center_h, :, :] = 0. + mask[center_h, center_w + (self.mask_type == 'B'):, :, :] = 0. + mask[center_h + 1:, :, :] = 0. + + self.mask = tf.constant(mask, dtype=tf.float64, name='mask') + + def call(self, input): + masked_kernel = tf.math.multiply(self.mask, self.kernel) + x = nn.conv2d(input, + masked_kernel, + strides=[1, self.strides, self.strides, 1], + padding=self.padding) + x = nn.bias_add(x, self.bias) + return x + +class GatedBlock(tf.keras.Model): + """ Gated block that compose Gated PixelCNN.""" + + def __init__(self, mask_type, filters, kernel_size): + super(GatedBlock, self).__init__(name='') + + self.mask_type = mask_type + self.vertical_conv = MaskedConv2D(mask_type='V', + filters=2 * filters, + kernel_size=kernel_size) + + self.horizontal_conv = MaskedConv2D(mask_type=mask_type, + filters=2 * filters, + kernel_size=kernel_size) + + self.padding = keras.layers.ZeroPadding2D(padding=((1, 0), 0)) + self.cropping = keras.layers.Cropping2D(cropping=((0, 1), 0)) + + self.v_to_h_conv = keras.layers.Conv2D(filters=2 * filters, kernel_size=1) + + self.horizontal_output = keras.layers.Conv2D(filters=filters, kernel_size=1) + + def _gate(self, x): + tanh_preactivation, sigmoid_preactivation = tf.split(x, 2, axis=-1) + return tf.nn.tanh(tanh_preactivation) * tf.nn.sigmoid(sigmoid_preactivation) + + def call(self, input_tensor): + v = input_tensor[0] + h = input_tensor[1] + + vertical_preactivation = self.vertical_conv(v) + + # Shifting vertical stack feature map down before feed into horizontal stack to + # ensure causality + v_to_h = self.padding(vertical_preactivation) + v_to_h = self.cropping(v_to_h) + v_to_h = self.v_to_h_conv(v_to_h) + + horizontal_preactivation = self.horizontal_conv(h) + + v_out = self._gate(vertical_preactivation) + + horizontal_preactivation = horizontal_preactivation + v_to_h + h_activated = self._gate(horizontal_preactivation) + h_activated = self.horizontal_output(h_activated) + + if self.mask_type == 'A': + h_out = h_activated + elif self.mask_type == 'B': + h_out = h + h_activated + + return v_out, h_out + +# Define the data +height = 10 +width = 10 +n_channel = 1 + +data = tf.random.normal((1, height, width, n_channel)) + +# 1 layer Gated PixelCNN +inputs = keras.layers.Input(shape=(height, width, n_channel)) +v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs]) +model = tf.keras.Model(inputs=inputs, outputs=h) + +plot_receptive_field(model, data) + +"""Excellent! Like we expected the block considered all the previous blocks in the same row of the analyssed pixel, and the two rows over it. The blind splot problem is fixed! + +Note that the receptive field is different from the original PixelCNN. In the original PixelCNN only one row over the analysed pixel influenced in its prediction (when using one masked convolution). In the Gated PixelCNN, the authors used a vertical stack with effective area of 2x3 per vertical convolution. This is not a problem, since the considered pixels still being the ones in past positions. We believe the main choice for this format is to implement an efficient way to apply the masked convolutions without using masking (which we will discuss in future posts). + +For the next step, we wll verify a model with 2, 3, 4, and 5 layers +""" + +# 2 layers Gated PixelCNN +inputs = keras.layers.Input(shape=(height, width, n_channel)) +v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs]) +v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h]) +model = tf.keras.Model(inputs=inputs, outputs=h) + +plot_receptive_field(model, data) + +# 3 layers Gated PixelCNN +inputs = keras.layers.Input(shape=(height, width, n_channel)) +v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs]) +v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h]) +v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h]) +model = tf.keras.Model(inputs=inputs, outputs=h) + +plot_receptive_field(model, data) + +# 4 layers Gated PixelCNN +inputs = keras.layers.Input(shape=(height, width, n_channel)) +v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs]) +v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h]) +v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h]) +v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h]) +model = tf.keras.Model(inputs=inputs, outputs=h) + +plot_receptive_field(model, data) + +# 5 layers Gated PixelCNN +inputs = keras.layers.Input(shape=(height, width, n_channel)) +v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs]) +v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h]) +v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h]) +v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h]) +v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h]) +model = tf.keras.Model(inputs=inputs, outputs=h) + +plot_receptive_field(model, data) + +"""As you can notice, the Gated PixelCNN does not create blind spots when adding more and more layers. + +# Gated CNN + +We will now see how the Gated CNN works on real data (i.e., MNIST dataset). To make sure that this notebook is reproducible, we will fix the random seed. + +This implementation is similar to the PixelCNN implementation, we described on the [first notebook](https://colab.research.google.com/github/Mind-the-Pineapple/Autoregressive-models/blob/master/1%20-%20Autoregressive%20Models%20-%20PixelCNN/pixelCNN.ipynb#scrollTo=bU25WyouYYE3). +""" + +# Defining random seeds +random_seed = 42 +tf.random.set_seed(random_seed) +np.random.seed(random_seed) +rn.seed(random_seed) + +# Loading data +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + +height = 28 +width = 28 +n_channel = 1 + +x_train = x_train.astype('float32') / 255. +x_test = x_test.astype('float32') / 255. + +x_train = x_train.reshape(x_train.shape[0], height, width, n_channel) +x_test = x_test.reshape(x_test.shape[0], height, width, n_channel) + +""" +In this example, to make the probability distribution of a single pixel easier to be defined, we decide to quantitise the number of possible values that a pixel could have. Originally, in the MNIST dataset the pixels are represented by a uint8 variable, beeing able assume values between [0, 255]. In this example, we restrict the image to have only 2 different values ([0, 1]).""" + +def quantise(images, q_levels): + """Quantise image into q levels""" + return (np.digitize(images, np.arange(q_levels) / q_levels) - 1).astype('float32') + +# Quantise the input data in q levels +q_levels = 2 +x_train_quantised = quantise(x_train, q_levels) +x_test_quantised = quantise(x_test, q_levels) + +"""Using the `tensorflow.Data` API, we defined the input data streams for our model during the training and the evaluation. In these dataset, we define the inputs as the images with 2 levels normalized to be between [0, 1] and the target values are the categoricals pixels values between [0, 1].""" + +# Creating input stream using tf.data API +batch_size = 192 +train_buf = 10000 + +train_dataset = tf.data.Dataset.from_tensor_slices((x_train_quantised / (q_levels - 1), + x_train_quantised.astype('int32'))) +train_dataset = train_dataset.shuffle(buffer_size=train_buf) +train_dataset = train_dataset.batch(batch_size) + +test_dataset = tf.data.Dataset.from_tensor_slices((x_test_quantised / (q_levels - 1), + x_test_quantised.astype('int32'))) +test_dataset = test_dataset.batch(batch_size) + +"""### PixelCNN architecture""" + +# Create Gated PixelCNN model +inputs = keras.layers.Input(shape=(height, width, n_channel)) +v, h = GatedBlock(mask_type='A', filters=64, kernel_size=3)([inputs, inputs]) + +for i in range(7): + v, h = GatedBlock(mask_type='B', filters=64, kernel_size=3)([v, h]) + +x = keras.layers.Activation(activation='relu')(h) +x = keras.layers.Conv2D(filters=128, kernel_size=1, strides=1)(x) + +x = keras.layers.Activation(activation='relu')(x) +x = keras.layers.Conv2D(filters=q_levels, kernel_size=1, strides=1)(x) + +gated_pixelcnn = tf.keras.Model(inputs=inputs, outputs=x) + +"""In this implementation we use a simple Adam optimizer with learning rate decay to train the neural network. The loss function is defined by the cross-entropy (that in this case is equivalent to minimizing the negative log-likelihood of the training data).""" + +# Prepare optimizer and loss function +lr_decay = 0.999995 +learning_rate = 1e-3 +optimizer = keras.optimizers.Adam(learning_rate=learning_rate) + +compute_loss = keras.losses.CategoricalCrossentropy(from_logits=True) + +""" +The training step is defined by the forward propagation through the model. Then, the gradients are calculated, clipped to be between [-1, 1], and applied to upgrade the PixelCNN parameters.""" + +@tf.function +def train_step(batch_x, batch_y): + with tf.GradientTape() as ae_tape: + logits = gated_pixelcnn(batch_x, training=True) + + loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits) + + gradients = ae_tape.gradient(loss, gated_pixelcnn.trainable_variables) + gradients, _ = tf.clip_by_global_norm(gradients, 1.0) + optimizer.apply_gradients(zip(gradients, gated_pixelcnn.trainable_variables)) + + return loss + +"""In this implementation, we defined the training loop with 50 epochs. + + +""" + +# Training loop +n_epochs = 1 +n_iter = int(np.ceil(x_train_quantised.shape[0] / batch_size)) +for epoch in range(n_epochs): + progbar = Progbar(n_iter) + print('Epoch {:}/{:}'.format(epoch + 1, n_epochs)) + + for i_iter, (batch_x, batch_y) in enumerate(train_dataset): + optimizer.lr = optimizer.lr * lr_decay + loss = train_step(batch_x, batch_y) + + progbar.add(1, values=[('loss', loss)]) + +"""To evaluate the performance of the model, we measured its negative log-likelihood (NLL) in the test set. + + +""" + +# Test set performance +test_loss = [] +for batch_x, batch_y in test_dataset: + logits = gated_pixelcnn(batch_x, training=False) + + # Calculate cross-entropy (= negative log-likelihood) + loss = compute_loss(tf.one_hot(np.squeeze(batch_y), q_levels), logits) + + test_loss.append(loss) +print('nll : {:} nats'.format(np.array(test_loss).mean())) +print('bits/dim : {:}'.format(np.array(test_loss).mean() / np.log(2))) + +# Test set performance +test_loss = [] +for batch_x, batch_y in test_dataset: + logits = gated_pixelcnn(np.squeeze(batch_x), training=False) + + # Calculate cross-entropy (= negative log-likelihood) + loss = compute_loss(tf.one_hot(np.squeeze(batch_y), q_levels), logits) + + test_loss.append(loss) +print('nll : {:} nats'.format(np.array(test_loss).mean())) +print('bits/dim : {:}'.format(np.array(test_loss).mean() / np.log(2))) + +# Generating new images +samples = np.zeros((100, height, width, n_channel), dtype='float32') +for i in range(height): + for j in range(width): + logits = gated_pixelcnn(samples) + next_sample = tf.random.categorical(logits[:, i, j, :], 1) + samples[:, i, j, 0] = (next_sample.numpy() / (q_levels - 1))[:, 0] + +fig = plt.figure(figsize=(10, 10)) +for i in range(100): + ax = fig.add_subplot(10, 10, i + 1) + ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary) + plt.xticks(np.array([])) + plt.yticks(np.array([])) +plt.show() +plt.savefig('numbers1.png') + +# Filling occluded images +occlude_start_row = 14 +num_generated_images = 10 +samples = np.copy(x_test_quantised[0:num_generated_images, :, :, :]) +samples = samples / (q_levels - 1) +samples[:, occlude_start_row:, :, :] = 0 + +fig = plt.figure(figsize=(10, 10)) + +for i in range(10): + ax = fig.add_subplot(1, 10, i + 1) + ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary) + plt.xticks(np.array([])) + plt.yticks(np.array([])) + +for i in range(occlude_start_row, height): + for j in range(width): + logits = gated_pixelcnn(samples) + next_sample = tf.random.categorical(logits[:, i, j, :], 1) + samples[:, i, j, 0] = (next_sample.numpy() / (q_levels - 1))[:, 0] + +fig = plt.figure(figsize=(10, 10)) + +for i in range(10): + ax = fig.add_subplot(1, 10, i + 1) + ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary) + plt.xticks(np.array([])) + plt.yticks(np.array([])) +plt.show() +plt.savefig('numbers2.png') + +"""Finally, we sampled some images from the trained model. First, we sampled from scratch, then we completed images partially occluded. + + +""" +