diff --git a/WIP/3 -/Receptive_fields.ipynb b/WIP/3 -/Receptive_fields.ipynb new file mode 100644 index 0000000..237f045 --- /dev/null +++ b/WIP/3 -/Receptive_fields.ipynb @@ -0,0 +1,251 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Receptive fields.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "YqTKIYLooHsq", + "colab_type": "text" + }, + "source": [ + "# Comparing receptive fields" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "wM-m3Z8CiLXU", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "gf5wwqP3ozaN", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import random as rn\n", + "\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "from tensorflow import nn\n", + "from tensorflow.keras import initializers\n", + "from tensorflow.keras.utils import Progbar" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "jEkll1yno2Vb", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Defining random seeds\n", + "random_seed = 42\n", + "tf.random.set_seed(random_seed)\n", + "np.random.seed(random_seed)\n", + "rn.seed(random_seed)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "yJ_JlzWco7ci", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MaskedConv2D(keras.layers.Layer):\n", + " \"\"\"Convolutional layers with masks.\n", + "\n", + " Convolutional layers with simple implementation of masks type A and B for\n", + " autoregressive models.\n", + "\n", + " Arguments:\n", + " mask_type: one of `\"A\"` or `\"B\".`\n", + " filters: Integer, the dimensionality of the output space\n", + " (i.e. the number of output filters in the convolution).\n", + " kernel_size: An integer or tuple/list of 2 integers, specifying the\n", + " height and width of the 2D convolution window.\n", + " Can be a single integer to specify the same value for\n", + " all spatial dimensions.\n", + " strides: An integer or tuple/list of 2 integers,\n", + " specifying the strides of the convolution along the height and width.\n", + " Can be a single integer to specify the same value for\n", + " all spatial dimensions.\n", + " Specifying any stride value != 1 is incompatible with specifying\n", + " any `dilation_rate` value != 1.\n", + " padding: one of `\"valid\"` or `\"same\"` (case-insensitive).\n", + " kernel_initializer: Initializer for the `kernel` weights matrix.\n", + " bias_initializer: Initializer for the bias vector.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " mask_type,\n", + " filters,\n", + " kernel_size,\n", + " strides=1,\n", + " padding='same',\n", + " kernel_initializer='glorot_uniform',\n", + " bias_initializer='zeros'):\n", + " super(MaskedConv2D, self).__init__()\n", + "\n", + " assert mask_type in {'A', 'B'}\n", + " self.mask_type = mask_type\n", + "\n", + " self.filters = filters\n", + " self.kernel_size = kernel_size\n", + " self.strides = strides\n", + " self.padding = padding.upper()\n", + " self.kernel_initializer = initializers.get(kernel_initializer)\n", + " self.bias_initializer = initializers.get(bias_initializer)\n", + "\n", + " def build(self, input_shape):\n", + " self.kernel = self.add_weight('kernel',\n", + " shape=(self.kernel_size,\n", + " self.kernel_size,\n", + " int(input_shape[-1]),\n", + " self.filters),\n", + " initializer=self.kernel_initializer,\n", + " trainable=True)\n", + "\n", + " self.bias = self.add_weight('bias',\n", + " shape=(self.filters,),\n", + " initializer=self.bias_initializer,\n", + " trainable=True)\n", + "\n", + " center = self.kernel_size // 2\n", + "\n", + " mask = np.ones(self.kernel.shape, dtype=np.float32)\n", + " mask[center, center + (self.mask_type == 'B'):, :, :] = 0.\n", + " mask[center + 1:, :, :, :] = 0.\n", + "\n", + " self.mask = tf.constant(mask, dtype=tf.float32, name='mask')\n", + "\n", + " def call(self, input):\n", + " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n", + " x = nn.conv2d(input,\n", + " masked_kernel,\n", + " strides=[1, self.strides, self.strides, 1],\n", + " padding=self.padding)\n", + " x = nn.bias_add(x, self.bias)\n", + " return x" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "jxCLMYc-FxdJ", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def plot_receptive_field(model, data):\n", + " out = model(data)\n", + "\n", + " with tf.GradientTape() as tape:\n", + " tape.watch(data)\n", + " prediction = model(data)\n", + " loss = prediction[:,5,5,0]\n", + "\n", + " gradients = tape.gradient(loss, data)\n", + "\n", + " gradients = np.abs(gradients.numpy().squeeze())\n", + " gradients = (gradients > 1e-8).astype('float32')\n", + " gradients[5, 5] = 0.5\n", + "\n", + " plt.figure()\n", + " plt.imshow(gradients)\n", + " plt.title(f'Receptive field from pixel layers')\n", + " plt.show()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "0qpDtNuvo9NL", + "colab_type": "code", + "outputId": "926434e2-44bf-40d3-a39f-6c80ebc880b6", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 + } + }, + "source": [ + "height = 10\n", + "width = 10\n", + "n_channel = 1\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", + "x = MaskedConv2D(mask_type='B', filters=1, kernel_size=3, strides=1)(x)\n", + "x = MaskedConv2D(mask_type='B', filters=1, kernel_size=3, strides=1)(x)\n", + "\n", + "model = keras.Model(inputs=inputs, outputs=x)\n", + "data = tf.random.normal((1,10,10,1))\n", + "\n", + "plot_receptive_field(model, data)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAARBElEQVR4nO3dfZBV9X3H8fdHWEGQKkbzwEOQiUYltoLdxKdqjWI1PmaaptVEU50axjaKWic+dVqcxDTTiSY6k2iCD9gIVVPiRGuNmgjGaCsjIBkFtCWIgkDEKgo2Aazf/nF+K5dl796zy72cuz8/r5mdufeec8/5nofP/n7nt/eeVURgZvnYpeoCzKy5HGqzzDjUZplxqM0y41CbZcahNsuMQ12SpKsl3dqC5X5I0uOSNki6vi/rkfSYpPPrTNtXUkgaXGf6AZIWpfVO3ZFtaDZJX5T0SBOWs0LS5DrT7pB07Y6uox31eMB3NkkrgA8B/wdsBB4CLoyIjRXVcywwMyLGdL0WEf/YotVNAV4Dfi927ocGLgfmRsTEnbjOUiJiFjCr6joGqnZqqU+LiN2BicAk4KqK69lZxgFLdnKgu9a7uN5ESYN2Yi3vC/V6Tc3WTqEGICLWAg9ThBsASYdL+g9J6yX9KrWkXdP2kjRD0mpJb0j6Sc20U1MXc316/x/UTFsh6SpJS9L7ZkgaKmk48FNglKSN6WeUpGskzUzv/amkC2vrTnX9aXp8oKSfSXpd0guS/rynbZV0B/CXwOVpPZNr19No27sta5Ck6yS9Jmk5cEq9fSxpDvBp4LtpvR9P3dGbJT0o6W3g05IOSl389ZIWSzq9tnZJN6V9sVHSk5I+LOmGtD+flzSplxpC0lRJy1PN35K0S5p2rqQn0uMj0/Sx6fkhafkHpud1j3FZkkZKekDSurTsBySNSdM+L2lBt/n/VtJ96fGQtN9flvQbSd+XtFuadqykVZKukLQWmCFp77T89en8+GXXdjdNRFT+A6wAJqfHY4BngRvT89HA/wAnU/wSOiE93ydN/3fgHmAk0AH8cXp9EvAqcBgwiCI8K4AhNet8DhgL7AU8CVybph0LrOpW4zUUXXKALwFP1kybAKwHhgDDgZXAeRSXN5MoutcT6mz7HV3r7WE9jbb9MeD89PgC4Pma7ZkLBDC4znrfe29NHW8CR6V1jQCWAVcDuwLHARuAA2rmfw34Q2AoMAd4Me2bQcC1FN37esc8Uo17AR8F/qtmW84FnqiZ9xtp+bulc+PCPhzjyY32O/AB4HPAsLTd/wr8JE0bArwOHFTz3meAz6XH3wHuT9sxAvg34Js159E7wD+l5ewGfBP4PsW52gEcDaipeao60DU7f2M6aQJ4FNgzTbsCuLPb/A+nA/gR4F1gZA/LvBn4erfXXmBr6FcAF9RMOxn4dclQjwDeBsbVnHS3p8d/Afyy23t/AEzrR6jrbnsPoZ7TbXv+hL6H+oc1z48G1gK71Lx2F3BNzfy31Ey7CFha8/z3gfW9HPMATqp5/jfAo+nxuWwb6g5gAUWgH+oKQclj3DDUPUybCLzR7Vz6Rnr8CeANipAqnQcfq5n3CODFmvNoMzC0ZvrXgPuA/VqVp3bqfn82IkZQ7IgDgb3T6+OAz6fuynpJ64E/ogj0WOD1iHijh+WNAy7r9r6xwKiaeVbWPH6p27S6ImIDRQ/hzPTSWWwd2BkHHNZtvV8EPlxm2T1sQ71t724U229PX9W+fxSwMiLe7bbM0TXPf1Pz+Lc9PN+9D+uru/8jYgtFCA8Gro+UDsod44YkDZP0A0kvSXoLeBzYU1vHFf4Z+IIkAecAP4qITcA+FK37gpr1P5Re77IuIn5X8/xbFD2gR9Klx5V9qbWMthj9rhURv0jXmtcBn6U48HdGxJe7zyvpI8BekvaMiPXdJq+k+O36jV5WN7bm8UeB1V1llCj1LmCapMcpup9za9b7i4g4ocQyGqm77T1Yw/bb01e1270aGCtpl5pgd3WTm2UsWwfravf/NiSNBqYBM4DrJX0yharMMS7jMuAA4LCIWCtpIkUXWwAR8ZSkzRS9ly+kHyguP34LfCIiXqmz7G3OpdQgXEbxy+hgYI6kpyPi0R3chve0U0td6wbgBEmHADOB0ySdmAaDhqYBiDERsYZiUOumNNjRIemYtIxbgAskHabCcEmnSBpRs56vSBojaS/g7yiuzaFocT4gaY9eanyQoqX4GnBPzYn/APBxSeekejokfVLSQf3YD3W3vYd5fwRMTdszEtjRFmAe8L8Ug3gdKgboTgPu3sHl1vpqOm5jgYvZuv/fk1rHO4DbgL+i+OX19TS5zDEuYwRFONenc2FaD/P8EPgusCUingBIx/wW4DuSPpjqHS3pxHorSgN7+6XtepPiz7jv1pu/P9oy1BGxjmIn/kNErATOoBiwWUfx2/mrbK39HGALxSDRq8AlaRnzgS9THIg3KLo853Zb1b8AjwDLgV9TDO4QEc9TtMTLU7dqu+5cainuBSan5XS9voHievZMipZnLVsHSvq6Hxpte61bKK63fwUsTLX1W0RspgjxZyhapJuAL6V90yz3UVwrL6K4nLmth3mmAh8E/j51u88DzpN0dMljXMYNFINYrwFPUXShu7uTovs/s9vrV6T1PpW67j+naPXr2T/NsxH4T+CmiJjby/x91jXg8L6j4gMv50fEz6uu5f1IUgD7R8SyqmspI/2Z6lXg0Ij476rr6U1bttRmbeivgafbPdDQhgNlZu0m9epEMXDb9t633W+zXLn7bZaZlnS/d9WQGMrwVizazIDf8TabY5N6mtaSUA9lOIfp+FYs2syAeb18VsXdb7PMONRmmXGozTLjUJtlxqE2y4xDbZaZUqGWdJKKe20ta8WXus2seRqGOt394XsUX8GbAJwlaUKrCzOz/inTUn8KWBYRy9N3bO+m+I6vmbWhMqEezbb3klrFtvepAkDSFEnzJc3fwqZm1WdmfdS0gbKImB4RnRHR2dH3m3yYWZOUCfUrbHtDuzHpNTNrQ2VC/TSwv6TxknaluPfW/a0ty8z6q+G3tCLiHRX/YuZhiv+CcHtE1P0fTGZWrVJfvYyIByluiWtmbc6fKDPLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzpf6XVu4eXr2o6hIsYyeOmrhT1+eW2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLTMNQSxoraa6kJZIWS7p4ZxRmZv1T5sMn7wCXRcRCSSOABZJ+FhFLWlybmfVDw5Y6ItZExML0eAOwFBjd6sLMrH/69DFRSfsCk4B5PUybAkwBGMqwJpRmZv1ReqBM0u7Aj4FLIuKt7tMjYnpEdEZEZwdDmlmjmfVBqVBL6qAI9KyIuLe1JZnZjigz+i3gNmBpRHy79SWZ2Y4o01IfBZwDHCdpUfo5ucV1mVk/NRwoi4gnAO2EWsysCfyJMrPMONRmmXGozTLjUJtlxjceNEt29g0CW8UttVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcZ3E7UBKZc7f7aCW2qzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDOlQy1pkKRnJD3QyoLMbMf0paW+GFjaqkLMrDlKhVrSGOAU4NbWlmNmO6psS30DcDnwbr0ZJE2RNF/S/C1sakpxZtZ3DUMt6VTg1YhY0Nt8ETE9IjojorODIU0r0Mz6pkxLfRRwuqQVwN3AcZJmtrQqM+u3hqGOiKsiYkxE7AucCcyJiLNbXpmZ9Yv/Tm2WmT59nzoiHgMea0klZtYUbqnNMuNQm2XGoTbLjENtlhmH2iwzvpuoAfCxey5oyXL3u/SplizX6nNLbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmfHdRIETR02suoTK7Yfv+pkLt9RmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWZKhVrSnpJmS3pe0lJJR7S6MDPrn7IfPrkReCgi/kzSrsCwFtZkZjugYagl7QEcA5wLEBGbgc2tLcvM+qtM93s8sA6YIekZSbdKGt59JklTJM2XNH8Lm5peqJmVUybUg4FDgZsjYhLwNnBl95kiYnpEdEZEZwdDmlymmZVVJtSrgFURMS89n00RcjNrQw1DHRFrgZWSDkgvHQ8saWlVZtZvZUe/LwJmpZHv5cB5rSvJzHZEqVBHxCKgs8W1mFkT+BNlZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0yUyrUki6VtFjSc5LukjS01YWZWf80DLWk0cBUoDMiDgYGAWe2ujAz65+y3e/BwG6SBgPDgNWtK8nMdkTDUEfEK8B1wMvAGuDNiHik+3ySpkiaL2n+FjY1v1IzK6VM93skcAYwHhgFDJd0dvf5ImJ6RHRGRGcHQ5pfqZmVUqb7PRl4MSLWRcQW4F7gyNaWZWb9VSbULwOHSxomScDxwNLWlmVm/VXmmnoeMBtYCDyb3jO9xXWZWT8NLjNTREwDprW4FjNrAn+izCwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8woIpq/UGkd8FKJWfcGXmt6Aa0zkOodSLXCwKq3HWodFxH79DShJaEuS9L8iOisrIA+Gkj1DqRaYWDV2+61uvttlhmH2iwzVYd6oP3z+oFU70CqFQZWvW1da6XX1GbWfFW31GbWZA61WWYqC7WkkyS9IGmZpCurqqMRSWMlzZW0RNJiSRdXXVMZkgZJekbSA1XX0htJe0qaLel5SUslHVF1Tb2RdGk6D56TdJekoVXX1F0loZY0CPge8BlgAnCWpAlV1FLCO8BlETEBOBz4ShvXWutiYGnVRZRwI/BQRBwIHEIb1yxpNDAV6IyIg4FBwJnVVrW9qlrqTwHLImJ5RGwG7gbOqKiWXkXEmohYmB5voDjpRldbVe8kjQFOAW6tupbeSNoDOAa4DSAiNkfE+mqramgwsJukwcAwYHXF9WynqlCPBlbWPF9FmwcFQNK+wCRgXrWVNHQDcDnwbtWFNDAeWAfMSJcKt0oaXnVR9UTEK8B1wMvAGuDNiHik2qq254GykiTtDvwYuCQi3qq6nnoknQq8GhELqq6lhMHAocDNETEJeBto5/GVkRQ9yvHAKGC4pLOrrWp7VYX6FWBszfMx6bW2JKmDItCzIuLequtp4CjgdEkrKC5rjpM0s9qS6loFrIqIrp7PbIqQt6vJwIsRsS4itgD3AkdWXNN2qgr108D+ksZL2pVisOH+imrplSRRXPMtjYhvV11PIxFxVUSMiYh9KfbrnIhou9YEICLWAislHZBeOh5YUmFJjbwMHC5pWDovjqcNB/YGV7HSiHhH0oXAwxQjiLdHxOIqainhKOAc4FlJi9JrV0fEgxXWlJOLgFnpl/ty4LyK66krIuZJmg0spPiryDO04UdG/TFRs8x4oMwsMw61WWYcarPMONRmmXGozTLjUJtlxqE2y8z/A1DIMnrGMqKcAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "I2APaCzDGeqP", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/WIP/4 - Gated_PixelCNN/Gated_Receptive_fields.ipynb b/WIP/4 - Gated_PixelCNN/Gated_Receptive_fields.ipynb new file mode 100644 index 0000000..c068cee --- /dev/null +++ b/WIP/4 - Gated_PixelCNN/Gated_Receptive_fields.ipynb @@ -0,0 +1,321 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Gated Receptive fields.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "YqTKIYLooHsq", + "colab_type": "text" + }, + "source": [ + "# Comparing receptive fields" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "gf5wwqP3ozaN", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import random as rn\n", + "\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "from tensorflow import nn\n", + "from tensorflow.keras import initializers\n", + "from tensorflow.keras.utils import Progbar" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "jEkll1yno2Vb", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Defining random seeds\n", + "random_seed = 42\n", + "tf.random.set_seed(random_seed)\n", + "np.random.seed(random_seed)\n", + "rn.seed(random_seed)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "yJ_JlzWco7ci", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MaskedConv2D(keras.layers.Layer):\n", + " \"\"\"Convolutional layers with masks for Gated PixelCNN.\n", + "\n", + " Masked convolutional layers used to implement Vertical and Horizontal\n", + " stacks of the Gated PixelCNN.\n", + "\n", + " Note: This implementation is different from the normal PixelCNN.\n", + "\n", + " Arguments:\n", + " mask_type: one of `\"V\"`, `\"A\"` or `\"B\".`\n", + " filters: Integer, the dimensionality of the output space\n", + " (i.e. the number of output filters in the convolution).\n", + " kernel_size: An integer or tuple/list of 2 integers, specifying the\n", + " height and width of the 2D convolution window.\n", + " Can be a single integer to specify the same value for\n", + " all spatial dimensions.\n", + " strides: An integer or tuple/list of 2 integers,\n", + " specifying the strides of the convolution along the height and width.\n", + " Can be a single integer to specify the same value for\n", + " all spatial dimensions.\n", + " Specifying any stride value != 1 is incompatible with specifying\n", + " any `dilation_rate` value != 1.\n", + " padding: one of `\"valid\"` or `\"same\"` (case-insensitive).\n", + " kernel_initializer: Initializer for the `kernel` weights matrix.\n", + " bias_initializer: Initializer for the bias vector.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " mask_type,\n", + " filters,\n", + " kernel_size,\n", + " strides=1,\n", + " padding='same',\n", + " kernel_initializer='glorot_uniform',\n", + " bias_initializer='zeros'):\n", + " super(MaskedConv2D, self).__init__()\n", + "\n", + " assert mask_type in {'A', 'B', 'V'}\n", + " self.mask_type = mask_type\n", + "\n", + " self.filters = filters\n", + "\n", + " if isinstance(kernel_size, int):\n", + " kernel_size = (kernel_size, kernel_size)\n", + " self.kernel_size = kernel_size\n", + "\n", + " self.strides = strides\n", + " self.padding = padding.upper()\n", + " self.kernel_initializer = initializers.get(kernel_initializer)\n", + " self.bias_initializer = initializers.get(bias_initializer)\n", + "\n", + " def build(self, input_shape):\n", + " kernel_h, kernel_w = self.kernel_size\n", + "\n", + " self.kernel = self.add_weight('kernel',\n", + " shape=(kernel_h,\n", + " kernel_w,\n", + " int(input_shape[-1]),\n", + " self.filters),\n", + " initializer=self.kernel_initializer,\n", + " trainable=True)\n", + "\n", + " self.bias = self.add_weight('bias',\n", + " shape=(self.filters,),\n", + " initializer=self.bias_initializer,\n", + " trainable=True)\n", + "\n", + " mask = np.ones(self.kernel.shape, dtype=np.float32)\n", + "\n", + " if kernel_h % 2 != 0: \n", + " center_h = kernel_h // 2\n", + " else:\n", + " center_h = (kernel_h - 1) // 2\n", + "\n", + " if kernel_w % 2 != 0: \n", + " center_w = kernel_w // 2\n", + " else:\n", + " center_w = (kernel_w - 1) // 2\n", + "\n", + " if self.mask_type == 'V':\n", + " mask[center_h + 1:, :, :, :] = 0.\n", + " else:\n", + " mask[:center_h, :, :] = 0.\n", + " mask[center_h, center_w + (self.mask_type == 'B'):, :, :] = 0.\n", + " mask[center_h + 1:, :, :] = 0. \n", + "\n", + " self.mask = tf.constant(mask, dtype=tf.float32, name='mask')\n", + "\n", + " def call(self, input):\n", + " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n", + " x = nn.conv2d(input,\n", + " masked_kernel,\n", + " strides=[1, self.strides, self.strides, 1],\n", + " padding=self.padding)\n", + " x = nn.bias_add(x, self.bias)\n", + " return x" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "K5topys_HW7-", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class GatedBlock(tf.keras.Model):\n", + " \"\"\" Gated block of the Gated PixelCNN.\"\"\"\n", + "\n", + " def __init__(self, mask_type, filters, kernel_size):\n", + " super(GatedBlock, self).__init__(name='')\n", + "\n", + " self.mask_type = mask_type\n", + " self.vertical_conv = MaskedConv2D(mask_type='V',\n", + " filters=2 * filters,\n", + " kernel_size=kernel_size)\n", + " \n", + " self.horizontal_conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=2 * filters,\n", + " kernel_size=kernel_size)\n", + "\n", + " self.padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n", + " self.cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", + "\n", + " self.v_to_h_conv = keras.layers.Conv2D(filters=2 * filters, kernel_size=1)\n", + "\n", + " self.horizontal_output = keras.layers.Conv2D(filters=filters, kernel_size=1)\n", + "\n", + " def _gate(self, x):\n", + " tanh_preactivation, sigmoid_preactivation = tf.split(x, 2, axis=-1)\n", + " return tf.nn.tanh(tanh_preactivation) * tf.nn.sigmoid(sigmoid_preactivation)\n", + "\n", + " def call(self, input_tensor):\n", + " v = input_tensor[0]\n", + " h = input_tensor[1]\n", + "\n", + " vertical_preactivation = self.vertical_conv(v) # NxN\n", + "\n", + " # Shifting feature map down to ensure causality\n", + " v_to_h = self.padding(vertical_preactivation)\n", + " v_to_h = self.cropping(v_to_h)\n", + " v_to_h = self.v_to_h_conv(v_to_h) # 1x1\n", + "\n", + " horizontal_preactivation = self.horizontal_conv(h) # 1xN\n", + " \n", + " v_out = self._gate(vertical_preactivation)\n", + "\n", + " horizontal_preactivation = horizontal_preactivation + v_to_h\n", + " h_activated = self._gate(horizontal_preactivation)\n", + " h_activated = self.horizontal_output(h_activated)\n", + "\n", + " if self.mask_type == 'A':\n", + " h_out = h_activated\n", + " elif self.mask_type == 'B':\n", + " h_out = h + h_activated\n", + "\n", + " return v_out, h_out" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "jxCLMYc-FxdJ", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def plot_receptive_field(model, data):\n", + " out = model(data)\n", + "\n", + " with tf.GradientTape() as tape:\n", + " tape.watch(data)\n", + " prediction = model(data)\n", + " loss = prediction[:,5,5,0]\n", + "\n", + " gradients = tape.gradient(loss, data)\n", + "\n", + " gradients = np.abs(gradients.numpy().squeeze())\n", + " gradients = (gradients > 1e-8).astype('float32')\n", + " gradients[5, 5] = 0.5\n", + "\n", + " plt.figure()\n", + " plt.imshow(gradients)\n", + " plt.title(f'Receptive field from pixel layers')\n", + " plt.show()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "0qpDtNuvo9NL", + "colab_type": "code", + "outputId": "b8142536-f9db-4ce3-ed7e-03741da11a3f", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 + } + }, + "source": [ + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "model = tf.keras.Model(inputs=inputs, outputs=h)\n", + "\n", + "data = tf.random.normal((1,10,10,1))\n", + "\n", + "plot_receptive_field(model, data)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAARAUlEQVR4nO3dfbBU9X3H8fdHQBAkCkoegBtgolGJrWJvfKzWKFbjY6ZpWk001alhbKOodeJTp8VJTDOdaKIziSb4gI1QNSVOtNaoiWCMtlIBzSigKUEUFCKoKNpUUL/94/yQ5XL37rn37nKWn5/XzJ3ZPb+z53zPw2fP7/zu3r2KCMwsHztUXYCZNZdDbZYZh9osMw61WWYcarPMONRmmXGoS5J0uaQbW7Dcj0h6WNJ6SVf3Zj2SHpJ0dp228ZJC0sA67XtJejKtd2p/tqHZJH1J0gNNWM5ySZPrtN0i6cr+rqMddXvAtzVJy4GPAO8CbwL3AedGxJsV1XMkMDMixm6aFhH/1KLVTQHWAh+KbfuhgYuBuRGx/zZcZykRMQuYVXUd26t2ulKfFBE7A/sDk4DLKq5nWxkHLN7Ggd603kX1GiUN2Ia1fCDU6zU1WzuFGoCIWA3cTxFuACQdLOk/Ja2T9Ot0Jd3UNlLSDEkvSXpN0k9r2k5MXcx16fV/WNO2XNJlkhan182QNETSMOBnwGhJb6af0ZKukDQzvfZnks6trTvV9Wfp8d6Sfi7pVUnPSvqL7rZV0i3AXwEXp/VMrl1Po23vsqwBkq6StFbSMuCEevtY0hzgM8D30no/mbqj10u6V9JbwGck7ZO6+OskLZJ0cm3tkq5L++JNSY9K+qika9L+fEbSpB5qCElTJS1LNX9b0g6p7UxJj6THh6b2jvR8v7T8vdPzuse4LEkjJN0jaU1a9j2Sxqa2L0ha0GX+v5N0V3o8OO33FyT9TtIPJO2U2o6UtFLSJZJWAzMk7Z6Wvy6dH7/atN1NExGV/wDLgcnp8VjgKeDa9HwM8ApwPMWb0DHp+ajU/h/AHcAIYBDwJ2n6JOBl4CBgAEV4lgODa9b5NNABjAQeBa5MbUcCK7vUeAVFlxzgy8CjNW0TgXXAYGAYsAI4i+L2ZhJF93pinW2/ZdN6u1lPo21/CDg7PT4HeKZme+YCAQyss973X1tTx+vAYWldw4GlwOXAjsBRwHpgr5r51wJ/BAwB5gDPpX0zALiSontf75hHqnEk8HHgNzXbcibwSM2830zL3ymdG+f24hhPbrTfgd2AzwND03b/G/DT1DYYeBXYp+a1TwCfT4+/C9ydtmM48O/At2rOo3eAf07L2Qn4FvADinN1EHA4oKbmqepA1+z8N9NJE8CDwK6p7RLg1i7z358O4MeA94AR3SzzeuAbXaY9y+bQLwfOqWk7HvhtyVAPB94CxtWcdDenx38J/KrLa38ITOtDqOtuezehntNle/6U3of6RzXPDwdWAzvUTLsNuKJm/htq2s4DltQ8/wNgXQ/HPIDjap7/LfBgenwmW4Z6ELCAItD3bQpByWPcMNTdtO0PvNblXPpmevwp4DWKkCqdB5+omfcQ4Lma82gDMKSm/evAXcAercpTO3W/PxcRwyl2xN7A7mn6OOALqbuyTtI64I8pAt0BvBoRr3WzvHHARV1e1wGMrplnRc3j57u01RUR6yl6CKemSaexeWBnHHBQl/V+CfhomWV3sw31tr2r0Wy9Pb1V+/rRwIqIeK/LMsfUPP9dzePfd/N8516sr+7+j4iNFCHcF7g6Ujood4wbkjRU0g8lPS/pDeBhYFdtHlf4F+CLkgScAfw4It4GRlFc3RfUrP++NH2TNRHxfzXPv03RA3og3Xpc2ptay2iL0e9aEfHLdK95FfA5igN/a0R8peu8kj4GjJS0a0Ss69K8guLd9Zs9rK6j5vHHgZc2lVGi1NuAaZIepuh+zq1Z7y8j4pgSy2ik7rZ3YxVbb09v1W73S0CHpB1qgr2pm9wsHWwerKvd/1uQNAaYBswArpb06RSqMse4jIuAvYCDImK1pP0putgCiIjHJG2g6L18Mf1Acfvxe+BTEfFinWVvcS6lC8JFFG9G+wJzJD0eEQ/2cxve105X6lrXAMdI2g+YCZwk6dg0GDQkDUCMjYhVFINa16XBjkGSjkjLuAE4R9JBKgyTdIKk4TXr+aqksZJGAn9PcW8OxRVnN0m79FDjvRRXiq8Dd9Sc+PcAn5R0RqpnkKRPS9qnD/uh7rZ3M++Pgalpe0YA/b0CzAP+l2IQb5CKAbqTgNv7udxaX0vHrQM4n837/33p6ngLcBPw1xRvXt9IzWWOcRnDKcK5Lp0L07qZ50fA94CNEfEIQDrmNwDflfThVO8YScfWW1Ea2NsjbdfrFL/Gfa/e/H3RlqGOiDUUO/EfI2IFcArFgM0ainfnr7G59jOAjRSDRC8DF6RlzAe+QnEgXqPo8pzZZVX/CjwALAN+SzG4Q0Q8Q3ElXpa6VVt159KV4k5gclrOpunrKe5nT6W48qxm80BJb/dDo22vdQPF/favgYWptj6LiA0UIf4sxRXpOuDLad80y10U98pPUtzO3NTNPFOBDwP/kLrdZwFnSTq85DEu4xqKQay1wGMUXeiubqXo/s/sMv2StN7HUtf9FxRX/Xr2TPO8CfwXcF1EzO1h/l7bNODwgaPiAy9nR8Qvqq7lg0hSAHtGxNKqaykj/ZrqZeCAiPifquvpSVteqc3a0N8Aj7d7oKENB8rM2k3q1Yli4LbtfWC732a5cvfbLDMt6X7vPnJAjO8Y1IpFmxmwfMVG1r76rrpra0mox3cM4r/v72g8o5n1yYHHrqjb5u63WWYcarPMONRmmXGozTLjUJtlxqE2y0ypUEs6TsV3bS1txR91m1nzNAx1+vaH71P8Cd5E4DRJE1tdmJn1TZkr9YHA0ohYlv7G9naKv/E1szZUJtRj2PK7pFay5fdUASBpiqT5kuaveeXdZtVnZr3UtIGyiJgeEZ0R0TlqN38PvFlVyoT6Rbb8QruxaZqZtaEyoX4c2FPSBEk7Unz31t2tLcvM+qrhX2lFxDsq/sXM/RT/BeHmiKj7P5jMrFql/vQyIu6l+EpcM2tz/kSZWWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmWkYakkdkuZKWixpkaTzt0VhZtY3A0vM8w5wUUQslDQcWCDp5xGxuMW1mVkfNLxSR8SqiFiYHq8HlgBjWl2YmfVNr+6pJY0HJgHzummbImm+pPlrXnm3OdWZWa+VDrWknYGfABdExBtd2yNiekR0RkTnqN0GNLNGM+uFUqGWNIgi0LMi4s7WlmRm/VFm9FvATcCSiPhO60sys/4oc6U+DDgDOErSk+nn+BbXZWZ91PBXWhHxCKBtUIuZNYE/UWaWGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDOlQy1pgKQnJN3TyoLMrH96c6U+H1jSqkLMrDlKhVrSWOAE4MbWlmNm/VX2Sn0NcDHwXr0ZJE2RNF/S/DWvvNuU4sys9xqGWtKJwMsRsaCn+SJiekR0RkTnqN0GNK1AM+udMlfqw4CTJS0HbgeOkjSzpVWZWZ81DHVEXBYRYyNiPHAqMCciTm95ZWbWJ/49tVlmBvZm5oh4CHioJZWYWVP4Sm2WGYfaLDMOtVlmHGqzzDjUZpnp1ei35esTd5zTkuXuceFjLVnuB91v4pW6bb5Sm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhQRTV/ohzQyDtLRTV+umRXmxYO8Ea+quzZfqc0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzJQKtaRdJc2W9IykJZIOaXVhZtY3Zf+V7bXAfRHx55J2BIa2sCYz64eGoZa0C3AEcCZARGwANrS2LDPrqzLd7wnAGmCGpCck3ShpWNeZJE2RNF/S/I283fRCzaycMqEeCBwAXB8Rk4C3gEu7zhQR0yOiMyI6BzG4yWWaWVllQr0SWBkR89Lz2RQhN7M21DDUEbEaWCFprzTpaGBxS6sysz4rO/p9HjArjXwvA85qXUlm1h+lQh0RTwKdLa7FzJrAnygzy4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZaZUqGWdKGkRZKelnSbpCGtLszM+qZhqCWNAaYCnRGxLzAAOLXVhZlZ35Ttfg8EdpI0EBgKvNS6ksysPxqGOiJeBK4CXgBWAa9HxANd55M0RdJ8SfM38nbzKzWzUsp0v0cApwATgNHAMEmnd50vIqZHRGdEdA5icPMrNbNSynS/JwPPRcSaiNgI3Akc2tqyzKyvyoT6BeBgSUMlCTgaWNLassysr8rcU88DZgMLgafSa6a3uC4z66OBZWaKiGnAtBbXYmZN4E+UmWXGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhlFRPMXKq0Bni8x6+7A2qYX0DrbU73bU62wfdXbDrWOi4hR3TW0JNRlSZofEZ2VFdBL21O921OtsH3V2+61uvttlhmH2iwzVYd6e/vn9dtTvdtTrbB91dvWtVZ6T21mzVf1ldrMmsyhNstMZaGWdJykZyUtlXRpVXU0IqlD0lxJiyUtknR+1TWVIWmApCck3VN1LT2RtKuk2ZKekbRE0iFV19QTSRem8+BpSbdJGlJ1TV1VEmpJA4DvA58FJgKnSZpYRS0lvANcFBETgYOBr7ZxrbXOB5ZUXUQJ1wL3RcTewH60cc2SxgBTgc6I2BcYAJxabVVbq+pKfSCwNCKWRcQG4HbglIpq6VFErIqIhenxeoqTbky1VfVM0ljgBODGqmvpiaRdgCOAmwAiYkNErKu2qoYGAjtJGggMBV6quJ6tVBXqMcCKmucrafOgAEgaD0wC5lVbSUPXABcD71VdSAMTgDXAjHSrcKOkYVUXVU9EvAhcBbwArAJej4gHqq1qax4oK0nSzsBPgAsi4o2q66lH0onAyxGxoOpaShgIHABcHxGTgLeAdh5fGUHRo5wAjAaGSTq92qq2VlWoXwQ6ap6PTdPakqRBFIGeFRF3Vl1PA4cBJ0taTnFbc5SkmdWWVNdKYGVEbOr5zKYIebuaDDwXEWsiYiNwJ3BoxTVtpapQPw7sKWmCpB0pBhvurqiWHkkSxT3fkoj4TtX1NBIRl0XE2IgYT7Ff50RE211NACJiNbBC0l5p0tHA4gpLauQF4GBJQ9N5cTRtOLA3sIqVRsQ7ks4F7qcYQbw5IhZVUUsJhwFnAE9JejJNuzwi7q2wppycB8xKb+7LgLMqrqeuiJgnaTawkOK3Ik/Qhh8Z9cdEzTLjgTKzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDP/D4wuQNxPJtmXAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "I2APaCzDGeqP", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/WIP/6-gated_pixelcnn_cropped/cropped_gated_pixelcnn.ipynb b/WIP/6-gated_pixelcnn_cropped/cropped_gated_pixelcnn.ipynb new file mode 100644 index 0000000..2b30544 --- /dev/null +++ b/WIP/6-gated_pixelcnn_cropped/cropped_gated_pixelcnn.ipynb @@ -0,0 +1,624 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "cropped gated_pixelcnn.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "k1uZnxh4Xz9Z", + "colab": {} + }, + "source": [ + "import tensorflow as tf\n", + "gpu_devices = tf.config.experimental.list_physical_devices('GPU')\n", + "for device in gpu_devices: tf.config.experimental.set_memory_growth(device, True)\n", + "\n", + "import random as rn\n", + "import time\n", + "\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow.keras.utils import Progbar\n", + "from tensorflow import keras\n", + "from tensorflow.keras import initializers\n", + "from tensorflow import nn" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "NN6vJl7eVnZ4", + "colab": {} + }, + "source": [ + "# Defining random seeds\n", + "random_seed = 42\n", + "tf.random.set_seed(random_seed)\n", + "np.random.seed(random_seed)\n", + "rn.seed(random_seed)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "8BnkhgCjVpJu", + "colab": {} + }, + "source": [ + "# Loading data\n", + "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", + "\n", + "height = 28\n", + "width = 28\n", + "n_channel = 1\n", + "\n", + "x_train = x_train.astype('float32') / 255.\n", + "x_test = x_test.astype('float32') / 255.\n", + "\n", + "x_train = x_train.reshape(x_train.shape[0], height, width, n_channel)\n", + "x_test = x_test.reshape(x_test.shape[0], height, width, n_channel)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "3ne-qY7JVZaB", + "colab": {} + }, + "source": [ + "def quantise(images, q_levels):\n", + " \"\"\"Quantise image into q levels\"\"\"\n", + " return (np.digitize(images, np.arange(q_levels) / q_levels) - 1).astype('float32')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "3QVhnMymVrzc", + "colab": {} + }, + "source": [ + "# Quantise the input data in q levels\n", + "q_levels = 2\n", + "x_train_quantised = quantise(x_train, q_levels)\n", + "x_test_quantised = quantise(x_test, q_levels)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "ZObIXqzNGwmo", + "colab": {} + }, + "source": [ + "# Creating input stream using tf.data API\n", + "batch_size = 192\n", + "train_buf = 10000\n", + "\n", + "train_dataset = tf.data.Dataset.from_tensor_slices((x_train_quantised / (q_levels - 1),\n", + " x_train_quantised.astype('int32')))\n", + "train_dataset = train_dataset.shuffle(buffer_size=train_buf)\n", + "train_dataset = train_dataset.batch(batch_size)\n", + "\n", + "test_dataset = tf.data.Dataset.from_tensor_slices((x_test_quantised / (q_levels - 1),\n", + " x_test_quantised.astype('int32')))\n", + "test_dataset = test_dataset.batch(batch_size)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "75VTDkK8VZLA", + "colab": {} + }, + "source": [ + "class VerticalConv2D(keras.layers.Conv2D):\n", + " \"\"\"https://github.com/JesseFarebro/PixelCNNPP/blob/master/layers/VerticalConv2D.py\"\"\"\n", + " def __init__(self,\n", + " filters,\n", + " kernel_size,\n", + " **kwargs):\n", + " if not isinstance(kernel_size, tuple):\n", + " kernel_size = (kernel_size // 2 + 1, kernel_size)\n", + "\n", + " super(VerticalConv2D, self).__init__(filters, kernel_size, **kwargs)\n", + "\n", + " self.pad = tf.keras.layers.ZeroPadding2D(\n", + " (\n", + " (kernel_size[0] - 1, 0), # Top, Bottom\n", + " (kernel_size[1] // 2, kernel_size[1] // 2), # Left, Right\n", + " )\n", + " )\n", + "\n", + " def call(self, inputs):\n", + " inputs = self.pad(inputs)\n", + " output = super(VerticalConv2D, self).call(inputs)\n", + "\n", + " return output\n", + "\n", + "\n", + "class HorizontalConv2D(keras.layers.Conv2D):\n", + " def __init__(self,\n", + " filters,\n", + " kernel_size,\n", + " **kwargs):\n", + "\n", + " if not isinstance(kernel_size, tuple):\n", + " kernel_size = (kernel_size // 2 + 1,) * 2\n", + "\n", + " super(HorizontalConv2D, self).__init__(filters, kernel_size, **kwargs)\n", + " self.pad = tf.keras.layers.ZeroPadding2D(\n", + " (\n", + " (kernel_size[0] - 1, 0), # (Top, Bottom)\n", + " (kernel_size[1] - 1, 0), # (Left, Right)\n", + " )\n", + " )\n", + "\n", + " def call(self, inputs):\n", + " inputs = self.pad(inputs)\n", + " outputs = super(HorizontalConv2D, self).call(inputs)\n", + "\n", + " return outputs" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "g625fSNYR9I6", + "colab_type": "code", + "colab": {} + }, + "source": [ + "filters = 1\n", + "kernel_size = 3\n", + "vertical_conv = VerticalConv2D(filters=2 * filters,\n", + " kernel_size=kernel_size)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "DqsfC9JnS8C8", + "colab_type": "code", + "outputId": "ef53f88a-c437-4aa4-b40f-0184e140b1b3", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "vertical_conv.kernel_size\n" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(2, 3)" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 9 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "PTUN4s52Nu3w", + "colab": {} + }, + "source": [ + "class GatedBlock(tf.keras.Model):\n", + " \"\"\" Gated block of the Gated PixelCNN.\"\"\"\n", + "\n", + " def __init__(self,\n", + " mask_type,\n", + " filters,\n", + " kernel_size):\n", + " super(GatedBlock, self).__init__(name='')\n", + "\n", + " self.mask_type = mask_type\n", + " self.vertical_conv = VerticalConv2D(filters=2 * filters,\n", + " kernel_size=kernel_size)\n", + " \n", + "\n", + " if mask_type =='A':\n", + " self.horizontal_conv = keras.layers.Conv2D(filters=2 * filters, \n", + " kernel_size=1)\n", + "\n", + " else: \n", + " self.horizontal_conv = HorizontalConv2D(filters=2 * filters,\n", + " kernel_size=kernel_size)\n", + "\n", + " self.padding_A = keras.layers.ZeroPadding2D(padding=(0, (1,0)))\n", + " self.cropping_A = keras.layers.Cropping2D(cropping=(0, (0, 1)))\n", + "\n", + " self.padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n", + " self.cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", + "\n", + " self.v_to_h_conv = keras.layers.Conv2D(filters=2 * filters, kernel_size=1)\n", + "\n", + " self.horizontal_output = keras.layers.Conv2D(filters=filters, kernel_size=1)\n", + "\n", + " def _gate(self, x):\n", + " tanh_preactivation, sigmoid_preactivation = tf.split(x, 2, axis=-1)\n", + " return tf.nn.tanh(tanh_preactivation) * tf.nn.sigmoid(sigmoid_preactivation)\n", + "\n", + " def call(self, input_tensor):\n", + " v = input_tensor[0]\n", + " h = input_tensor[1]\n", + "\n", + " vertical_preactivation = self.vertical_conv(v) # NxN\n", + "\n", + " # Shifting feature map down to ensure causality\n", + " v_to_h = self.padding(vertical_preactivation)\n", + " v_to_h = self.cropping(v_to_h)\n", + " v_to_h = self.v_to_h_conv(v_to_h) # 1x1\n", + "\n", + " horizontal_preactivation = self.horizontal_conv(h) # 1xN\n", + " if self.mask_type == 'A':\n", + " horizontal_preactivation = self.padding_A(horizontal_preactivation)\n", + " horizontal_preactivation = self.cropping_A(horizontal_preactivation)\n", + " \n", + " \n", + " v_out = self._gate(vertical_preactivation)\n", + "\n", + " horizontal_preactivation = horizontal_preactivation + v_to_h\n", + " h_activated = self._gate(horizontal_preactivation)\n", + " h_activated = self.horizontal_output(h_activated)\n", + "\n", + " if self.mask_type == 'A':\n", + " h_out = h_activated\n", + " elif self.mask_type == 'B':\n", + " h_out = h + h_activated\n", + "\n", + " return v_out, h_out" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "WB57YufrVxn2", + "colab": {} + }, + "source": [ + "# Create Gated PixelCNN model\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=64, kernel_size=3)([inputs, inputs])\n", + "\n", + "for i in range(7):\n", + " v, h = GatedBlock(mask_type='B', filters=64, kernel_size=3)([v, h])\n", + "\n", + "x = keras.layers.Activation(activation='relu')(h)\n", + "x = keras.layers.Conv2D(filters=128, kernel_size=1, strides=1)(x)\n", + "\n", + "x = keras.layers.Activation(activation='relu')(x)\n", + "x = keras.layers.Conv2D(filters=q_levels, kernel_size=1, strides=1)(x)\n", + "\n", + "gated_pixelcnn = tf.keras.Model(inputs=inputs, outputs=x)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "_LnzHUaqV77d", + "colab": {} + }, + "source": [ + "# Prepare optimizer and loss function\n", + "lr_decay = 0.999995\n", + "learning_rate = 1e-3\n", + "optimizer = keras.optimizers.Adam(lr=learning_rate)\n", + "\n", + "compute_loss = keras.losses.CategoricalCrossentropy(from_logits=True)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "CsAgEKVzLCJD", + "colab": {} + }, + "source": [ + "@tf.function\n", + "def train_step(batch_x, batch_y):\n", + " with tf.GradientTape() as ae_tape:\n", + " logits = gated_pixelcnn(batch_x, training=True)\n", + "\n", + " loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits)\n", + "\n", + " gradients = ae_tape.gradient(loss, gated_pixelcnn.trainable_variables)\n", + " gradients, _ = tf.clip_by_global_norm(gradients, 1.0)\n", + " optimizer.apply_gradients(zip(gradients, gated_pixelcnn.trainable_variables))\n", + "\n", + " return loss" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "NoEPrfwQNM-s", + "outputId": "4a9422c4-8ce1-4310-8783-882be0c7c924", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 476 + } + }, + "source": [ + "# Training loop\n", + "n_epochs = 50\n", + "n_iter = int(np.ceil(x_train_quantised.shape[0] / batch_size))\n", + "for epoch in range(n_epochs):\n", + " progbar = Progbar(n_iter)\n", + " print('Epoch {:}/{:}'.format(epoch + 1, n_epochs))\n", + "\n", + " for i_iter, (batch_x, batch_y) in enumerate(train_dataset):\n", + " optimizer.lr = optimizer.lr * lr_decay\n", + " loss = train_step(batch_x, batch_y)\n", + "\n", + " progbar.add(1, values=[('loss', loss)])" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Epoch 1/50\n", + "313/313 [==============================] - 167s 533ms/step - loss: 0.1132\n", + "Epoch 2/50\n", + "313/313 [==============================] - 160s 512ms/step - loss: 0.0888\n", + "Epoch 3/50\n", + "188/313 [=================>............]" + ], + "name": "stdout" + }, + { + "output_type": "error", + "ename": "KeyboardInterrupt", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mprogbar\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py\u001b[0m in \u001b[0;36madd\u001b[0;34m(self, n, values)\u001b[0m\n\u001b[1;32m 676\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 677\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 678\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_seen_so_far\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 679\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 680\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, current, values, finalize)\u001b[0m\n\u001b[1;32m 638\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m' - %s:'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 640\u001b[0;31m \u001b[0mavg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 641\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mavg\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1e-3\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 642\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m' %.4f'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mavg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36mmean\u001b[0;34m(*args, **kwargs)\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36mmean\u001b[0;34m(a, axis, dtype, out, keepdims)\u001b[0m\n\u001b[1;32m 3333\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3334\u001b[0m return _methods._mean(a, axis=axis, dtype=dtype,\n\u001b[0;32m-> 3335\u001b[0;31m out=out, **kwargs)\n\u001b[0m\u001b[1;32m 3336\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3337\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/_methods.py\u001b[0m in \u001b[0;36m_mean\u001b[0;34m(a, axis, dtype, out, keepdims)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_mean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0marr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0masanyarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0mis_float16_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/_asarray.py\u001b[0m in \u001b[0;36masanyarray\u001b[0;34m(a, dtype, order)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \"\"\"\n\u001b[0;32m--> 138\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msubok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 139\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "ue0vZbitSNmz", + "outputId": "42de5250-7404-4fb2-daff-e5283be9f53e", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + } + }, + "source": [ + "# Test set performance\n", + "test_loss = []\n", + "for batch_x, batch_y in test_dataset:\n", + " logits = gated_pixelcnn(batch_x, training=False)\n", + "\n", + " # Calculate cross-entropy (= negative log-likelihood)\n", + " loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits)\n", + "\n", + " test_loss.append(loss)\n", + "print('nll : {:} nats'.format(np.array(test_loss).mean()))\n", + "print('bits/dim : {:}'.format(np.array(test_loss).mean() / np.log(2)))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "nll : 0.08501224964857101 nats\n", + "bits/dim : 0.12264675098280793\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "-Ia9VXYySkuW", + "outputId": "ad1075b1-70cc-4012-851e-925cae32ac2c", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 581 + } + }, + "source": [ + "# Generating new images\n", + "samples = np.zeros((100, height, width, n_channel), dtype='float32')\n", + "for i in range(height):\n", + " for j in range(width):\n", + " logits = gated_pixelcnn(samples)\n", + " next_sample = tf.random.categorical(logits[:, i, j, :], 1)\n", + " samples[:, i, j, 0] = (next_sample.numpy() / (q_levels - 1))[:, 0]\n", + "\n", + "fig = plt.figure(figsize=(10, 10))\n", + "for i in range(100):\n", + " ax = fig.add_subplot(10, 10, i + 1)\n", + " ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary)\n", + " plt.xticks(np.array([]))\n", + " plt.yticks(np.array([]))\n", + "plt.show()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj4AAAI0CAYAAAAdqSPKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3dW3LkOLKtYfDYGYLyuWIOrfmPQJqD6rk0B+6HzOimongBQFyWu/+fWZvtnRWS4ATIcDpAcFnXNQEAAETw/2Y3AAAAYBQSHwAAEAaJDwAACIPEBwAAhEHiAwAAwiDxAQAAYZD4AACAMEh8AABAGCQ+AAAgjP9f8uG3t7f18Xh0ako/X19f6fv7e7n6nNX4Ukrp8/Pze13XX1efsxpjbh+mRIzKvJ+L9OFPxKgr8rlYlPg8Ho/08fHRplUDvb+/Z33OanwppbQsy985n7MaY24fpkSMyryfi/ThT8SoK/K5yFQXAAAIg8QHAACEQeIDAADCIPEBAABhFC1ubmVZshbLn1rXtUFLgGN745RxB2AWrkltdE98WiQ5Ob93VufnxOdxYPboV6XjdBTf89+V2opjJeOUPoUqEp62mOoCAABhTJnq8iL3bnJZFhfZea/q3fb3ezhOmKv3OAVGYSz3QcUHAACE0bXi4zVbzZ1vtRb/7PYqVHtmHwMrtsdJod+OKLcNOMO1qJ8uiU+kDsu9sCovih3RX4px41zJVG5K8/s40nUHx47Gwezx2YKHGJ5Kz9eWsTPVBQAAwmBx8yDKmfpZ266ycuW4SuTefXiIN0plxENf4VjpOFapTF7xen7ejavl9DoVHwAAEEbTio/XTLWUp+OgfnfUQoRqT88xqXBcPJ1zOEY/23RUrantz7vVO/NTXbMvupyIdkVIeO54jZuxrqN1XyiPccadfUfjq2TctRwHTHUBAIAwTFZ8Zt+dPP9+SQaaU5rjfSw6vB13b/FEFan6ESnWK5y/bVHxAQAAYZio+Khmu3cWaVnbHA5xnY1VlXG5ruu/2hmhgloTj3Il5U7brq7H3vreu57XHdnEx9ogzX1lRc6JbS12+KL8xXhX9BfhekwIchbhW48Rv7XqR6a6AABAGHIVH0+Zee5UmKeYYZvlR/zv7EDunXoVpMV7m6L3sSe9xysVHwAAEEbTik/NY94RcDygzvsY3au+KlU8cr1eYy1X6FJqM+6opvsxqjJJxQcAAIQht8YnGu5I/LF0F+690nPGU+Xn6L9HZTF+9XVYPY2OvUvi0+IlZJbxyLoNLfrA8viuaTvjVku0/sj5bol2TCyb1YdMdQEAgDC6T3VFe8PzVXxe70aiVA88VPNaPDps1TZ2T3FZlfNADP2E1qj4AACAMFjc3IjnSk+vKp21xaXWFx9aWnTdivU+i6K0Tzz2a7TtYGb2IYnPIK+dbOUkjXISnvFwkY2Q9DBd4l+E61HEsTo6Zqa6AABAGMMqPhEy9RLWpnl6Uj0WVqt0NTzG5jEm/Bv9bNPMfqPiAwAAwuhe8YlS6dlbmJaz2Zbnx2stvkHZw3qeXJ7i8nweRcdGhWita+Kj/iXXw9HJmLNiX3XKJ9fRnk0W4vGe8HiKBTGwWB29MNUFAADCkHmcPUoGb22vhlbvs3r+nrPpv+3nRmKaxIfcvou0aD2HwviPsN0CdFDxAQAAYUhUfCJm8dYqP3eo9q/CnS7aOVtT5n0Nl3cR+4oqWD9TE5/IHeY54bG0qBn+bMcfCc+xWdegyC/JfWp57LmBK8dUFwAACGMpyRCXZfknpfR3v+Z089e6rr+uPmQ4vpT8x5gVX0rEKI5x+gcxSiPGPzzGV5T4AAAAWMZUFwAACIPEBwAAhEHiAwAAwiDxAQAAYZD4AACAMEh8AABAGCQ+AAAgjKJXVry9va2Px6NTU/r5+vpK39/fl3uEW40vpZQ+Pz+/czajshpjbh+mRIzKvJ+L9OFPxKgr8rlYlPg8Ho/08fHRplUDvb+/Z33OanwppbQsS9bOmlZjzO3DlIhRmfdzkT78iRh1RT4XmeoCAABhkPgAAIAwSHxQbFmypr4BAJBD4gMAAMIoWtyM2Kj0AABGOvveWde16nfKJj7bYGuDAwD41uOLcaacG0yvcW09Y9yL9fm7lmWpOhZMdQEAgDBkKz5R7GXBatk8U1yxHPW32rhEPKXXIiszB16vsb3ievblsiz//Rsl/UvFBwAAhCFX8fGa+b46i1O9CqTUFtRTH2dAStffCWdjVvX75G67ate2jNDimOfGtq5r1d+TS3xeqXZuqdoFeHcXcd1lpVSsSDWxsJx0w7fW4+/ulMgoV21STeBSatu20r6p/dtMdQEAgDBkKj6qGe2disedxxJnH4/Zf98yS8due0d89N9gn3r/jmyfcuXnyN55qhDHrGvd3ZkIKj4AACAMiYqPpTvkK7mxlC7Im5nVW7ozUqVwDL1t9Fai9hpj/bjkrOdSjbHldbB2EWwvqse81Ojj2upvTU18rF2IW3bw3kVHLeFBPot73yi37Y7W56nF41RyDBRiPNud9+jfStt8Nq2rTvm7oeVxzf1efP3bpZjqAgAAYUyp+NzN4BTuUGrl3tnMiu/qLislnTsNdUrHyeJdbqkIMeaoOQ6K015Xd/211yWlGD2pOa5HYzV3DPOuLgAAgAvDKz531kLMuKNrMX9ZOxetSqlCNZvFKoPHvrLYD62NvEYp8rjZqvdxPWvd1bDEp3R6a1t6Vej8knaUnHQKsW2VDkQvF5iWlI6J2vhS8No/V8dIcRroKUKyc5TQeN+B3NrDP3ecnZNnfc7iZgAAgAvdKz61lZ6cnx+td5atlMUrtWWEu+P06HPQcNY3lh9zrmFlnNa+nPRqwax6/JEqPUdyr7m1qPgAAIAwulZ8ShYyl2RzHrJejwvxvOl91zGLlTvfFrzG6Hkn6rvV1KvzVnX8U+nJd/d4UPEBAABhDH+cveYu2lO266FikMtyVat0nKrGF239inel/ag6LkdTeTr4iMVX3lgm8ZLSM9473mN8yheYGhYTnlfqF35ci7IcYC9ZV52eusvDtcUiproAAEAYcu/q8sp7zNbje73L9Dj92qLNHh/jV+/rqJuJep6mpdJzz90KIBUfAAAQhsy7ul55yno93iU/ncXm8U7NS7/VsL4xnFWez6eIWMg8n9ziZqudHymh83ji5vSf5fjuyB3b1hIgL4mEleM9k8ITpkxv3dPyGDHVBQAAwuha8cl9k67VbDfK46VPHvswJZ8VrBYsVUSOHtffq0JZimsr2njM7c/cn53F63XTMio+AAAgjGFrfDxltp4XK+8pjff186rHRumuUInH3YHp6zgsVHCV2hKR3OJmayIO4NzXOUQ8Nt4p92mPp5+U4/Uqp/+sJOck3JqY6gIAAGFQ8amwvbO09ghvqas7Fio9vnhYiNliIbOVWD1qWblT7EfFNkVDxQcAAIRBxeeG7SO0VuacW7CweBBlvK5FeB2T6u/lmmFZFhdxe4jhLu8zEK3IJT7WdkAuvbBaw0Jm36LsWO3tvGxhe0wsf2FabHNLd/Y7ioqpLgAAEMZSkhEuy/JPSunvfs3p5q91XX9dfchwfCn5jzErvpSIURzj9A9ilEaMf3iMryjxAQAAsIypLgAAEAaJDwAACIPEBwAAhEHiAwAAwiDxAQAAYZD4AACAMEh8AABAGEWvrHh7e1sfj0enpvTz9fWVvr+/L/estxpfSil9fn5+52xGZTXG3D5MiRiVeT8X6cOfiFFX5HOxKPF5PB7p4+OjTasGen9/z/qc1fhSSmlZlqydNa3GmNuHKRGjMu/nIn34EzHqinwuMtUFAMBAy7Lw4tyJSHwAAEAYRVNdAGJ7vUvlXX9Ame059Py/a84jzsV6VHwAAEAYJD4AYNRzrQjrRWLZ62/GQD6munBL7slGGVbHXp/l9A8XVh1nX3yca355Owdz4ukxnqn4AACAMKj4INuduw3uRuc767/tf8vpI/pxjpxzsLaih/Fy+8VbpSel/Jh6jGcqPgAAIAyTFZ+jTNHCXc1Zlqvafo93G1HU9J3F/o6w1qxFxTUl28fg6epYKMYY5Vy8Uvsd2GobgJSMJD6zFkC1ZLVE3erEmx1HNL0umGr9WBrnneMyK3aPX36lWvSz2thNqb5v13U1t4/P3T7Zi7kWU10AACAM+YqPxZLmq9JMV+EOr7QNFvqhN2vTmNs2WauqKpwjyp595eE4tYrBwwMW27Zb6uPWMwfP37csS1V/UvEBAABhyFd8jljI2mvnNLdzmep3Kart6qX2zmVkP5ZUSa1VenDNalWgRITHwEvPu9rqxyyt2lpzbZVNfKxNG2xZPtlyqB//1nKTA/V+f/0StNyPOV/oLV78qM5yH+a6E6NS4nfVFo992WuR+d1+ZaoLAACEIVvx2WM5I7bc9icPMeSqWVT/+m9H71NSeSy6ZudUtTHQ8o5eoSqw506MqjH1phy32jnUw4ztBEoq2VR8AABAGHIVH+VMPYeVjbOubO8yLbb/rqM1Oy023PKwxkbNnWPpYcuMI0prXPb0OBciraFRNOs70MXi5i2rA7b0yYO9i9Ts2Gf//VksJ7B3FlqXPvFlUUlcVvrcopavIIjI2s7NLd29NjHVBQAAwpCp+Hi9uzxyNgWSUqzsXdFe1aS0f2aO6bMpjjvTmNbvzC1fZ7w8nq/YJvShep2g4gMAAMKQqfgcUc0Yj1wtaM39HZjvrGpi5S3fZ3+rph3Wx2bu+ifrcSrrteDa8qa3paxUzVpXiHk7OwAAQCH5io81e08q5PJ2V+JFyyed6OP5zjaa9N4/6vFF3T4jl+WnTUvlXmtr4pdIfKyU7V61arfXgbvHejnaQht78Rq7x7gsXFP3ph1L+sLTlKWnm+Tc6eQWY7T2ODDVBQAAwpCo+FhTk6kqZ+gjWLgDxW/0FWY5mnaMPiatf3+07r+7x4OKDwAACEO24qOc4V49jqnc9hk8vwvJm0iLJ72y1IcetozoyWIcatWdPc0Sn4hJgOfYWiDhAZCjxd4+lq8nltt+x1G/9z4eTHUBAIAwbld87jyeHH3BmleeHjONwtL0CPZ5uJ5ePQrNmLSr9S7yd1DxAQAAYdyu+Nx50zN8sVzpsdz2O6j02He2PkK9CtRzd17oUOu/ZoubWwWmdoCQR/0Ce+XsNQZ7/+ZhnFrvM9jqQ0ttHY3CwVhMdQEAgDBk9/GBttq7t+fPqd/dXE0XjH78sgXuuP0onaK0dN6dnW/q7b/D4jVlT+67umai4gMAAMKg4gNcyNlcTX2OvuQOLMLdtVWWF6PntjNnN2crMd+hfk0poRYLiQ+6OZsuUjsRclgo4ZbYO/7b+CJ9yaRkM6m4at9rEmHpvMt9nYWVeM4cXSutnoN7facUC1NdAAAgjKUk+1qW5Z+U0t/9mtPNX+u6/rr6kOH4UvIfY1Z8KRGjOMbpH8QojRj/8BhfUeIDAABgGVNdAAAgDBIfAAAQBokPAAAIg8QHAACEQeIDAADCIPEBAABhkPgAAIAwSHwAAEAYRe/qent7Wx+PR6em9PP19ZW+v78vX7JkNb6UUvr8/PzO2YXTaoy5fZgSMSrzfi7Shz8Ro67I52JR4vN4PNLHx0ebVg30/v6e9Tmr8aWU0rIsWVuKW40xtw9TIkZl3s9F+vCn0hhVXj5KP/6Px/h4OzsQ2Oubr3mFDWZQenM3/GONDwAACIOKDxDMa5UHmIWxiBmo+AAAgDDkKj5ndwDM/6I3r2sNuLMGMNqdxeo9r8USiU/uRdnal1Ltl42V+LaWZTltt7WE9ioeD9Z1DZ0Q5cSuNgasnUeIaW+c7v3b0ZjtfV1iqgsAAIQxpeJzN5tTuhvvkZkqxXdlG3/tsbBWybPk7C4rQrXH6rUmQt9s1U6DjPhbPUSaDTgzq4JJxQcAAIQxpeKTc8e5l+2p3QWVtOc1HrVYSllv/6voG/m1iFetcmdxjN5p853jP6PvSmJt2ZezxmmLGErWyeDY1MXNpR22TZjULrJ7ztq2/W8WL9Atp0uU+9CTXsdZ5Vxs8aUw+1yccV3Ym66e3Zc5ats4uo97/z1LfXZlVAxMdQEAgDAkHmf3piRrnX2HOYvS3YnHPhg1dady7DxOAZRO97eIV/GYterbGWN15N+09FDM06z2UvEBAABhUPGpcOeuS+UOeSRLdyGW2lqr9SJLhWO2txljxHPNk15rthTG65OVdUqlah9gGsVU4qPc2bW7Fpf+LjWW94Sx2OYcXuO64u3JyVxWvzxLdu0l4fltdp+VOvt+mDk1x1QXAAAIw1TFZ0ttN9XtI4WeKzwlLMXVsq2zHy/tOe1jab+j3EfDLewZFgXVnn3Wx+PR9+KsayUVHwAAEIZ8xcdapmutva1sM3qlu6tRlPr9zo7irX6vmtw4rcXo5Vzzuqvx3b/f61yeQWk9qGzio7gavLTjckrt6oO1hKdYStVMcfZyNe68TR3Avu219fU6W3NeeRifKtcTj5jqAgAAYchWfPZYyeJzHqu1EgvOqfdt67YoxRZRy5fJtv69LWzbcadNlqfdIz4ckzMmWz7+TsUHAACEYaLiYy2jVVyfhHOW7xB7sfToOrB1th5TYRyzfufa2ePvWzX9ScUHAACEIVvxUcjKW/IWD/yi0uNHTmXBc//uPeE4u7rbutozO57ZauKXTXyUKe1HgHp7j9DmUi2h17IUj9oLUmfjOpRn7zF5xk9MTHUBAIAwqPjcoLRpHepFvuuzPq0VtfrT4roT6XhtzY77zvdG7vvnImFxMwAAwImmFZ+IGy+x3icWT68esVotiVppbfHqhojHTVHO98bVOWl1zVKL78y7cTZLfGqCsHrh3UMC5J+XvvV03kXVYjE+5ot8/tVM27FzMwAAQKFmFZ+j8nNuqdXzbsfW2x+Rxztk6wuZr3jfz8RrXMBoVHwAAEAYTRc359yR1Mzreb9T9eCqL9X77G6FR31hIet6YlKuXHp6UGAG5b7tpVVVd/g+Pi06i4u4PcpTmVH2RLHQRvhUco5xo3vN0k7rLbG4GQAAoJDszs1Wd6iMuseIN7njz8pd1nZcWmkz6qmM2c/PT/fTyCPx3dIGFR8AABDGsIrPncWv1nYfjX53cha/pb5TuWtuxWKbW1Efd7W87ZZPxfynnGOh3qettIyzS8VnWZZ//e9MlI7zbF3X//7Psqsx6yFG70qvP2e89PWscfuf//znx7Xh6n+R1Xxvejlmo2NhqgsAAISxlGRZy7L8k1L6u19zuvlrXddfVx8yHF9K/mPMii8lYhTHOP2DGKUR4x8e4ytKfAAAACxjqgsAAIRB4gMAAMIg8QEAAGGQ+AAAgDBIfAAAQBgkPgAAIAwSHwAAEEbRu7re3t7Wx+PRqSn9fH19pe/v78t9663Gl1JKn5+f3zmbUVmNMbcPUyJGZd7PRfrwJ2LUFflcLEp8Ho9H+vj4aNOqgd7f37M+ZzW+lFJaliVrZ02rMeb2YUrEqMz7uUgf/kSMuiKfi0x1AQCAMEh8AABAGCQ+AAAgDBIfAAAQRtHi5tmW5d8LtNXeLr/XxpT02gkA0Mb3SR9yic9RR79S6/irdj//u1q7AfzP2XnMuYvWcr/vzn6OcVmOqS4AABDGlIpPSZarns2WZuzeKj8R7jwsTLHiPqq2GKW20nP1uxibeaj4AACAMIZXfDwt1iqJ5fWzFrP0lncpqnJjtLwWpGUFa1kW+XhbiRIn0Eqv74y75+KwxMfiF32No9jOYlb/Es35orScFLVuu+JYz53GiWxd11DHgak7f2b06Yxz5u41lqkuAAAQhtzj7BEpV09yM2uLd401x/koTqU+2+rdLov9fuYZj2p/5rDc9ihqx1nOMoqeWvytFteMbTtqqlxUfAAAQBhUfBrz8uiz4jqVVlrcZe19Zu/3zlpHwV3/PVbG/J1+VozxarG8t+vSNobaa4VilXJk39T8LRKfG/YG7dF/v/O7R3qNw8PFZevq4uAt3j0RYvSqR9KuYG/qouRnUrIT65GW7Z9xwzXib7VK7pjqAgAAYQyr+FjPxq+0XrCFsbyNz73yt7cYo2m9B5rS4+wtFu0qxRNRy+Pf+52dVHwAAEAYrPGZjCrPfBGqddwF+xKhP7dVS+UtPyLpvZB61BpMEp/B1E/Y14Ed6ZUEV9T7bo/XvsvpCy+xe52uzL22eIrZixlPkrUcB0x1AQCAMKj4dHQnG569UG+b0c9uyywqu5Tit9rFriWU+6tX26jqznNVzfP2yP6R0e+rpOIDAADCoOLTSMsdVFXXkvR6XFHxLqZ0Q0rVPvOkZpy0qBKpbCaK8ZWBu3IX624r63d+3+vvVTdrjFPxAQAAYZiq+KhVCbytIejp7FiNvMvee6dW643hWv087pv9NmsLZo7T3CqytT6ree1GiyelLF1zzr7P7759/YpM4uN9l86cdqqf3D2TkZFyLjC5bVLvM/wbU5a/zb52ej7HSpOY7bG487MKjhbLl9zgHt2gtoqVqS4AABDGlIqPxQy+RGlWqrSgsgfV+O60wdoiS1zzOE4jUDo+LTbWVFvSUaPXBqOtZnqo+AAAgDCGVXxqqzx7c30qrGbjOVSP+WxUevxRHevRx5Nqv5zpVUVWM2IBds/XYnRPfCK9U6cVtePR4ukmSyf1EY9Jz17pOMo5qzq1Bajy8sQZU10AACCMKYubWzzarZQ91lCtgPRul/V+e2UpntK9lGp/V0rax6XXvk0AflM/l6j4AACAMExuYKieTe6Jsm7iiOXYVKtzCpTWb6m0A2NZ28zWsr0FxxaPu0Ti4/WCZXXnTfzmdVzWUhynLfpIMa7IOO/0jThneo4DproAAEAY3Ss+OS8fs+zO/kRWRCwl5y7etfRIdI92zTqHvTxWC2D87AgVHwAAEMbwNT4tMraWb2kdzWq7o8i988h5NNxzX8/czNHyo/S4djZLkPN5q/aqyJ6uJXerwy2PwdTFzbmB7B0wTwNCXfRjvBd/zklsOUHfM3tq2uPO2aNYfwrnzOu48BafZaoPHzDVBQAAwlhKsqllWf5JKf3drznd/LWu66+rDxmOLyX/MWbFlxIximOc/kGM0ojxD4/xFSU+AAAAljHVBQAAwiDxAQAAYZD4AACAMEh8AABAGCQ+AAAgDBIfAAAQBokPAAAIo+iVFW9vb+vj8ejUlH6+vr7S9/f35d7ZVuNLKaXPz8/vnM2orMaY24cpEaMy7+ciffgTMeqKfC4WJT6PxyN9fHy0adVA7+/vWZ+zGl9KKS3LkrWzptUYc/swJWJU5v1cpA9/IkZdkc/FqS8pBTCG55dUAkAJ1vgAAIAwqPgAQVDpwbbyt8XYQCRUfAAAQBjDKz6sNfDh2Y/0ITDOUcUGti3LwrV0oKlTXXx52sTFF/Bl9rU495pS276936/wvVN73FXjsYKpLgAAEMawig8Zqn3RKj2v8Vocr63v5D0ck8i2/aVwPp+14dnW7Wdatnnm9JLCsY+Mig8AAAhjyhof7hJtiVit24uZBYj/5uGYeH7g4iyes4rKiOOQe13Z+zcvFZMWa3tQrnviE6U0nlOyrfnZKzOOZ69pkx5/o1RpX8xeFIoyV/3roT9btH1EQnvn99/52dl9TPLyb6O/C5jqAgAAYbBz8w25mbvlDH9025XuwnovrvRAvTJS218Wp7+YNtF2NH3+dNZ/itXxV70qaT3ORSo+AAAgjGEVH5WstIUId0q9FzQrLZi+asvZI8CexrU33s7TFvF4OyYl1GNXb9+eqypWr79397rbNfGx2JEtePsy7J3wzKKUfFlwp1Q/ktIYG6HHcVfpy1YUx+ldM+NofY7VTvPVYqoLAACE0aXi4zG73lLb/bSnkv47OhbrupqprnjvT89q+079fG6xsLX3u7AsmB2b+jgbZXY/pETFBwAABNK04hM5i42qxcaNCncAuSy1tQfV+FtUNKxUJVNqW+lRjbEF1WvM3TUts+PZbvXRqy0984mixOfz8/M00Nd/X5bF/VMwe3u9eHAnLit9fLUVvpU49vQYj5aPh8W23+lDC1+eI3iJUTWO0UkP+/gAAAAUKkp8/vOf/xRlXOu6/vd/T94qI1Ft+/b1f6W/Q0lNHNCV05dHlWnFMdCiXaqxRWdpHLb2jH3vGDy1PA5UfAAAQBjDd27eZnQeMlnvFaxefeSh72Gb8rnba30P5x1UlI7xlgupqfgAAIAwhr+dfbuZnafKD44p31l7xTl1zGpFJHd7iJqfxVylWylcXVMV+7vlu+ak39WVq+deANBCP2Okkout0tjc2+187zpp8Qswujt9Zm27jV43vXcTIKa6AABAGBIVH6ss7fQ6A1NcsMDSORv1nVt3ryWzjkdNu61uPnmnj3Liafl9QsUHAACEQcUHgHl7a2Ksq31tjGI1oEaLCsL2QRql4+J5rVbuuVga095GyLX9SuJTwdsFthev7zHzzvKC4JT8jTelY9xbq6k8lTEQ6Um8bXtHtr1moTNTXQAAIIxhFR9Pme0eDzGMwN5NOq7eibP33qC936HWl3sl8b2Y1NodWcupHgsPnXj/PuylVSWPig8AAAija8XH04ItwLKatRM556f6OfzaPvX24n9K+8pCpSclO+1U0/Kt7V0SH+9lPA8xjJS7C60F1hb+cpGFBTVjUmUBcy5r7VXR47gx1QUAAMLoUvE5W1yIa1arIaWsjI3aO44o/QiMlntOKixm9z4DUqtFJYd3dQEAAFzo/jh75Iy2FHPAWugPAD1E/l7s/U6vHOzcDDSiuHOwYpugRWE6qKW9pRZKPBzjGVoeN6a6AABAGEvhjpj/pJT+7tecbv5a1/XX1YcMx5eS/xiz4kuJGMUxTv8gRmnE+IfH+IoSHwAAAMuY6gIAAGGQ+AAAgDBIfAAAQBgkPgAAIAwSHwAAEAaJDwAACIPEBwAAhFH0yoq3t7f18Xh0ako/X19f6fv7+3LvcqvxpZTS5+fnd85mVFZjzO3DlIhRmfdzkT78iRh1RT4XixKfx+ORPj4+2rRqoPf396zPWY0vpZSWZcnaWdNqjLl9mBIxKvN+LtKHPxGjrsjnIi8pFeftBYIAAMzEGh8AABAGFR9R20pPSlR7AKBUhIr5M0av8fVAxQcAAIQhX/F5rXycIeMFEHkOW9MAABS+SURBVJ3HKkfJ98DV7+CYQC7xKe1ML4MYAO7wlhz0+mJXirGlZVncxdQLU10AACAMmYrPWXYfLYu1VKq+c1emHhti2BvDXsZmbnVDfdqktD/U47nLe3y9UfEBAABhTK34UOX5SbXSw1z7NdW+K+W1+lE6hi31593zU6160KI9zz7zNp7V+sqqKYmPt8F4l+KePa1PsLMLkaUEKOe4ePvStBTPK74oyljrX6sUr/mRMNUFAADCGF7x8V7tKX2kUOluuubueHabR8qdmrVSZcg9F7efs3CnauX4j3DWP0rXnq2z6rAVltseARUfAAAQxrCKz1kGrHrnUcvSmpUcXuIoVbv4fnvHqjgWSs+3s2qWtU3TctvKHTtq9H4QBG3I7OPzpPhFkaM2eVvX9b8/Ozv2bVt6/o2Ujhc5q/e7evvOWJimQn8Wv0RLr41KMeZMNx5d+zy9smlUn+QcB6a6AABAGN0rPjULKLf/v3oW21KkWNXdWYSvdLf55G06Gb9djbWcsag6Hvaqw2fnpcJ5d+dYenpP5cy+yJk5oOIDAADCkFnjo5S1j6acuacUa21Ir40bZ2lZ6fG+FUVKMca6t5gsf2d4/N6zEEvXxCfihdKy3Fg8Tpu0eH2K0pdmy3Fp7TyuXUBv6Vz2NC3y6m4/WNxTa8SDJb3ltr/XWLyaEt1iqgsAAIQhM9XlhaU7qxY8VH+O7g6s7sB9V5SXB3uO02r7W+8eb+k45O6ybdGIfiiZNqTiAwAAwuhS8SldE2BtDcGT9Sz81dUxL310Vn2uvfW6HvUx23LthAVH6yYsV3pK1+Gpx5PLSxy1lONXOZ9er8Xv7++Hn6XiAwAAwmha8Sm9o7Ra6YmqtIJj+ckTy22/w0scTyXz/t5it47+sEu977ovbs49AOoHao+HRxBr3X2vzNXv6qlln6mOW9V2zcLx0Bb1Oor7arYRYaoLAACEwePsjVl4w7gC1WOk2i4gCs5B5KqtFFLxAQAAYUyp+Cht7Y/+LPSvhTYCXnH++TLjwaWS3z888fG2iM3jS+Za4WLmbz8V2MB4g5qW+/3c/b5lqgsAAIQxrOJDRQQqRlXpGPOYgWoPRmpxPb3zszXjnYoPAAAIo3vFR+U9Hpawo/V4rMWBdYxdzNRyU9s7fzNH08SHhb51cl//wIWtrb3x2uKlo4ovLmX8ABgtNyd4vTYdfb7VNYypLgAAEEaXqa6rLM/TnWdtdSv35zwdK1VHL18tmf5SnZ5kCg+wxePykKP3WuZWevY+ewcVHwAAEEbXxc1Ws9MrpVkpa57suFr3U/I7FCi1BX3Qxz7krvVMyV+fj65yyb6ktHQx1Ai5nRM10XlNGiyfpDUJrMU4vfM2bWCxzb1Yvr68ivCdUXpT2bNfmeoCAABhLCVZ1bIs/6SU/u7XnG7+Wtf119WHDMeXkv8Ys+JLiRjFMU7/IEZpxPiHx/iKEh8AAADLmOoCAABhkPgAAIAwSHwAAEAYJD4AACAMEh8AABAGiQ8AAAiDxAcAAIRB4gMAAMIoelfX29vb+ng8OjWln6+vr/T9/X35MhSr8aWU0ufn53fOLpxWY8ztw5SIUZn3c5E+/IkYdUU+F4sSn8fjkT4+Ptq0aqD39/esz1mNL6WUlmXJ2lLcaoy5fZgSMSrzfi7Shz8Ro67I56Ls29mBGbZvC470OpeocQOIhzU+AAAgDCo+QPpZ8UiJqgcAeEXFBwAAhEHFZ5LXCsMrKg7jRF3fcjUGrdmL56w/qfL5wjVVV+m1pndfkfgMUPMF8/wZTtYxIh3ns/Ho/Th4S/YiWJalKIHN/Zz3sT5Li3Os980oU10AACAM+YqPxfJlSca71/7tzytWfnLj27ZZ8W4r2t2/53ivYotc5bKqtk+vrqnbf6Pv22l9fenZN1R8AABAGLIVn9J529mZe8s7ynVd//X7Zt6d1GbyZz83c0GxYvWpt5K7Yy+uYrMcu8p1T8nZsdi7pqqzMtuhtnA5h1zik/ulpJIY9PoSff6e2Sdr6ZMyR7/j9WdmJnIt/m5tv8xO8rYULkA9qE+x3nH3eqD+ZXp3mUDO52dfU/fceQDmaXbf7an5rmj1u84w1QUAAMKQqfiUZK8qGbvn/V9aVHrOWD1eKmPvjpbHXnmBqIe+quUpdtXxVatX3yiei1dTsjnHgsfZAQAAbpCo+LSonIzMdEfeTW3npUcsaOxd6bHM0130XRaOhcUFraVaLCxVP0atKhmz45z992e4M3PT83tnauLjYarIarv3kPT4GJNbrWOwsKByS719vViLO3fPntq4ZiYd3hOevUXj6ktVmOoCAABhSEx15WbxCpniLNuyfespr97H1drd56vSBXheK2feqmGWRTr+udc+xf1kIn1nWRqTVHwAAEAYUys+LSo9Mxc1W8pwz5xt7HVUvfC6c+xV5ebu5y2yVunxOjbV3+HXyl5Mr9co9XNtxqJz9bFQGk/P7/0piU9OAqE4sGcOrBHHg5f7/bS32/TRf8v9HZZYSvQjPMGlfPxbOkvqWuy+PPqp2Nw2WNkNvoTqOclUFwAACENicfOWaoaYy9q0wJWjOxPPpfYa1sftlpdYolQmPah975aSO9WeGorH4En9GkLFBwAAhDGs4rOtELRYqKaQ7fbedOvobyjE7lVpxU5l4X0rVseb+h2mIm/VaYs8jFuLMVDxAQAAYcit8ckx6+4k9/HmFhmwehYd/Q6RSo8+6+2HLblbW6hf2xX03iZkeOJjtdOP2j3y5XmzL+RW+65WhHitJzwR+gj2tHjsvufvUzXq2sNUFwAACMPEVJfCHWhJG2rfTBthF2B1uXdWHu7ArFd7RohwjLzFo6D1dUH5OlN7LZw57qj4AACAMIZVfEqzQk93IdYXwlp6dUErEWL0qrTvlO+mYUuLsWS1mmzpmjl8qsvzCy5ftZzWUn6SLSLLxyXC9M2rll9IwJGW0z4R3j83C1NdAAAgjCmLmz3fOZU+9p67iFaBUluAI63vkr2Ne6oIOlq+pR35qPgAAIAwTDzOrqomM4+4uHsPb84er/edpKe1Q1bbDQ17a32uxlTJ+RlhfPa8XpH4VDi7wB91lvIC5j2UW2MoGXOKY6LXDrmwQ/namvt3uSHO0yp+proAAEAYVHwqnGWdNYvV1LN49fapUD5ONdNQd6ooo49FyaO/yv00gsX4a/dCe/6c9al1y22v0bu6TMUHAACE0b3i42nBYy3eweWHtf5rsQDfyvmas97HSiyRKa/ZGS1KnE+jzt2miY+1LwWgFbUL1N2pgavPKbPabuy705/bZFjxjQF8Z/42+maFqS4AABDGUvg46z8ppb/7Naebv9Z1/XX1IcPxpeQ/xqz4UiJGcYzTP4hRGjH+4TG+osQHAADAMqa6AABAGCQ+AAAgDBIfAAAQBokPAAAIg8QHAACEQeIDAADCIPEBAABhFL2y4u3tbX08Hp2a0s/X11f6/v6+3BvcanwppfT5+fmdsxmV1Rhz+zAlYlTm/VykD38iRl2Rz8WixOfxeKSPj482rRro/f0963NW40sppWVZsnbWtBpjbh+mRIzKvJ+L9OFPxKgr8rnIVBcAAAiDxAcAAIRB4gMAAMIg8QEAAGEULW5GH8vy74Xn67pOaAm2nv1CXwCAHyYSn73E4EnlS6n0S/IsJmiJlADVjEu142LhegFgHqa6AABAGLIVn9w7z9l346/tXJblX22hugMlkccj08qY7er88z4eFc5BKj4AACAMuYpP6d2oYnbc4o56diULPrQ4n6xUiF7beXTu7FVpzz6v5G5fXPWv0jFQbVet0lmMLevxX8U++hyUSXxKTmjrgyCl8xhmftHkfnl4Z+XL/kjLBb7Pz1s/Jk9H8exNUytoedzVvoDO2uBFbUJe8ztUqBcwmOoCAABhSFR8vD5+yqPtx7yVsZV4LJX34KGSta5rUZVm77xTil+pLa1d9U/ONLPl81ip7VR8AABAGBIVn+1dizW1WaxivIptGq22WnJ07Gbf5cz++7hvrzKz7deSPt5ea5Uqgy0fCLli5Zywej1WGldHJBKfPWoH6i6rg7g11eNQ2y7FeFqdOzmxWT5PrU0j3Gmf5b1jahYDnxm9gLv071hIHHKptpupLgAAEIZcxUc1QxyN4zBf1D6wWOlRrLzNZGV7kNJ+a/1ov8o4Vp0qL2HpgRUqPgAAIAy5io9XJY+OqmfL3tTMqXvrR+sVk5IHJKyt7cmlvmlcrTtjU3Uncq9buFhB4iOEAa+v5KKpVErf8vpKldebi6sXBiu1vVRtH1qO+cheH5fuodbzuHh72uyIpfYz1QUAAMKg4jOYpawY9vUu66tWtV4pTG+MYqE/Umo/hXX132aNASv9EQkVHwAAEIZcxUdx7QDa2VvkPavPe98BWnpMeI/qwtASnjaDu0JcQB4qPgAAIAy5is8TlR89Xp6IqWWl2tFz8zoL/X62dcRZ+y291kGpLb20jNHDBoFoRybxKdnnBhgpZ88NS+M2ysW+9JHm3M9FOX4esF8O9jDVBQAAwpCp+GC+vd1vIy0OLaF4DLi7vWapMuddyW7brf8u7Gm51IKKDwAACEOu4mNxzYQnHP/fchdDRj9O6lr2D5WCse6sqeK89GOvL++ut5NLfKDvzm691qfOvL2c1LM770jii9MeS0/l4drey4RbnZdMdQEAgDCo+GBXyykv7p7niXiXy2PnduRcZ64qzNGrsFEeamBxMwAAQAW5ig/VAS1H86qvd9V3+m3WHXrtDr/Q1Ks6yVjob3uMSzeUzP29lvG9+Fur/pRJfKx3bGn7o52Qyk+LeemLPVGmfVqUwS0mPT3bvLe4dJbSGy31fssRIcYjvb8nmOoCAABhyFR8XlnJZmsfmbV0J15arTmLSbny49mdLQhyfveWhTGdUpxFoTVUz89I01tbHmPaM2rcUfEBAABhyFZ8rCrNzHveibe2rdZYaXMkIxdrK6+HOaumelk30evOeHa/3onLSt9dOVuv5rVKObrCKJf4WO68lNruNaDKa1xe9Oyf2V+MR3JfsHv1O9SNTHhGuhpXXr/wr+S8Ouf5GYs3pLPGHVNdAAAgjKUkQ1yW5Z+U0t/9mtPNX+u6/rr6kOH4UvIfY1Z8KRGjOMbpH8QojRj/8BhfUeIDAABgGVNdAAAgDBIfAAAQBokPAAAIg8QHAACEQeIDAADCIPEBAABhkPgAAIAwil5Z8fb2tj4ej05N6efr6yt9f39f7o1tNb6UUvr8/PzO2YzKaoy5fZgSMSrzfi7Shz8Ro67I52JR4vN4PNLHx0ebVg30/v6e9Tmr8aWU0rIsWTtrWo0xtw9TIkZl3s9F+vAnYtQV+VxkqgsAAIRB4gMAAMIg8QEAAGGQ+AAAgDCKFjfPtiy/F2irv1H+2c49Z23f+zn1WBGTlXMRwDy134W9ySY+ZwdsWRa5C+5Ze2s+50luzGp9Wuo1TqvxlIxRxXPxzDY2S+1+5SWO3lS/eD2r+S4c3RdMdQEAgDDkKj6l2eLMrL1n9Ub5bmQvU29xLNTvYi1X63q23Uuly4K9fqSq8ZvlyrKHZQ6l43D7+dHf51R8AABAGDIVn6NscZsBnt3tjMyOr+4sStpioYpwtd7Kq7uxKax/mdE/CnF7cqcPPVQSrliseOX0qeUK1lNp20Z9n0skPrknZ8tplRoWT7C7PCc2Z0qT26jHae/GxOu5oOTOtTBCP3mO7dXsJQK1yfW6rtOum0x1AQCAMCQqPqVmZoqtqFePWk7nbX/f7KrdHWcxH03Jeru7zj0Gyqy0c0/Ltnsbm0/q8ZRc+yxfL3Mcxdd7upyKDwAACEO24lN6N9I7Q7zKuK/am5Oxq9+pPN1pZ+tKEupdPThw9nnLLFc6ctZTlFYHWIyujzWFbUkkPmfP+M8+KXOnpLbtLWXlolPbTotPJ+xNp85eRFgjt51W4sFPZw+B7OELc57SY8852Q9TXQAAIAyJis+VnIWxI7NjxUfte6pdyNzr949S2qce+94j1fF2pGV7PTwY4p218XlHzv59PVDxAQAAYchVfEp3Ce6VGd7ZlKn291sWYW3T2WPqs+5cAPw0e13okdyHCc6+eyx8b9z5DhzVb8MTn1YdN3pgK55ICkoXLlt+ouZqwTNsiPQEGzT1mEafNWZbxDK67Ux1AQCAMOSmus5wFzZHbVWj9FFbbyLFasVe1S5SP1GhRC+WziMqPgAAIIxhFR9Lm9j1vCvydsdVuvtvJKqLLKOL2Cecm35ZGc8Ka3ueqPgAAIAwJNb4KGesym3roeY9P6+8HrOau2bLT7HBPuWngdDmMXUL1xi174lhiY+1fQhe373VqpPUj0PO7ti5L2x9/TkAY5Dw2MKSgbGY6gIAAGEMn+qy+q6YngtVFUuVZ20pvTvxssj36lFoi+Ma/lyNQw/nomelswKK/XnWdoX2UvEBAABhSCxutlIRaLlAS32tz5XaRdAW+vlMTiXMap8+WXzv2LbNyu3soWS8RTs2Flmu9JxRau+UxOfspY9nnx+Jhbz/Y/2LXIF64mepj0vOSdXj3YrFJBX7PCewau1lqgsAAIQxdarLQlWlZiHv2e/I+Vml7Lh1JUAxxlYsVU2eLLW5dNHusiwhKrKvIsTozV6V0tK5uZVbcW0dX8m4p+IDAADCkFjcvJepKWa7e3eUORRjuVJzd11yfKwsaM9l8RHiq+qk8rjNOZ57sURa/wNdyudWbwqxSyQ+e84uwCpfmla+JEbY64+rkq2HaS+PCY+qO220Pn3wqkUfWh0HVuVeK9Qe9Gml17lXc0yY6gIAAGHIVnysaTFdZyWbb/3eMossvgvJS7UD9RgDujxWelT3JKLiAwAAwpCv+Fieh859lM9DPF60uCNWPE5sdOfD1QMDJej7sWoe7bbYR7nv45wZm1ziQynWDyulW6/JTkpt+0DloYIaFm80WooWs7UHJ6y0M5f6wwRMdQEAgDCWkkxzWZZ/Ukp/92tON3+t6/rr6kOG40vJf4xZ8aVEjOIYp38QozRi/MNjfEWJDwAAgGVMdQEAgDBIfAAAQBgkPgAAIAwSHwAAEAaJDwAACIPEBwAAhEHiAwAAwiDxAQAAYZD4AACAMP4P6a4zeY5hWWgAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "KBuWx-FtSouR", + "outputId": "3498bba6-068c-413a-b1a9-74e98192b543", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 153 + } + }, + "source": [ + "# Filling occluded images\n", + "occlude_start_row = 14\n", + "num_generated_images = 10\n", + "samples = np.copy(x_test_quantised[0:num_generated_images, :, :, :])\n", + "samples = samples / (q_levels - 1)\n", + "samples[:, occlude_start_row:, :, :] = 0\n", + "\n", + "fig = plt.figure(figsize=(10, 10))\n", + "\n", + "for i in range(10):\n", + " ax = fig.add_subplot(1, 10, i + 1)\n", + " ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary)\n", + " plt.xticks(np.array([]))\n", + " plt.yticks(np.array([]))\n", + "\n", + "for i in range(occlude_start_row, height):\n", + " for j in range(width):\n", + " logits = gated_pixelcnn(samples)\n", + " next_sample = tf.random.categorical(logits[:, i, j, :], 1)\n", + " samples[:, i, j, 0] = (next_sample.numpy() / (q_levels - 1))[:, 0]\n", + "\n", + "fig = plt.figure(figsize=(10, 10))\n", + "\n", + "for i in range(10):\n", + " ax = fig.add_subplot(1, 10, i + 1)\n", + " ax.matshow(samples[i, :, :, 0], cmap=matplotlib.cm.binary)\n", + " plt.xticks(np.array([]))\n", + " plt.yticks(np.array([]))\n", + "plt.show()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAABECAYAAABu1lQcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAADy0lEQVR4nO3cQVLkOBAF0PTEHIFeTx2C+58A7kCvhztoFkwHBOFy2WW7UCrfW3Y4OpSWbL5Shqm1FgAAVfz10wMAAHgk4QcAKEX4AQBKEX4AgFKEHwCgFOEHAChF+AEAShF+AIBShB8AoJS/t1z89PTULpfLSUM519vbW7y/v09L12SuLyLi9fX1vbX2a+mazDWumcMINfbOs/ghc43W6afRa8xcX8T1Z3FT+LlcLvHy8nLcqB7o+fn55jWZ64uImKbp961rMte4Zg4j1Ng7z+KHzDVap59GrzFzfRHXn0XHXgBAKcIPAFDKpmMvtpum+ePU1tqDRwIAROj8AADF6Pwc5FqH59b1FTpAo9Q6N8fZawLGNE2T99MCnR8AoBSdnwNs7fpAT0bpzO2V8T58f/dkGjvb3XvC8If18Un4OcnSIvu6IDO+cNcaKRSOVMtRvt6TzOv3+9xmeCbPWI9Z5jPLOPfyzjmXYy8AoJRTOz8VPxDdWt/o9yO7W7uvDF2Ca/buLFtrdqc/4Oj3apajEWttbI/OCzo/AEApp3R+lhL6kem9lx3K2nFU3bn0Mk9b+LCwhuzd6UxjvdeanydZ78O9489a7x9r369L1+29B6eEn7Uf++6VfeGPLGvQO2rcPf+NjaOCWtY5zuqMD317/nh4y/rq+Xm75lZ92epZ0uO7wrEXAFDKw3/VfW+a7Xmnck321jrzvs5hjzubORmfnzNkma+IXGP9SVlOAqrM59o65+br+1zO/V97u306PwBAKf7I4cmqpPwR7NmpzO1QsuxEtxp1TY82T3Myd6GXOgRz/9ZbXZnv/dHu+cD76PeOzg8AUEqazk+2bxXO/BW9nmWbp7XW1tJrB+iseRlpjjPZ8ht72Tt1o66xUes6osO19K3PUboPP9kf3K9GXewjM2efRnoWM1nzg+Deuel1ffe2edjLs/Ohp7/z59gLACil+85P1rSfddx7Za575OOgI8bTW017ZKwl45jvtfWYuXdZxnmEpVp7ug86PwBAKcIPAFCK8AMAlCL8AAClCD8AQCnCDwBQivADAJQi/AAApQg/AEApwg8AUIrwAwCUIvwAAKUIPwBAKcIPAFCK8AMAlCL8AAClCD8AQCnCDwBQivADAJQi/AAApQg/AEApwg8AUIrwAwCUMrXW1l88Tf9GxO/zhnOqf1prv5YuSF5fxPg13qwvQo0JjL5OI8av0Tr93+g1Jq8v4kqNm8IPAEB2jr0AgFKEHwCgFOEHAChF+AEAShF+AIBShB8AoBThBwAoRfgBAEoRfgCAUv4Dzge+SeXbn58AAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAABECAYAAABu1lQcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAFtklEQVR4nO3dS5LbNhAAUDCVI9jrzCF8/xPYd7DXmTswC8U18hRFkRQ/3ej3li55Ck0AVKMBUsM4jg0AoIq/rm4AAMCZJD8AQCmSHwCgFMkPAFCK5AcAKEXyAwCUIvkBAEqR/AAApUh+AIBS/l7z4S9fvoxvb28HNeVYP3/+bO/v78PcZzLH11prP378eB/H8evcZzLHuKQPWxNjdObiTeYYjdMPvceYOb7WHs/FVcnP29tb+/79+36tOtG3b9+efiZzfK21NgzDr2efyRzjkj5sTYzRmYs3mWM0Tj/0HmPm+Fp7PBdtewEApUh+AIBSVm17sd4wTG+njuN4cksAgNZUfgCAYlR+dvKowvPs8xUqQL3EOtXH2WMC+jQMg/vTDJUfAKAUlZ8drK36QCS9VOZelfE6fL73ZGo7623dYfjN+Pgg+TnI3CC7H5AZb7hL9ZQU9hTLXu6vSebx+7lvM8zJI8Zjlv7M0s5Xueccy7YXAFDKoZWfigdE18bX+/XI7tnqK0OV4JFXV5bjOFqdXmDv+2qWrRFjrW9n5wsqPwBAKYdUfuYy9D2z9ygrlKXtqLpyidJPazhYWEP26nSmtm615Psk63XY2v6s8f629P4697lXr8Ehyc/Sw76vyj7we5Y10dur3ZHfsbFXopa1j7M64qBv5MPDa8ZX5Pn2yLP4ssUzJ+K9wrYXAFDK6Y+6v5rNRl6pPJK9tM60+z6MuLKZknH+HCFLf7WWq61XyrITUKU/l8Y51V+f+3Lqb71a7VP5AQBK8ZLDg1XJ8nvwykplaoWSZSW6Vq9jurd+mpK5Cj1XIZj6t2hxZb72e9tywHvv+47KDwBQSprKT7azCkc+ohdZtn5aamksUStAR/VLT32cyZon9rJX6nodY73GtUeFa+6sz17CJz/ZJ+69Xgd7z/TZh57mYiZLvgi29k3U8R1t8fAqc+cm0nv+bHsBAKWErfxsXeVkXh30Tr/FUPlt1JlX4EsP/G79W5yj4rWPOO9UfgCAUkJUfvbICiNn05Hbtpe5R06X/v8K1ymbCtW6zHGsrQZliXXq0eb7sz8RKwncHHVYee+xe1ny02u5tuqkXDPgH70ZuYeDjVE9e1Pqlr9FTGc8KXOljHFVnDN7vgH/iOtn2wsAKOXQyk+kx9q4zlzfRV+lrn1nRdQ4ftuyGutl7kXvm71Fn1vPbG1/L+O1J0v65Ow3YKv8AAClhDjw3JPqq461q7WMq9JeVqJR28U+en3b+iMVYmQ/Kj8AQCkqP+xqSVXk0Wd6W7n1Fg85ZKymftZLdZXnrvq1e8kPu5r7Yc8l/y+7XuKAK1Q7hM+fzuxX214AQClhKz8y+9yWvoU1Qj/3+sJNbnrYBnpV9BeIrumjqDHwp2dvGr96Xqr8AAClHFr5WfvYs4y+hqz9nLXdVWWpPO6p54cJeoihgq0PvZzdv6dse0UocXG9aDevil+OW2T6QrXQ6o++unk2tq+8TmvvEWvzgSOSJdteAEApoQ48Rz+Ut6dMq2nq6L1Ca37lU+l74d7W6siZ1+no7asj/77KDwBQymmVnzWHn4dh6CbL35K9R4z9ld8Jih7PlIhtvsLUdYhYHXr2WG3vIvbJGku+HyIckj3aHv0YuVIWqQ9VfgCAUk4/8/M5y3uU6UbOXqeseXLo2WezxT4l60o08zXfw5L478fx1deresVniUzXYe3rUXr55fqs98tnts7PMypElx94fvYYfNZEYK69WbYSWptu11yfRI1jqc/tzzbuzhC9j/XZTebrcN/2tYnQkrijJE2vJu9Xz8W5ZHVLbGc+CGTbCwAoZViTUQ3D8G9r7ddxzTnUP+M4fp37QPL4Wus/xqfxtSbGBHofp631H6Nx+r/eY0weX2sPYlyV/AAAZGfbCwAoRfIDAJQi+QEASpH8AAClSH4AgFIkPwBAKZIfAKAUyQ8AUIrkBwAo5T903wuZqmcetwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "W7bzeLMjPNcI", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/WIP/6-gated_pixelcnn_cropped/gated_pixelcnn_vs_cropped.ipynb b/WIP/6-gated_pixelcnn_cropped/gated_pixelcnn_vs_cropped.ipynb new file mode 100644 index 0000000..aede3e3 --- /dev/null +++ b/WIP/6-gated_pixelcnn_cropped/gated_pixelcnn_vs_cropped.ipynb @@ -0,0 +1,1135 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "gated_pixelcnn_vs_cropped.ipynb", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "HgFGN07idT26", + "colab_type": "text" + }, + "source": [ + "# Converting masked-based implementation to cropping-based implementation" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ObS7YqtCbC33", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import tensorflow as tf\n", + "import tensorflow.keras as keras\n", + "import numpy as np\n", + "import math" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "iCMR2mKLbt_l", + "colab_type": "code", + "colab": {} + }, + "source": [ + "test_ones_2d = np.ones([1, 5, 5, 1], dtype='float32')\n", + "test_ones_3d = np.ones([1, 5, 5, 5, 1], dtype='float32')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "NycU0IQZb1X1", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def print_3d(matrix_3d):\n", + " for i in range(matrix_3d.shape[0]):\n", + " print(f'Depth {i}')\n", + " print(matrix_3d[i,...])" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "AeH21Zkzcrt5", + "colab_type": "code", + "outputId": "d47da908-4659-45aa-f540-e01d3e510bdc", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 527 + } + }, + "source": [ + "print_3d(test_ones_3d.squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Depth 0\n", + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n", + "Depth 1\n", + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n", + "Depth 2\n", + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n", + "Depth 3\n", + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n", + "Depth 4\n", + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "xqrqiDnbfoqW", + "colab_type": "code", + "outputId": "12aef41a-4992-464c-8d38-8fd732286c57", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + } + }, + "source": [ + "print(test_ones_2d[0,:,:,0].squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[[1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]\n", + " [1. 1. 1. 1. 1.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1UcUkEj0d7wh", + "colab_type": "text" + }, + "source": [ + "## Creating 2D masked solution to check results with cropped solution later" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "83mZFyondaAT", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MaskedConv2D(tf.keras.layers.Layer):\n", + " def __init__(self,\n", + " mask_type,\n", + " filters,\n", + " kernel_size,\n", + " strides=1,\n", + " padding='same',\n", + " kernel_initializer='glorot_uniform',\n", + " bias_initializer='zeros'):\n", + " super(MaskedConv2D, self).__init__()\n", + "\n", + " assert mask_type in {'A', 'B', 'V'}\n", + " self.mask_type = mask_type\n", + "\n", + " self.filters = filters\n", + "\n", + " if isinstance(kernel_size, int):\n", + " kernel_size = (kernel_size, kernel_size)\n", + " self.kernel_size = kernel_size\n", + "\n", + " self.strides = strides\n", + " self.padding = padding.upper()\n", + " self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)\n", + " self.bias_initializer = tf.keras.initializers.get(bias_initializer)\n", + "\n", + " def build(self, input_shape):\n", + " kernel_h, kernel_w = self.kernel_size\n", + "\n", + " self.kernel = self.add_weight(\"kernel\",\n", + " shape=(kernel_h,\n", + " kernel_w,\n", + " int(input_shape[-1]),\n", + " self.filters),\n", + " initializer=self.kernel_initializer,\n", + " trainable=True)\n", + "\n", + " self.bias = self.add_weight(\"bias\",\n", + " shape=(self.filters,),\n", + " initializer=self.bias_initializer,\n", + " trainable=True)\n", + "\n", + " mask = np.ones(self.kernel.shape, dtype=np.float32)\n", + "\n", + " if kernel_h % 2 != 0: \n", + " center_h = kernel_h // 2\n", + " else:\n", + " center_h = (kernel_h - 1) // 2\n", + "\n", + " if kernel_w % 2 != 0: \n", + " center_w = kernel_w // 2\n", + " else:\n", + " center_w = (kernel_w - 1) // 2\n", + "\n", + " if self.mask_type == 'V':\n", + " mask[center_h + 1:, :, :, :] = 0.\n", + " else:\n", + " mask[:center_h, :, :] = 0.\n", + " mask[center_h, center_w + (self.mask_type == 'B'):, :, :] = 0.\n", + " mask[center_h + 1:, :, :] = 0.\n", + "\n", + " self.mask = tf.constant(mask, dtype=tf.float32, name='mask')\n", + "\n", + " def call(self, input):\n", + " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n", + " x = tf.nn.conv2d(input, masked_kernel, strides=[1, self.strides, self.strides, 1], padding=self.padding)\n", + " x = tf.nn.bias_add(x, self.bias)\n", + " return x" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nMXqIeSUkdcV", + "colab_type": "text" + }, + "source": [ + "### Tests with kernel_size 3" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M62GZQe8ixvy", + "colab_type": "text" + }, + "source": [ + "#### Vertical stack" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "kjUrpEtIg7p9", + "colab_type": "code", + "outputId": "8df4b5ee-1d0e-4f7d-9009-beaf8c8133ee", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 + } + }, + "source": [ + "mask_type = 'V'\n", + "kernel_size=(3, 3)\n", + "\n", + "conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size, \n", + " padding='same',\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "result_v = conv(test_ones_2d)\n", + "\n", + "print('MASK')\n", + "print(conv.mask.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result_v.numpy().squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "MASK\n", + "[[1. 1. 1.]\n", + " [1. 1. 1.]\n", + " [0. 0. 0.]]\n", + "\n", + "OUTPUT\n", + "[[2. 3. 3. 3. 2.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PFqvGa439Z2o", + "colab_type": "text" + }, + "source": [ + "#### Feeding horizontal" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "oq_JTwdE9lr6", + "colab_type": "code", + "outputId": "a5eecd56-e068-4d9f-ff4b-01bd048f0bbb", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 119 + } + }, + "source": [ + "padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n", + "cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", + "\n", + "x = padding(result_v)\n", + "result = cropping(x)\n", + "\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "OUTPUT\n", + "[[0. 0. 0. 0. 0.]\n", + " [2. 3. 3. 3. 2.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iHY6UE2_p5oc", + "colab_type": "text" + }, + "source": [ + "#### Horizontal stack A" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4_q_IunZkFmj", + "colab_type": "code", + "outputId": "1bfa1a88-7ae3-4a7e-d43e-46e7e683cce4", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 + } + }, + "source": [ + "mask_type = 'A'\n", + "kernel_size=(3, 3)\n", + "\n", + "conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size, \n", + " padding='same',\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "result = conv(test_ones_2d)\n", + "\n", + "print('MASK')\n", + "print(conv.mask.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "MASK\n", + "[[0. 0. 0.]\n", + " [1. 0. 0.]\n", + " [0. 0. 0.]]\n", + "\n", + "OUTPUT\n", + "[[0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jMuS-vgWqAWK", + "colab_type": "text" + }, + "source": [ + "#### Horizontal stack B" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "5yeB5h2tkSs_", + "colab_type": "code", + "outputId": "9e7346b9-d360-42dd-85b4-a9a22ba2bacb", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 + } + }, + "source": [ + "mask_type = 'B'\n", + "kernel_size=(3, 3)\n", + "\n", + "conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size, \n", + " padding='same',\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "result = conv(test_ones_2d)\n", + "\n", + "print('MASK')\n", + "print(conv.mask.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "MASK\n", + "[[0. 0. 0.]\n", + " [1. 1. 0.]\n", + " [0. 0. 0.]]\n", + "\n", + "OUTPUT\n", + "[[1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1NxkQ3U1knbE", + "colab_type": "text" + }, + "source": [ + "### Tests with kernel_size 4" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WNykK-WpqMlu", + "colab_type": "text" + }, + "source": [ + "#### Vertical stack" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "p3DTiFYCk57Y", + "colab_type": "code", + "outputId": "01d4c588-acde-43ba-8b52-7b5b2fddc52a", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 221 + } + }, + "source": [ + "mask_type = 'V'\n", + "kernel_size=(4, 4)\n", + "\n", + "padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n", + "\n", + "conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size, \n", + " padding='same',\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", + "\n", + "x = padding(test_ones_2d)\n", + "x = conv(x)\n", + "result = cropping(x)\n", + "\n", + "print('MASK')\n", + "print(conv.mask.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "MASK\n", + "[[1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]\n", + " [0. 0. 0. 0.]\n", + " [0. 0. 0. 0.]]\n", + "\n", + "OUTPUT\n", + "[[0. 0. 0. 0. 0.]\n", + " [3. 4. 4. 3. 2.]\n", + " [6. 8. 8. 6. 4.]\n", + " [6. 8. 8. 6. 4.]\n", + " [6. 8. 8. 6. 4.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "E5jUGK3_qbT8", + "colab_type": "text" + }, + "source": [ + "#### Horizontal stack A" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "J5HSUo7Rk5xO", + "colab_type": "code", + "outputId": "d2eabc6f-1b49-43f9-f7d8-1cf13021c47b", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 221 + } + }, + "source": [ + "mask_type = 'A'\n", + "kernel_size=(4, 4)\n", + "\n", + "conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size, \n", + " padding='same',\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "result = conv(test_ones_2d)\n", + "\n", + "print('MASK')\n", + "print(conv.mask.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "MASK\n", + "[[0. 0. 0. 0.]\n", + " [1. 0. 0. 0.]\n", + " [0. 0. 0. 0.]\n", + " [0. 0. 0. 0.]]\n", + "\n", + "OUTPUT\n", + "[[0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KqORK7mLqvPP", + "colab_type": "text" + }, + "source": [ + "#### Horizontal stack B" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_2V51aerk5l1", + "colab_type": "code", + "outputId": "b54f7147-0080-48a2-a84c-d3ca557c31d9", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 221 + } + }, + "source": [ + "mask_type = 'B'\n", + "kernel_size=(4, 4)\n", + "\n", + "conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=1,\n", + " kernel_size=kernel_size, \n", + " padding='same',\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "result = conv(test_ones_2d)\n", + "\n", + "print('MASK')\n", + "print(conv.mask.numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "MASK\n", + "[[0. 0. 0. 0.]\n", + " [1. 1. 0. 0.]\n", + " [0. 0. 0. 0.]\n", + " [0. 0. 0. 0.]]\n", + "\n", + "OUTPUT\n", + "[[1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kZmui789Br2B", + "colab_type": "text" + }, + "source": [ + "## Creating 2D cropped solution" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "PxBNsvzhB1ec", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class VerticalCroppedConv2d(tf.keras.Model):\n", + " def __init__(self,\n", + " filters,\n", + " kernel_size,\n", + " kernel_initializer, \n", + " bias_initializer):\n", + " super(VerticalCroppedConv2d, self).__init__(name='')\n", + "\n", + " if isinstance(kernel_size, int):\n", + " kernel_size = (kernel_size, kernel_size)\n", + "\n", + " kernel_h, kernel_w = kernel_size\n", + "\n", + " self.padding = keras.layers.ZeroPadding2D(padding=((kernel_h-1, 0),(int((kernel_w-1)/2),int((kernel_w-1)/2))))\n", + "\n", + " self.conv = keras.layers.Conv2D(filters=filters,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " padding='valid',\n", + " kernel_initializer=kernel_initializer, \n", + " bias_initializer=bias_initializer)\n", + "\n", + " def call(self, input_value):\n", + "\n", + " x = self.padding(input_value)\n", + " x = self.conv(x)\n", + " out = self.cropping(x)\n", + "\n", + " return out\n", + "\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RfFuwKlrP2JU", + "colab_type": "text" + }, + "source": [ + "Example step by step" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fH3I0lfoPdcH", + "colab_type": "code", + "outputId": "f142569c-f99d-4244-ab5d-1887243d0d29", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 + } + }, + "source": [ + "kernel_h = 2\n", + "kernel_w = 3\n", + "\n", + "kernel_size = (kernel_h, kernel_w)\n", + "\n", + "padding = keras.layers.ZeroPadding2D(padding=((kernel_h-1, 0),(int((kernel_w-1)/2),int((kernel_w-1)/2))))\n", + "\n", + "res = padding(test_ones_2d)\n", + "print(res.numpy().squeeze())\n", + "\n", + "conv = keras.layers.Conv2D(filters=1,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " padding='valid',\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "res2 = conv(res)\n", + "print(res2.numpy().squeeze())\n", + "\n", + "\n", + "\n" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[[0. 0. 0. 0. 0. 0. 0.]\n", + " [0. 1. 1. 1. 1. 1. 0.]\n", + " [0. 1. 1. 1. 1. 1. 0.]\n", + " [0. 1. 1. 1. 1. 1. 0.]\n", + " [0. 1. 1. 1. 1. 1. 0.]\n", + " [0. 1. 1. 1. 1. 1. 0.]]\n", + "[[2. 3. 3. 3. 2.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "TFexYspQWEpo", + "colab_type": "code", + "outputId": "d0533bb8-38e8-482c-8e65-4f6522519087", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + } + }, + "source": [ + "conv.weights[0].numpy().squeeze()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "array([[1., 1., 1.],\n", + " [1., 1., 1.]], dtype=float32)" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 16 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "AfyRyUmTNYZ8", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def build_test_croppedv_stack_2d(input_shape=(5, 5, 1), kernel_size=3):\n", + " inputs = tf.keras.layers.Input(shape=input_shape)\n", + " \n", + " x = VerticalCroppedConv2d(\n", + " filters=1,\n", + " kernel_size=kernel_size, \n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')(inputs)\n", + "\n", + " stack = tf.keras.Model(inputs=inputs, outputs=x)\n", + " stack.compile(optimizer='adam', loss='mse')\n", + " return stack" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Nskk-zJwN3Em", + "colab_type": "text" + }, + "source": [ + "###Tests with kernel_size 3" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DdcDFcMbxwpZ", + "colab_type": "text" + }, + "source": [ + "#### Vertical stack" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "jc06sHDoNzx8", + "colab_type": "code", + "outputId": "750bf826-6cf4-400b-f693-51ad6350fd15", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 187 + } + }, + "source": [ + "kernel_size=(2, 3)\n", + "kernel_h, kernel_w = kernel_size\n", + "\n", + "padding2 = keras.layers.ZeroPadding2D(padding=((kernel_h-1, 0),(int((kernel_w-1)/2),int((kernel_w-1)/2))))\n", + "conv = keras.layers.Conv2D(filters=1,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " padding='valid',\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "x = padding2(test_ones_2d)\n", + "result_v = conv(x)\n", + "\n", + "print('KERNEL')\n", + "print(conv.weights[0].numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result_v.numpy().squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "KERNEL\n", + "[[1. 1. 1.]\n", + " [1. 1. 1.]]\n", + "\n", + "OUTPUT\n", + "[[2. 3. 3. 3. 2.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SvWpzQFGEGGm", + "colab_type": "text" + }, + "source": [ + "#### Feeding horizontal" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "5jLEZhYtOZgi", + "colab_type": "code", + "outputId": "d293ab6f-7fa6-4aec-cc68-9ab3ae7185dd", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 119 + } + }, + "source": [ + "padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n", + "cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", + "\n", + "x = padding(result_v)\n", + "result = cropping(x)\n", + "\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "OUTPUT\n", + "[[0. 0. 0. 0. 0.]\n", + " [2. 3. 3. 3. 2.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]\n", + " [4. 6. 6. 6. 4.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MQLekDEaEUUT", + "colab_type": "text" + }, + "source": [ + "#### Horizontal stack A" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bHiwKZniEk5A", + "colab_type": "code", + "outputId": "ebd659c5-d899-4d6f-9c81-5c821cf4ea61", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 170 + } + }, + "source": [ + "kernel_size=(1, 1)\n", + "kernel_h, kernel_w = kernel_size\n", + "\n", + "conv = keras.layers.Conv2D(filters=1,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "padding = keras.layers.ZeroPadding2D(padding=(0,(1,0)))\n", + "cropping = keras.layers.Cropping2D(cropping=(0, (0, 1)))\n", + "\n", + "x = conv(test_ones_2d)\n", + "x = padding(x)\n", + "result = cropping(x)\n", + "\n", + "print('KERNEL')\n", + "print(conv.weights[0].numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "KERNEL\n", + "1.0\n", + "\n", + "OUTPUT\n", + "[[0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]\n", + " [0. 1. 1. 1. 1.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IvmGrDziEadf", + "colab_type": "text" + }, + "source": [ + "#### Horizontal stack B" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "pRKJFE4TFx4I", + "colab_type": "code", + "outputId": "75d34ade-0983-49a5-f157-b98975c22560", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 170 + } + }, + "source": [ + "kernel_size=(1, 2)\n", + "kernel_h, kernel_w = kernel_size\n", + "\n", + "padding2 = keras.layers.ZeroPadding2D(padding=((int((kernel_h-1)/2),int((kernel_h-1)/2)), (kernel_w-1, 0)))\n", + "conv = keras.layers.Conv2D(filters=1,\n", + " kernel_size=kernel_size,\n", + " strides=1,\n", + " padding='valid',\n", + " kernel_initializer='ones', \n", + " bias_initializer='zeros')\n", + "\n", + "\n", + "x = padding2(test_ones_2d)\n", + "result = conv(x)\n", + "\n", + "print('KERNEL')\n", + "print(conv.weights[0].numpy().squeeze())\n", + "print('')\n", + "print('OUTPUT')\n", + "print(result.numpy().squeeze())" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "KERNEL\n", + "[1. 1.]\n", + "\n", + "OUTPUT\n", + "[[1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]\n", + " [1. 2. 2. 2. 2.]]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "swYJ4XMofUWv", + "colab_type": "text" + }, + "source": [ + "REFERENCES\n", + "\n", + "https://wiki.math.uwaterloo.ca/statwiki/index.php?title=STAT946F17/Conditional_Image_Generation_with_PixelCNN_Decoders#Gated_PixelCNN\n", + "\n", + "https://www.slideshare.net/suga93/conditional-image-generation-with-pixelcnn-decoders\n", + "\n", + "https://www.youtube.com/watch?v=1BURwCCYNEI" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "jTI9ts7i7Wch", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file