From 9f8e4605130918d698f8ed2fb8bf2634dd2215c7 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 8 May 2020 18:33:15 +0100 Subject: [PATCH] WIP - Gated pixelcnn --- .../pixelCNN.py | 5 +- README.md | 9 +- .../pixelcnn_receptive_field.ipynb | 414 +++++++++++++++ .../Receptive_fields.ipynb | 251 ---------- .../gated_pixelCNN.py | 172 ++++--- .../gated_pixelcnn.ipynb | 0 .../gated_pixelcnn_receptive_field.ipynb | 472 ++++++++++++++++++ .../Gated_Receptive_fields.ipynb | 321 ------------ WIP/rascunho.py | 4 +- 9 files changed, 988 insertions(+), 660 deletions(-) create mode 100644 WIP/3 - PixelCNNs blind spot in the receptive field/pixelcnn_receptive_field.ipynb delete mode 100644 WIP/3 -PixelCNNs blind spot in the receptive field/Receptive_fields.ipynb rename WIP/{4 - Gated_PixelCNN => 4 - Gated PixelCNN}/gated_pixelCNN.py (76%) rename WIP/{4 - Gated_PixelCNN => 4 - Gated PixelCNN}/gated_pixelcnn.ipynb (100%) create mode 100644 WIP/4 - Gated PixelCNN/gated_pixelcnn_receptive_field.ipynb delete mode 100644 WIP/4 - Gated_PixelCNN/Gated_Receptive_fields.ipynb diff --git a/1 - Autoregressive Models - PixelCNN/pixelCNN.py b/1 - Autoregressive Models - PixelCNN/pixelCNN.py index 980a31a..7b63707 100644 --- a/1 - Autoregressive Models - PixelCNN/pixelCNN.py +++ b/1 - Autoregressive Models - PixelCNN/pixelCNN.py @@ -197,7 +197,7 @@ def train_step(batch_x, batch_y): with tf.GradientTape() as ae_tape: logits = pixelcnn(batch_x, training=True) - loss = compute_loss(tf.one_hot(batch_y, q_levels), logits) + loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits) gradients = ae_tape.gradient(loss, pixelcnn.trainable_variables) gradients, _ = tf.clip_by_global_norm(gradients, 1.0) @@ -218,6 +218,7 @@ def train_step(batch_x, batch_y): loss = train_step(batch_x, batch_y) progbar.add(1, values=[('loss', loss)]) + # ------------------------------------------------------------------------------------ # Test set performance test_loss = [] @@ -225,7 +226,7 @@ def train_step(batch_x, batch_y): logits = pixelcnn(batch_x, training=False) # Calculate cross-entropy (= negative log-likelihood) - loss = compute_loss(tf.one_hot(batch_y, q_levels), logits) + loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits) test_loss.append(loss) print('nll : {:} nats'.format(np.array(test_loss).mean())) diff --git a/README.md b/README.md index f2dfdf2..81fe10e 100644 --- a/README.md +++ b/README.md @@ -15,12 +15,15 @@ Clone the git repository : Python 3 with [TensorFlow 2.0+](https://www.tensorflow.org/) are the primary requirements. Install virtualenv and creat a new virtual environment: - sudo apt-get install -y python3-venv - python3 -m venv ./venv + sudo apt update + sudo apt install python3-dev python3-pip + sudo pip3 install -U virtualenv # system-wide install + virtualenv --system-site-packages -p python3 ./venv Then, install requirements - source venv/bin/activate + source ./venv/bin/activate + pip3 install --upgrade pip pip3 install -r requirements.txt

1. Autoregressive Models — PixelCNN

diff --git a/WIP/3 - PixelCNNs blind spot in the receptive field/pixelcnn_receptive_field.ipynb b/WIP/3 - PixelCNNs blind spot in the receptive field/pixelcnn_receptive_field.ipynb new file mode 100644 index 0000000..1ff689f --- /dev/null +++ b/WIP/3 - PixelCNNs blind spot in the receptive field/pixelcnn_receptive_field.ipynb @@ -0,0 +1,414 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "YqTKIYLooHsq" + }, + "source": [ + "# PixelCNN blind spots" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "wM-m3Z8CiLXU" + }, + "source": [ + "*Note: Here we are using float64 to get more precise values of the gradients and avoid false values." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "gf5wwqP3ozaN" + }, + "outputs": [], + "source": [ + "import random as rn\n", + "\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.ticker import FixedLocator\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "from tensorflow import nn\n", + "from tensorflow.keras import initializers\n", + "from tensorflow.keras.utils import Progbar\n", + "\n", + "tf.keras.backend.set_floatx('float64')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "yJ_JlzWco7ci" + }, + "outputs": [], + "source": [ + "class MaskedConv2D(keras.layers.Layer):\n", + " \"\"\"Convolutional layers with masks.\n", + "\n", + " Convolutional layers with simple implementation of masks type A and B for\n", + " autoregressive models.\n", + "\n", + " Arguments:\n", + " mask_type: one of `\"A\"` or `\"B\".`\n", + " filters: Integer, the dimensionality of the output space\n", + " (i.e. the number of output filters in the convolution).\n", + " kernel_size: An integer or tuple/list of 2 integers, specifying the\n", + " height and width of the 2D convolution window.\n", + " Can be a single integer to specify the same value for\n", + " all spatial dimensions.\n", + " strides: An integer or tuple/list of 2 integers,\n", + " specifying the strides of the convolution along the height and width.\n", + " Can be a single integer to specify the same value for\n", + " all spatial dimensions.\n", + " Specifying any stride value != 1 is incompatible with specifying\n", + " any `dilation_rate` value != 1.\n", + " padding: one of `\"valid\"` or `\"same\"` (case-insensitive).\n", + " kernel_initializer: Initializer for the `kernel` weights matrix.\n", + " bias_initializer: Initializer for the bias vector.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " mask_type,\n", + " filters,\n", + " kernel_size,\n", + " strides=1,\n", + " padding='same',\n", + " kernel_initializer='glorot_uniform',\n", + " bias_initializer='zeros'):\n", + " super(MaskedConv2D, self).__init__()\n", + "\n", + " assert mask_type in {'A', 'B'}\n", + " self.mask_type = mask_type\n", + "\n", + " self.filters = filters\n", + " self.kernel_size = kernel_size\n", + " self.strides = strides\n", + " self.padding = padding.upper()\n", + " self.kernel_initializer = initializers.get(kernel_initializer)\n", + " self.bias_initializer = initializers.get(bias_initializer)\n", + "\n", + " def build(self, input_shape):\n", + " self.kernel = self.add_weight('kernel',\n", + " shape=(self.kernel_size,\n", + " self.kernel_size,\n", + " int(input_shape[-1]),\n", + " self.filters),\n", + " initializer=self.kernel_initializer,\n", + " trainable=True)\n", + "\n", + " self.bias = self.add_weight('bias',\n", + " shape=(self.filters,),\n", + " initializer=self.bias_initializer,\n", + " trainable=True)\n", + "\n", + " center = self.kernel_size // 2\n", + "\n", + " mask = np.ones(self.kernel.shape, dtype=np.float64)\n", + " mask[center, center + (self.mask_type == 'B'):, :, :] = 0.\n", + " mask[center + 1:, :, :, :] = 0.\n", + "\n", + " self.mask = tf.constant(mask, dtype=tf.float64, name='mask')\n", + "\n", + " def call(self, input):\n", + " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n", + " x = nn.conv2d(input,\n", + " masked_kernel,\n", + " strides=[1, self.strides, self.strides, 1],\n", + " padding=self.padding)\n", + " x = nn.bias_add(x, self.bias)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now, we define the residual block.\n", + "\n", + "*Note: Here we removed the ReLU activations to not mess with the gradients while we are investigating them." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class ResidualBlock(keras.Model):\n", + " \"\"\"Residual blocks that compose pixelCNN\n", + "\n", + " Blocks of layers with 3 convolutional layers and one residual connection.\n", + " Based on Figure 5 from [1] where h indicates number of filters.\n", + "\n", + " Refs:\n", + " [1] - Oord, A. V. D., Kalchbrenner, N., & Kavukcuoglu, K. (2016). Pixel recurrent\n", + " neural networks. arXiv preprint arXiv:1601.06759.\n", + " \"\"\"\n", + "\n", + " def __init__(self, h):\n", + " super(ResidualBlock, self).__init__(name='')\n", + "\n", + " self.conv2a = keras.layers.Conv2D(filters=h, kernel_size=1, strides=1)\n", + " self.conv2b = MaskedConv2D(mask_type='B', filters=h, kernel_size=3, strides=1)\n", + " self.conv2c = keras.layers.Conv2D(filters=2 * h, kernel_size=1, strides=1)\n", + "\n", + " def call(self, input_tensor):\n", + "# x = nn.relu(input_tensor)\n", + "# x = self.conv2a(x)\n", + " x = self.conv2a(input_tensor)\n", + "\n", + "# x = nn.relu(x)\n", + " x = self.conv2b(x)\n", + "\n", + "# x = nn.relu(x)\n", + " x = self.conv2c(x)\n", + "\n", + " x += input_tensor\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "height = 10\n", + "width = 10\n", + "n_channel = 1\n", + "\n", + "data = tf.random.normal((1, height, width, n_channel))\n", + "\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", + "model = tf.keras.Model(inputs=inputs, outputs=x)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "jxCLMYc-FxdJ" + }, + "outputs": [], + "source": [ + "def plot_receptive_field(model, data):\n", + " with tf.GradientTape() as tape:\n", + " tape.watch(data)\n", + " prediction = model(data)\n", + " loss = prediction[:,5,5,0]\n", + "\n", + " gradients = tape.gradient(loss, data)\n", + "\n", + " gradients = np.abs(gradients.numpy().squeeze())\n", + " gradients = (gradients > 0).astype('float64')\n", + " gradients[5, 5] = 0.5\n", + "\n", + " fig = plt.figure()\n", + " ax = fig.add_subplot(1, 1, 1)\n", + "\n", + " plt.xticks(np.arange(0, 10, step=1))\n", + " plt.yticks(np.arange(0, 10, step=1))\n", + " ax.xaxis.set_minor_locator(FixedLocator(np.arange(0.5, 10.5, step=1)))\n", + " ax.yaxis.set_minor_locator(FixedLocator(np.arange(0.5, 10.5, step=1)))\n", + " plt.grid(which=\"minor\")\n", + " plt.imshow(gradients, vmin=0, vmax=1)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 + }, + "colab_type": "code", + "id": "0qpDtNuvo9NL", + "outputId": "926434e2-44bf-40d3-a39f-6c80ebc880b6" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOHUlEQVR4nO3db8yddX3H8fdn/WNpQUAQgpQNFhkZIREYYSjKNhADipgsywYJJprN7oE68M+MzgfGB1u2zBh9sJg0gDORP8EKiRpEmKLEZKuWUmah4KAgtCLFqPybo4DfPTinS2GlvXrOdfW+r1/er+ROz33fp59807ufc65z3ed8T6oKSe34rYUeQFK/LLXUGEstNcZSS42x1FJjlg4RujyvqhWs6jXzsGMO4VePPd1rprnDZZo7XCbA//AsO+u57Ol7g5R6Bav4w5zXa+aff+RCbvjbb/aaae5wmeYOlwmwvr79it/z8FtqjKWWGmOppcZYaqkxllpqjKWWGtOp1EkuSHJ/kgeSfHzooSTNbp+lTrIE+BfgQuBk4NIkJw89mKTZdLmnPhN4oKq2VtVO4HrgXcOOJWlWXUp9LPDobp9vm37tJZKsSbIhyYbnea6v+STtp95OlFXV2qo6o6rOWMar+oqVtJ+6lHo7cNxun6+efk3SItSl1D8ETkxyQpLlwCXA14YdS9Ks9vkqrap6IckHgG8BS4Crq+qewSeTNJNOL72sqpuBmweeRVIPfEaZ1BhLLTXGUkuNsdRSYyy11Jj0+V5aSd4JvPPIw1/7vn/45D/2lgtw+OpD+eW2J3vNNHe4THOHywT4yEc/ylP1iz1uE+211Lu8Oq+p3reJ/vNAGyTNHdWsY8sdatb19e1XLLWH31JjLLXUGEstNcZSS42x1FJjLLXUmC6LB69OsiPJ5gMxkKT5dLmn/lfggoHnkNSTfZa6qu4AfnEAZpHUg97enzrJGmANwApW9hUraT+5TVRqjGe/pcZYaqkxXX6ldR3w78BJSbYl+cvhx5I0qy4rgi89EINI6oeH31JjLLXUGEstNcZSS42x1FJj3CZq7qhmHVuu20T3YkwbJMeWO6ZZx5brNlFJc7PUUmMstdQYSy01xlJLjbHUUmO6vPTyuCS3J7k3yT1JLj8Qg0maTZcdZS8AH6mqjUkOAe5McltV3TvwbJJm0GWb6GNVtXF6+WlgC3Ds0INJms1+bRNNcjxwGrB+D99zm6i0CHQ+UZbkYOCrwBVV9dTLv+82UWlx6FTqJMuYFPqaqrpx2JEkzaPL2e8AVwFbquqzw48kaR5d7qnPBt4NnJtk0/Tj7QPPJWlGXbaJfh/Y40u8JC0+PqNMaoyllhpjqaXGWGqpMZZaaozbRM0d1axjy3Wb6F6MaYPk2HLHNOvYct0mKmlullpqjKWWGmOppcZYaqkxXV56uSLJD5LcPV08+OkDMZik2XRZZ/QccG5VPTNdlvD9JN+sqv8YeDZJM+jy0ssCnpl+umz60f8vtyX1ous6oyVJNgE7gNuqao+LB5NsSLLheZ7re05JHXUqdVW9WFWnAquBM5OcsofruHhQWgT26+x3Vf0KuB24YJhxJM2ry9nv1yY5bHr5IOB84L6hB5M0my5nv48BvpRkCZMbgRuq6hvDjiVpVl3Ofv8nk3flkDQCPqNMaoyllhpjqaXGWGqpMZZaaoyLB0eWe8zvHcLBBz3ea+Yzvz6698yhcx/78dO957p4cC9cPDhc7idv/SPOOeXzvWbesfny3jOHzv37t32v91wXD0palCy11BhLLTXGUkuNsdRSYyy11JjOpZ6uNLoriS+7lBax/bmnvhzYMtQgkvrRdfHgauAdwJXDjiNpXl3vqT8HfAz4zStdwW2i0uLQZUfZRcCOqrpzb9dzm6i0OHS5pz4buDjJw8D1wLlJvjzoVJJmts9SV9Unqmp1VR0PXAJ8p6ouG3wySTPx99RSY7qsCP4/VfVd4LuDTCKpF95TS42x1FJjLLXUGEstNcZSS41xm+jIcofaJvrQr1/xGcAzO3rZcp7c+svec8f0M3Ob6F6Mbevn2LaJvvee/p+v/+FjVvO1S9b1njumn5nbRCXNzVJLjbHUUmMstdQYSy01xlJLjen0Kq3pgoSngReBF6rqjCGHkjS7/Xnp5Z9U1c8Hm0RSLzz8lhrTtdQF3JrkziRr9nQFt4lKi0PXw+83V9X2JEcBtyW5r6ru2P0KVbUWWAuTp4n2PKekjjrdU1fV9umfO4CbgDOHHErS7Lrs/V6V5JBdl4G3AZuHHkzSbLocfh8N3JRk1/WvrapbBp1K0sz2Weqq2gq84QDMIqkH/kpLaoyllhpjqaXGWGqpMZZaaozbRM0d1axjy3Wb6F6MaYPk2HLHNOvYct0mKmlullpqjKWWGmOppcZYaqkxllpqTKdSJzksybok9yXZkuSNQw8maTZd1xl9Hrilqv4syXJg5YAzSZrDPkud5FDgHOA9AFW1E9g57FiSZtXl8PsE4Angi0nuSnLldK3RS7hNVFocupR6KXA68IWqOg14Fvj4y69UVWur6oyqOmMZr+p5TElddSn1NmBbVa2ffr6OScklLUL7LHVV/Qx4NMlJ0y+dB9w76FSSZtb17PcHgWumZ763Au8dbiRJ8+hU6qraBPhOl9II+IwyqTGWWmqMpZYaY6mlxlhqqTFuEzV3VLOOLddtonsxpg2SY8sd06xjy3WbqKS5WWqpMZZaaoyllhpjqaXG7LPUSU5Ksmm3j6eSXHEghpO0//b5Kq2quh84FSDJEmA7cNPAc0ma0f4efp8HPFhVPxliGEnz67okYZdLgOv29I0ka4A1ACvcICwtmM731NOtJxcDX9nT9108KC0O+3P4fSGwsaoeH2oYSfPbn1JfyiscektaPLq+l9Yq4HzgxmHHkTSvrosHnwWOGHgWST3wGWVSYyy11BhLLTXGUkuNsdRSY1w8aO6oZh1brosH92JMy+bGljumWceW6+JBSXOz1FJjLLXUGEstNcZSS42x1FJjur708kNJ7kmyOcl1SVYMPZik2XRZEXws8DfAGVV1CrCEya4ySYtQ18PvpcBBSZYCK4GfDjeSpHnss9RVtR34DPAI8BjwZFXd+vLrJVmTZEOSDc/zXP+TSuqky+H34cC7gBOA1wGrklz28uu5TVRaHLocfr8VeKiqnqiq55nsKXvTsGNJmlWXUj8CnJVkZZIweZeOLcOOJWlWXR5TrwfWARuBH03/ztqB55I0o67bRD8FfGrgWST1wGeUSY2x1FJjLLXUGEstNcZSS41xm6i5o5p1bLluE92LMW2QHFvumGYdW67bRCXNzVJLjbHUUmMstdQYSy01xlJLjem6TfTy6SbRe5JcMfRQkmbXZZ3RKcD7gDOBNwAXJXn90INJmk2Xe+rfB9ZX1X9X1QvA94A/HXYsSbPqUurNwFuSHJFkJfB24LiXX8ltotLisM/NJ1W1Jck/AbcCzwKbgBf3cL21TNccvTqv6f+5p5I66XSirKquqqo/qKpzgF8CPx52LEmz6rSjLMlRVbUjyW8zeTx91rBjSZpVp1IDX01yBPA88P6q+tWAM0maQ9dtom8ZehBJ/fAZZVJjLLXUGEstNcZSS42x1FJjBtkmCvwF8F8d/sqRwM87xh8KdF3LaO64Zh1b7mKY9cSqOnSP36mqBfsANuzHddea2z13TLOOLXexzzqmw++vmztY7phmHVvuAZ91NKWuqkH+ccwd16xjy12IWRe61EO9eb2545p1bLmLetZB3qFD0sJZ6HtqST2z1FJjFqzUSS5Icn+SB5J8vKfMq5PsSLK5j7xp5nFJbk9y73Sb6uU95a5I8oMkd09zP91H7m75S5LcleQbPWY+nORHSTYl2dBT5mFJ1iW5L8mWJG/sIfOk6Yy7Pp7qawtukg9Nf16bk1yXZEVPuf1t7O36e7E+P4AlwIPA7wLLgbuBk3vIPQc4Hdjc46zHAKdPLx/CZOtLH7MGOHh6eRmwHjirx7k/DFwLfKPHzIeBI3v+v/Al4K+ml5cDhw3wf+1nwO/0kHUs8BBw0PTzG4D39JB7CpNdgCuZvBz634DXz5q3UPfUZwIPVNXWqtoJXA+8a97QqroD+MW8OS/LfKyqNk4vPw1sYfLDnTe3quqZ6afLph+9nLVMshp4B3BlH3lDSXIokxviqwCqamf1v4DjPODBqvpJT3lLgYOSLGVSwp/2kNnrxt6FKvWxwKO7fb6NHooytCTHA6cxuVftI29Jkk3ADuC2quolF/gc8DHgNz3l7VLArUnuTLKmh7wTgCeAL04fKlyZZFUPubu7BLiuj6Cq2g58BngEeAx4sqpu7SG608berjxR1lGSg4GvAldU1VN9ZFbVi1V1KrAaOHP6xglzSXIRsKOq7px7wP/vzVV1OnAh8P4k58yZt5TJw6UvVNVpTLbV9nJ+BSDJcuBi4Cs95R3O5IjyBOB1wKokl82bW1VbgF0be2/hFTb2drVQpd7OS2+JVk+/tiglWcak0NdU1Y19508POW8HLugh7mzg4iQPM3lYc26SL/eQu+ueiqraAdzE5GHUPLYB23Y7QlnHpOR9uRDYWFWP95T3VuChqnqiqp4HbgTe1Edw9bixd6FK/UPgxCQnTG9NLwG+tkCz7FWSMHnMt6WqPttj7muTHDa9fBBwPnDfvLlV9YmqWl1VxzP5d/1OVc19b5JkVZJDdl0G3sbksHGeWX8GPJrkpOmXzgPunWvQl7qUng69px4Bzkqycvr/4jwm51jmluSo6Z+7NvZeO2tW122ivaqqF5J8APgWk7OTV1fVPfPmJrkO+GPgyCTbgE9V1VVzxp4NvBv40fTxL8DfVdXNc+YeA3wpyRImN643VFVvv34awNHATZP/yywFrq2qW3rI/SBwzfTGfSvw3h4yd93wnA/8dR95AFW1Psk6YCPwAnAX/T1ltLeNvT5NVGqMJ8qkxlhqqTGWWmqMpZYaY6mlxlhqqTGWWmrM/wJ0im1A4OfnGwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_receptive_field(model, data)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "I2APaCzDGeqP" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOVElEQVR4nO3df6zV9X3H8edrF5AfUrD+IJTrJksdmSHxxwiztaWbVAOt1WRZNkxsUrOV/dF22No1uv5h+seWLWua9o+lCRE7kyLGoiTOWMW1tqTJRguIE7jYKVq9iF6dPwDr+GHf++N87wLswv3ccz4f7vl+8nokN5x77+GVd+C87vme7z3nfRQRmFk9fmuyBzCzvFxqs8q41GaVcanNKuNSm1VmSonQaTonpjMra+bc+bN5+8ChrJnOLZfp3HKZAP/DuxyNIxrre0VKPZ1Z/KGWZ838s9tX8sDf/DBrpnPLZTq3XCbA1vjRab/nw2+zyrjUZpVxqc0q41KbVcalNquMS21WmaRSS1oh6VlJz0m6o/RQZta9cUstaQD4Z2AlcBlws6TLSg9mZt1JuadeCjwXEfsi4ihwP3BT2bHMrFsppV4AvHzC58PN104iabWkbZK2HeNIrvnMbIKynSiLiLURsSQilkzlnFyxZjZBKaXeD1x8wueDzdfMrA+llPoXwKWSFkqaBqwCHi47lpl1a9xXaUXEcUlfBB4HBoB7ImJ38cnMrCtJL72MiEeBRwvPYmYZ+BllZpVxqc0q41KbVcalNquMS21WGeV8Ly1JnwE+c8F5F37+77/+D9lyAc4bnMNbw+9kzXRuuUznlssEuP2rX+VgvDnmNtGspR71AX0wsm8T/adCGySd26pZ25Zbatat8aPTltqH32aVcanNKuNSm1XGpTarjEttVhmX2qwyKYsH75E0ImnX2RjIzHqTck/9L8CKwnOYWSbjljoitgBvnoVZzCyDbO9PLWk1sBpgOjNzxZrZBHmbqFllfPbbrDIutVllUn6ltQH4d2CRpGFJf1F+LDPrVsqK4JvPxiBmlocPv80q41KbVcalNquMS21WGZfarDLeJurcVs3atlxvEz2DNm2QbFtum2ZtW663iZpZz1xqs8q41GaVcanNKuNSm1XGpTarTMpLLy+W9KSkPZJ2S1pzNgYzs+6k7Cg7DtweETskzQa2S3oiIvYUns3MupCyTfRAROxoLh8ChoAFpQczs+5MaJuopEuAK4GtY3zP20TN+kDyiTJJ5wIPArdFxMFTv+9tomb9IanUkqbSKfT6iHio7Ehm1ouUs98C1gFDEfGt8iOZWS9S7qmvAT4LXCtpZ/PxqcJzmVmXUraJ/gwY8yVeZtZ//Iwys8q41GaVcanNKuNSm1Um2/tT28kuvfzXPP7Kzuy5W3Z9Intuicw25q5bvzJ75mTwNtFCufN/bzbnzngte+7h9+Zlzy2R2cbcN/57vreJno63icLXN3+CZYu/kz13y6412XNLZLYxd936O71N1Mz6j0ttVhmX2qwyLrVZZVxqs8qkvPRyuqSfS3q6WTz4jbMxmJl1J+XJJ0eAayPicLMs4WeSfhgR/1F4NjPrQspLLwM43Hw6tfnI/8ttM8sidZ3RgKSdwAjwRESMuXhQ0jZJ245xJPecZpYoqdQR8X5EXAEMAkslLR7jOl48aNYHJnT2OyLeBp4EVpQZx8x6lXL2+0JJc5vLM4DrgL2lBzOz7qSc/Z4P3CtpgM4PgQci4pGyY5lZt1LOfv8nnXflMLMW8DPKzCrjUptVxqU2q4xLbVYZl9qsMl486MWDrVsQePi9eRz45aHsuSVuC148eAZePNjhxYOd3L+7/qfZc0vcFrx40Mx65lKbVcalNquMS21WGZfarDIutVllkkvdrDR6SpJfdmnWxyZyT70GGCo1iJnlkbp4cBD4NHB32XHMrFep99TfBr4G/OZ0V/A2UbP+kLKj7AZgJCK2n+l63iZq1h9S7qmvAW6U9CJwP3CtpO8XncrMujZuqSPizogYjIhLgFXAjyPiluKTmVlX/Htqs8qkrAj+PxHxE+AnRSYxsyx8T21WGZfarDIutVllXGqzyrjUZpXxNlFvE+Xwe/N44b3TPgO4a/OmTuOdfW9lz23TbcHbRM/A20Q7Sm0TvXV3/ufrf2X+IA+v2pg9t023BW8TNbOeudRmlXGpzSrjUptVxqU2q4xLbVaZpFdpNQsSDgHvA8cjYknJocysexN56eUfR8QbxSYxsyx8+G1WmdRSB7BZ0nZJq8e6greJmvWH1MPvj0XEfkkXAU9I2hsRW068QkSsBdZC52mimec0s0RJ99QRsb/5cwTYBCwtOZSZdS9l7/csSbNHLwPXA7tKD2Zm3Uk5/J4HbJI0ev37IuKxolOZWdfGLXVE7AMuPwuzmFkG/pWWWWVcarPKuNRmlXGpzSrjUptVxttEnduqWduW622iZ9CmDZJty23TrG3L9TZRM+uZS21WGZfarDIutVllXGqzyrjUZpVJKrWkuZI2StoraUjSR0oPZmbdSV1n9B3gsYj4U0nTgJkFZzKzHoxbaklzgGXA5wAi4ihwtOxYZtatlMPvhcDrwPckPSXp7mat0Um8TdSsP6SUegpwFfDdiLgSeBe449QrRcTaiFgSEUumck7mMc0sVUqph4HhiNjafL6RTsnNrA+NW+qIeBV4WdKi5kvLgT1FpzKzrqWe/f4SsL45870PuLXcSGbWi6RSR8ROwO90adYCfkaZWWVcarPKuNRmlXGpzSrjUptVxttEnduqWduW622iZ9CmDZJty23TrG3L9TZRM+uZS21WGZfarDIutVllXGqzyoxbakmLJO084eOgpNvOxnBmNnHjvkorIp4FrgCQNADsBzYVnsvMujTRw+/lwPMR8asSw5hZ71KXJIxaBWwY6xuSVgOrAaZ7g7DZpEm+p262ntwI/GCs73vxoFl/mMjh90pgR0S8VmoYM+vdREp9M6c59Daz/pH6XlqzgOuAh8qOY2a9Sl08+C5wfuFZzCwDP6PMrDIutVllXGqzyrjUZpVxqc0q48WDzm3VrG3L9eLBM2jTsrm25bZp1rblevGgmfXMpTarjEttVhmX2qwyLrVZZVxqs8qkvvTyy5J2S9olaYOk6aUHM7PupKwIXgD8NbAkIhYDA3R2lZlZH0o9/J4CzJA0BZgJvFJuJDPrxbiljoj9wDeBl4ADwDsRsfnU60laLWmbpG3HOJJ/UjNLknL4fR5wE7AQ+BAwS9Itp17P20TN+kPK4fcngRci4vWIOEZnT9lHy45lZt1KKfVLwNWSZkoSnXfpGCo7lpl1K+Ux9VZgI7ADeKb5O2sLz2VmXUrdJnoXcFfhWcwsAz+jzKwyLrVZZVxqs8q41GaVcanNKuNtos5t1axty/U20TNo0wbJtuW2ada25XqbqJn1zKU2q4xLbVYZl9qsMi61WWVcarPKpG4TXdNsEt0t6bbSQ5lZ91LWGS0GPg8sBS4HbpD04dKDmVl3Uu6pfx/YGhG/jojjwE+BPyk7lpl1K6XUu4CPSzpf0kzgU8DFp17J20TN+sO4m08iYkjSPwKbgXeBncD7Y1xvLc2aow/og/mfe2pmSZJOlEXEuoj4g4hYBrwF/LLsWGbWraQdZZIuiogRSb9N5/H01WXHMrNuJZUaeFDS+cAx4AsR8XbBmcysB6nbRD9eehAzy8PPKDOrjEttVhmX2qwyLrVZZVxqs8oU2SYK/DnwXwl/5QLgjcT4OUDqWkbntmvWtuX2w6yXRsScMb8TEZP2AWybwHXXOjc9t02zti2332dt0+H3vzq3WG6bZm1b7lmftTWljogi/zjObdesbcudjFknu9Sl3rzeue2atW25fT1rkXfoMLPJM9n31GaWmUttVplJK7WkFZKelfScpDsyZd4jaUTSrhx5TebFkp6UtKfZpromU+50ST+X9HST+40cuSfkD0h6StIjGTNflPSMpJ2StmXKnCtpo6S9koYkfSRD5qJmxtGPg7m24Er6cvP/tUvSBknTM+Xm29ib+nuxnB/AAPA88LvANOBp4LIMucuAq4BdGWedD1zVXJ5NZ+tLjlkFnNtcngpsBa7OOPdXgPuARzJmvghckPm2cC/wl83lacDcAre1V4HfyZC1AHgBmNF8/gDwuQy5i+nsApxJ5+XQ/wZ8uNu8ybqnXgo8FxH7IuIocD9wU6+hEbEFeLPXnFMyD0TEjubyIWCIzn9ur7kREYebT6c2H1nOWkoaBD4N3J0jrxRJc+j8IF4HEBFHI/8CjuXA8xHxq0x5U4AZkqbQKeErGTKzbuydrFIvAF4+4fNhMhSlNEmXAFfSuVfNkTcgaScwAjwREVlygW8DXwN+kylvVACbJW2XtDpD3kLgdeB7zUOFuyXNypB7olXAhhxBEbEf+CbwEnAAeCciNmeITtrYm8onyhJJOhd4ELgtIg7myIyI9yPiCmAQWNq8cUJPJN0AjETE9p4H/P8+FhFXASuBL0ha1mPeFDoPl74bEVfS2Vab5fwKgKRpwI3ADzLlnUfniHIh8CFglqRbes2NiCFgdGPvY5xmY2+qySr1fk7+STTYfK0vSZpKp9DrI+Kh3PnNIeeTwIoMcdcAN0p6kc7DmmslfT9D7ug9FRExAmyi8zCqF8PA8AlHKBvplDyXlcCOiHgtU94ngRci4vWIOAY8BHw0R3Bk3Ng7WaX+BXCppIXNT9NVwMOTNMsZSRKdx3xDEfGtjLkXSprbXJ4BXAfs7TU3Iu6MiMGIuITOv+uPI6LnexNJsyTNHr0MXE/nsLGXWV8FXpa0qPnScmBPT4Oe7GYyHXo3XgKuljSzuV0sp3OOpWeSLmr+HN3Ye1+3WanbRLOKiOOSvgg8Tufs5D0RsbvXXEkbgD8CLpA0DNwVEet6jL0G+CzwTPP4F+BvI+LRHnPnA/dKGqDzw/WBiMj266cC5gGbOrdlpgD3RcRjGXK/BKxvfrjvA27NkDn6g+c64K9y5AFExFZJG4EdwHHgKfI9ZTTbxl4/TdSsMj5RZlYZl9qsMi61WWVcarPKuNRmlXGpzSrjUptV5n8BFgeDlsYBZaUAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", + "x = ResidualBlock(h=1)(x)\n", + "\n", + "model = tf.keras.Model(inputs=inputs, outputs=x)\n", + "\n", + "plot_receptive_field(model, data)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOaElEQVR4nO3df6zV9X3H8edrF5AfUrCihHLdZKkhMyRVR5itLdmkEmitJsvSYWaTmq7sj7bD1q7R7Q/TP9yPtGnaP5YmROxMihiLkjhjLa61kiYbLSDOCxetolUois4fgHP8sO/9cb53QXbhfu45nw/3fD99PZIbzr338Mo7cF73fM/3nvM+igjMrB6/M9EDmFleLrVZZVxqs8q41GaVcanNKjOpROgUnRNTmZE1c/a8mbx54HDWTOeWy3RuuUyA/+FtjsVRjfa9IqWeygz+SMuyZn76lpXc9zc/zJrp3HKZzi2XCbA1fnza7/nw26wyLrVZZVxqs8q41GaVcanNKuNSm1UmqdSSVkh6WtKzkm4tPZSZdW/MUksaAP4ZWAlcCtwg6dLSg5lZd1LuqZcAz0bE3og4BtwLXF92LDPrVkqp5wMvnfT5vuZr7yFptaRtkrYd52iu+cxsnLKdKIuItRGxOCIWT+acXLFmNk4ppd4PXHTS54PN18ysD6WU+hfAJZIWSJoCrAIeLDuWmXVrzFdpRcQJSV8EfgQMAHdFxK7ik5lZV5JeehkRDwMPF57FzDLwM8rMKuNSm1XGpTarjEttVhmX2qwyyvleWpI+BXxqznkXfP7v/+4fs+UCnDc4izf2vZU107nlMp1bLhPglq9+lUPx+qjbRLOWesT79P7Ivk30G4U2SDq3VbO2LbfUrFvjx6cttQ+/zSrjUptVxqU2q4xLbVYZl9qsMi61WWVSFg/eJemgpKGzMZCZ9SblnvpfgBWF5zCzTMYsdURsAV4/C7OYWQbZ3p9a0mpgNcBUpueKNbNx8jZRs8r47LdZZVxqs8qk/EprA/DvwEJJ+yR9rvxYZtatlBXBN5yNQcwsDx9+m1XGpTarjEttVhmX2qwyLrVZZbxNdHAWc84/kD33yDtzOXfaK63IbdOsbcs98s5cDjxzOGsmeJvomLmf+4t/yJ67ZWgNSxd9pxW5bZq1bblbhtZwx/LHs2aCt4ma/VZxqc0q41KbVcalNquMS21WGZfarDIpL728SNJjknZL2iVpzdkYzMy6k7Kj7ARwS0TskDQT2C7p0YjYXXg2M+tCyjbRAxGxo7l8GBgG5pcezMy6M65topIuBi4Hto7yPW8TNesDySfKJJ0L3A/cHBGHTv2+t4ma9YekUkuaTKfQ6yPigbIjmVkvUs5+C1gHDEfEt8qPZGa9SLmnvgr4DHC1pJ3NxycKz2VmXUrZJvozYNSXeJlZ//Ezyswq41KbVcalNquMS21WmWzvT23Wdr98cjp3LL8sa+anv3H2n13pbaLeJtqqWUvmvvZf87Lfxkrdbr1NdIxcbxNtz6wlc9etvy37bazU7dbbRM1+i7jUZpVxqc0q41KbVcalNqtMyksvp0r6uaQnm8WDXz8bg5lZd1KefHIUuDoijjTLEn4m6YcR8R+FZzOzLqS89DKAI82nk5uP/L/cNrMsUtcZDUjaCRwEHo2IURcPStomadtxjuae08wSJZU6It6NiMuAQWCJpEWjXMeLB836wLjOfkfEm8BjwIoy45hZr1LOfl8gaXZzeRpwDbCn9GBm1p2Us9/zgLslDdD5IXBfRDxUdiwz61bK2e//pPOuHGbWAn5GmVllXGqzyrjUZpVxqc0q41KbVcaLB714sFWzjuQeeOZw9twStzEvHjwDLx4sl9umWUdy71j+ePbcErcxLx40s5651GaVcanNKuNSm1XGpTarjEttVpnkUjcrjZ6Q5JddmvWx8dxTrwGGSw1iZnmkLh4cBD4J3Fl2HDPrVeo99beBrwG/Od0VvE3UrD+k7Ci7FjgYEdvPdD1vEzXrDyn31FcB10l6AbgXuFrS94tOZWZdG7PUEXFbRAxGxMXAKuAnEXFj8cnMrCv+PbVZZVJWBP+fiPgp8NMik5hZFr6nNquMS21WGZfarDIutVllXGqzynibqLeJcuSduTz/zmmfAdy1uZOn8NbeN7LnlrwteJvoaXibaLs2dG4ZWsNNu/I/X/8r8wZ5cNXG7LklbwveJmpmfcelNquMS21WGZfarDIutVllXGqzyiS9SqtZkHAYeBc4ERGLSw5lZt0bz0sv/yQiXis2iZll4cNvs8qkljqAzZK2S1o92hW8TdSsP6Qefn80IvZLuhB4VNKeiNhy8hUiYi2wFjpPE808p5klSrqnjoj9zZ8HgU3AkpJDmVn3UvZ+z5A0c+QysBwYKj2YmXUn5fB7LrBJ0sj174mIR4pOZWZdG7PUEbEX+NBZmMXMMvCvtMwq41KbVcalNquMS21WGZfarDLeJurcVs3atlxvEz2DNm2QbFtum2ZtW663iZpZz1xqs8q41GaVcanNKuNSm1XGpTarTFKpJc2WtFHSHknDkj5cejAz607qOqPvAI9ExJ9JmgJMLziTmfVgzFJLmgUsBT4LEBHHgGNlxzKzbqUcfi8AXgW+J+kJSXc2a43ew9tEzfpDSqknAVcA342Iy4G3gVtPvVJErI2IxRGxeDLnZB7TzFKllHofsC8itjafb6RTcjPrQ2OWOiJeBl6StLD50jJgd9GpzKxrqWe/vwSsb8587wVuKjeSmfUiqdQRsRPwO12atYCfUWZWGZfarDIutVllXGqzyrjUZpXxNlHntmrWtuV6m+gZtGmDZNty2zRr23K9TdTMeuZSm1XGpTarjEttVhmX2qwyY5Za0kJJO0/6OCTp5rMxnJmN35iv0oqIp4HLACQNAPuBTYXnMrMujffwexnwXET8qsQwZta71CUJI1YBG0b7hqTVwGqAqd4gbDZhku+pm60n1wE/GO37Xjxo1h/Gc/i9EtgREa+UGsbMejeeUt/AaQ69zax/pL6X1gzgGuCBsuOYWa9SFw++DZxfeBYzy8DPKDOrjEttVhmX2qwyLrVZZVxqs8p48aBzWzVr23K9ePAM2rRsrm25bZq1bblePGhmPXOpzSrjUptVxqU2q4xLbVYZl9qsMqkvvfyypF2ShiRtkDS19GBm1p2UFcHzgb8GFkfEImCAzq4yM+tDqYffk4BpkiYB04FflxvJzHoxZqkjYj/wTeBF4ADwVkRsPvV6klZL2iZp23GO5p/UzJKkHH6fB1wPLAA+AMyQdOOp1/M2UbP+kHL4/XHg+Yh4NSKO09lT9pGyY5lZt1JK/SJwpaTpkkTnXTqGy45lZt1KeUy9FdgI7ACeav7O2sJzmVmXUreJ3g7cXngWM8vAzygzq4xLbVYZl9qsMi61WWVcarPKeJuoc1s1a9tyvU30DNq0QbJtuW2atW253iZqZj1zqc0q41KbVcalNquMS21WGZfarDKp20TXNJtEd0m6ufRQZta9lHVGi4DPA0uADwHXSvpg6cHMrDsp99R/AGyNiP+OiBPA48Cflh3LzLqVUuoh4GOSzpc0HfgEcNGpV/I2UbP+MObmk4gYlvRPwGbgbWAn8O4o11tLs+bofXp//ueemlmSpBNlEbEuIv4wIpYCbwDPlB3LzLqVtKNM0oURcVDS79J5PH1l2bHMrFtJpQbul3Q+cBz4QkS8WXAmM+tB6jbRj5UexMzy8DPKzCrjUptVxqU2q4xLbVYZl9qsMkW2iQJ/Dvwy4a/MAV5LjJ8FpK5ldG67Zm1bbj/MeklEzBr1OxExYR/AtnFcd61z03PbNGvbcvt91jYdfv+rc4vltmnWtuWe9VlbU+qIKPKP49x2zdq23ImYdaJLXerN653brlnbltvXsxZ5hw4zmzgTfU9tZpm51GaVmbBSS1oh6WlJz0q6NVPmXZIOShrKkddkXiTpMUm7m22qazLlTpX0c0lPNrlfz5F7Uv6ApCckPZQx8wVJT0naKWlbpszZkjZK2iNpWNKHM2QubGYc+TiUawuupC83/19DkjZImpopN9/G3tTfi+X8AAaA54DfB6YATwKXZshdClwBDGWcdR5wRXN5Jp2tLzlmFXBuc3kysBW4MuPcXwHuAR7KmPkCMCfzbeFu4C+by1OA2QVuay8Dv5chaz7wPDCt+fw+4LMZchfR2QU4nc7Lof8N+GC3eRN1T70EeDYi9kbEMeBe4PpeQyNiC/B6rzmnZB6IiB3N5cPAMJ3/3F5zIyKONJ9Obj6ynLWUNAh8ErgzR14pkmbR+UG8DiAijkX+BRzLgOci4leZ8iYB0yRNolPCX2fIzLqxd6JKPR946aTP95GhKKVJuhi4nM69ao68AUk7gYPAoxGRJRf4NvA14DeZ8kYEsFnSdkmrM+QtAF4Fvtc8VLhT0owMuSdbBWzIERQR+4FvAi8CB4C3ImJzhuikjb2pfKIskaRzgfuBmyPiUI7MiHg3Ii4DBoElzRsn9ETStcDBiNje84D/30cj4gpgJfAFSUt7zJtE5+HSdyPicjrbarOcXwGQNAW4DvhBprzz6BxRLgA+AMyQdGOvuRExDIxs7H2E02zsTTVRpd7Pe38SDTZf60uSJtMp9PqIeCB3fnPI+RiwIkPcVcB1kl6g87Dmaknfz5A7ck9FRBwENtF5GNWLfcC+k45QNtIpeS4rgR0R8UqmvI8Dz0fEqxFxHHgA+EiO4Mi4sXeiSv0L4BJJC5qfpquABydoljOSJDqP+YYj4lsZcy+QNLu5PA24BtjTa25E3BYRgxFxMZ1/159ERM/3JpJmSJo5chlYTuewsZdZXwZekrSw+dIyYHdPg77XDWQ69G68CFwpaXpzu1hG5xxLzyRd2Pw5srH3nm6zUreJZhURJyR9EfgRnbOTd0XErl5zJW0A/hiYI2kfcHtErOsx9irgM8BTzeNfgL+NiId7zJ0H3C1pgM4P1/siItuvnwqYC2zq3JaZBNwTEY9kyP0SsL754b4XuClD5sgPnmuAv8qRBxARWyVtBHYAJ4AnyPeU0Wwbe/00UbPK+ESZWWVcarPKuNRmlXGpzSrjUptVxqU2q4xLbVaZ/wWVXHDe9eNHQgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", + "x = ResidualBlock(h=1)(x)\n", + "x = ResidualBlock(h=1)(x)\n", + "\n", + "model = tf.keras.Model(inputs=inputs, outputs=x)\n", + "\n", + "plot_receptive_field(model, data)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOj0lEQVR4nO3dbYzV5Z3G8e+1A8iDFKwooYy7sqkxa0yqLmFtbcmuVBZaq0nTuJi1SU237Iu2i6220d0Xpi/ch9g07YtNEyJ2TYoYipJYoxa3tZImu7SIsPKkVbQCRdH1AXAtD/a3L85/NsgOzD3n3Ddz/rfXJ5lwZuZw5ReY65z/+c85v6OIwMzq8QdjPYCZ5eVSm1XGpTarjEttVhmX2qwy40qETtAZMZEpWTOnz5rKm/sOZs10brlM55bLBPgdb3MkDmu47xUp9USm8GdakDXzupsXs/obj2TNdG65TOeWywTYED896fd8+G1WGZfarDIutVllXGqzyrjUZpVxqc0qk1RqSYskPSPpOUm3lh7KzLo3YqklDQD/CiwGLgKul3RR6cHMrDsp99TzgOciYldEHAHuA64tO5aZdSul1LOB3cd9vqf52ntIWippo6SNRzmcaz4zG6VsJ8oiYnlEzI2IueM5I1esmY1SSqn3Aucd9/lg8zUz60Mppf4VcIGkOZImAEuAB8uOZWbdGvFVWhFxTNJXgJ8AA8DdEbGt+GRm1pWkl15GxMPAw4VnMbMM/Iwys8q41GaVcanNKuNSm1XGpTarjHK+l5akzwCfmXHWOV/6x3/452y5AGcNTmPG2fuyZgIcemcmZ0565X2d26ZZ25ZbatZbbr6FjVt+N+w20aylHvIBfTCybxO9czFf/Ot/ypoJsH7rMuZf/L33dW6bZm1bbqlZ5/3l7pOW2offZpVxqc0q41KbVcalNquMS21WGZfarDIpiwfvlrRf0tbTMZCZ9SblnvrfgEWF5zCzTEYsdUSsB14/DbOYWQbZ3p9a0lJgKcBEJueKNbNR8jZRs8r47LdZZVxqs8qk/EprFfAfwIWS9kj6YvmxzKxbKSuCrz8dg5hZHj78NquMS21WGZfarDIutVllXGqzynibaIs2U5bKbdOsbcs99M5M9j17MGsmwM233MKBeN3bRIfTps2UpXLbNGvbctdvXcYdC5/ImgmwIX560lL78NusMi61WWVcarPKuNRmlXGpzSrjUptVJuWll+dJelzSdknbJC07HYOZWXdSdpQdA26OiE2SpgJPSnosIrYXns3MupCyTXRfRGxqLh8EdgCzSw9mZt0Z1TZRSecDlwIbhvmet4ma9YHkE2WSzgTuB26KiAMnft/bRM36Q1KpJY2nU+iVEfFA2ZHMrBcpZ78FrAB2RMR3yo9kZr1Iuae+Avg8cKWkzc3HpwrPZWZdStkm+gtg2Jd4mVn/8TPKzCrjUptVxqU2q4xLbVaZbO9PbdZ2v94ymTsWXpI187o7T/+zK71NtEWbKUvltmnWkrmv/fcs3tjzVtbMswanZc8EbxM9pTZtpiyV26ZZS+auWHkbq7/xSNbM6+5cnD0TvE3U7H3FpTarjEttVhmX2qwyLrVZZVJeejlR0i8lbWkWD37rdAxmZt1JefLJYeDKiDjULEv4haRHIuI/C89mZl1IeellAIeaT8c3H/l/uW1mWaSuMxqQtBnYDzwWEcMuHpS0UdLGoxzOPaeZJUoqdUS8GxGXAIPAPEkXD3MdLx406wOjOvsdEW8CjwOLyoxjZr1KOft9jqTpzeVJwFXAztKDmVl3Us5+zwLukTRA50ZgdUQ8VHYsM+tWytnv/6Lzrhxm1gJ+RplZZVxqs8q41GaVcanNKuNSm1XGiwdbthzPiwc7ufuePZg9t8SSQC8ePAUvHiyX26ZZh3LvWPhE9twSSwK9eNDMeuZSm1XGpTarjEttVhmX2qwyLrVZZZJL3aw0ekqSX3Zp1sdGc0+9DNhRahAzyyN18eAg8GngrrLjmFmvUu+pvwt8E/j9ya7gbaJm/SFlR9nVwP6IePJU1/M2UbP+kHJPfQVwjaQXgfuAKyX9sOhUZta1EUsdEbdFxGBEnA8sAX4WETcUn8zMuuLfU5tVJmVF8P+JiJ8DPy8yiZll4Xtqs8q41GaVcanNKuNSm1XGpTarjLeJtnCTZoltoi+8c9JnAHdt5vgJvLXrjey5pTZ0epvoKXibaLty129dxo3b8j9f/+uzBnlwyZrsuaU2dHqbqJn1JZfarDIutVllXGqzyrjUZpVxqc0qk/QqrWZBwkHgXeBYRMwtOZSZdW80L738i4h4rdgkZpaFD7/NKpNa6gDWSXpS0tLhruBtomb9IfXw++MRsVfSucBjknZGxPrjrxARy4Hl0HmaaOY5zSxR0j11ROxt/twPrAXmlRzKzLqXsvd7iqSpQ5eBhcDW0oOZWXdSDr9nAmslDV3/3oh4tOhUZta1EUsdEbuAj5yGWcwsA/9Ky6wyLrVZZVxqs8q41GaVcanNKtOqbaJt2SDZttw2zdq2XG8TPYU2bZBsW26bZm1brreJmlnPXGqzyrjUZpVxqc0q41KbVcalNqtMUqklTZe0RtJOSTskfbT0YGbWndR1Rt8DHo2Iz0maAEwuOJOZ9WDEUkuaBswHvgAQEUeAI2XHMrNupRx+zwFeBX4g6SlJdzVrjd7D20TN+kNKqccBlwHfj4hLgbeBW0+8UkQsj4i5ETF3PGdkHtPMUqWUeg+wJyI2NJ+voVNyM+tDI5Y6Il4Gdku6sPnSAmB70anMrGupZ7+/CqxsznzvAm4sN5KZ9SKp1BGxGfA7XZq1gJ9RZlYZl9qsMi61WWVcarPKuNRmlfE2Uee2ata25Xqb6Cm0aYNk23LbNGvbcr1N1Mx65lKbVcalNquMS21WGZfarDIjllrShZI2H/dxQNJNp2M4Mxu9EV+lFRHPAJcASBoA9gJrC89lZl0a7eH3AuD5iPhNiWHMrHepSxKGLAFWDfcNSUuBpQATvUHYbMwk31M3W0+uAX403Pe9eNCsP4zm8HsxsCkiXik1jJn1bjSlvp6THHqbWf9IfS+tKcBVwANlxzGzXqUuHnwbOLvwLGaWgZ9RZlYZl9qsMi61WWVcarPKuNRmlfHiQee2ata25Xrx4Cm0adlc23LbNGvbcr140Mx65lKbVcalNquMS21WGZfarDIutVllUl96+TVJ2yRtlbRK0sTSg5lZd1JWBM8G/g6YGxEXAwN0dpWZWR9KPfweB0ySNA6YDPy23Ehm1osRSx0Re4FvAy8B+4C3ImLdideTtFTSRkkbj3I4/6RmliTl8Pss4FpgDvAhYIqkG068nreJmvWHlMPvTwIvRMSrEXGUzp6yj5Udy8y6lVLql4DLJU2WJDrv0rGj7Fhm1q2Ux9QbgDXAJuDp5u8sLzyXmXUpdZvo7cDthWcxswz8jDKzyrjUZpVxqc0q41KbVcalNquMt4k6t1Wzti3X20RPoU0bJNuW26ZZ25brbaJm1jOX2qwyLrVZZVxqs8q41GaVcanNKpO6TXRZs0l0m6SbSg9lZt1LWWd0MfAlYB7wEeBqSR8uPZiZdSflnvpPgA0R8T8RcQx4Avhs2bHMrFsppd4KfELS2ZImA58CzjvxSt4matYfRtx8EhE7JP0LsA54G9gMvDvM9ZbTrDn6gD6Y/7mnZpYk6URZRKyIiD+NiPnAG8CzZccys24l7SiTdG5E7Jf0h3QeT19ediwz61ZSqYH7JZ0NHAW+HBFvFpzJzHqQuk30E6UHMbM8/Iwys8q41GaVcanNKuNSm1XGpTarTJFtosBfAb9O+CszgNcS46cBqWsZnduuWduW2w+zXhAR04b9TkSM2QewcRTXXe7c9Nw2zdq23H6ftU2H3z92brHcNs3attzTPmtrSh0RRf5xnNuuWduWOxazjnWpS715vXPbNWvbcvt61iLv0GFmY2es76nNLDOX2qwyY1ZqSYskPSPpOUm3Zsq8W9J+SVtz5DWZ50l6XNL2Zpvqsky5EyX9UtKWJvdbOXKPyx+Q9JSkhzJmvijpaUmbJW3MlDld0hpJOyXtkPTRDJkXNjMOfRzItQVX0tea/6+tklZJmpgpN9/G3tTfi+X8AAaA54E/BiYAW4CLMuTOBy4DtmacdRZwWXN5Kp2tLzlmFXBmc3k8sAG4POPcXwfuBR7KmPkiMCPzz8I9wN80lycA0wv8rL0M/FGGrNnAC8Ck5vPVwBcy5F5MZxfgZDovh/534MPd5o3VPfU84LmI2BURR4D7gGt7DY2I9cDrveackLkvIjY1lw8CO+j85/aaGxFxqPl0fPOR5aylpEHg08BdOfJKkTSNzg3xCoCIOBL5F3AsAJ6PiN9kyhsHTJI0jk4Jf5shM+vG3rEq9Wxg93Gf7yFDUUqTdD5wKZ171Rx5A5I2A/uBxyIiSy7wXeCbwO8z5Q0JYJ2kJyUtzZA3B3gV+EHzUOEuSVMy5B5vCbAqR1BE7AW+DbwE7APeioh1GaKTNvam8omyRJLOBO4HboqIAzkyI+LdiLgEGATmNW+c0BNJVwP7I+LJngf8/z4eEZcBi4EvS5rfY944Og+Xvh8Rl9LZVpvl/AqApAnANcCPMuWdReeIcg7wIWCKpBt6zY2IHcDQxt5HOcnG3lRjVeq9vPeWaLD5Wl+SNJ5OoVdGxAO585tDzseBRRnirgCukfQinYc1V0r6YYbcoXsqImI/sJbOw6he7AH2HHeEsoZOyXNZDGyKiFcy5X0SeCEiXo2Io8ADwMdyBEfGjb1jVepfARdImtPcmi4BHhyjWU5Jkug85tsREd/JmHuOpOnN5UnAVcDOXnMj4raIGIyI8+n8u/4sInq+N5E0RdLUocvAQjqHjb3M+jKwW9KFzZcWANt7GvS9rifToXfjJeBySZObn4sFdM6x9EzSuc2fQxt77+02K3WbaFYRcUzSV4Cf0Dk7eXdEbOs1V9Iq4M+BGZL2ALdHxIoeY68APg883Tz+Bfj7iHi4x9xZwD2SBujcuK6OiGy/fipgJrC287PMOODeiHg0Q+5XgZXNjfsu4MYMmUM3PFcBf5sjDyAiNkhaA2wCjgFPke8po9k29vppomaV8Ykys8q41GaVcanNKuNSm1XGpTarjEttVhmX2qwy/wufCGzew3+2hgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", + "x = ResidualBlock(h=1)(x)\n", + "x = ResidualBlock(h=1)(x)\n", + "x = ResidualBlock(h=1)(x)\n", + "\n", + "model = tf.keras.Model(inputs=inputs, outputs=x)\n", + "\n", + "plot_receptive_field(model, data)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOaElEQVR4nO3df6zV9X3H8edrFxi/LCgoQS4bLDVkhsQfJczWlmxSGbRWk2XpMLFJzSb9oz+w1Ta6/eH6h8uWNk37x9KEiJ2JiLEoiTNWcauVNNmwgDAvXLSKVC5FgWoFnJUffe+P870bsgv3c8/5fO89309fj+SGc+89vvKO8Lrne773e95HEYGZleP3xnoAM8vLpTYrjEttVhiX2qwwLrVZYcbVETrzop6YN3d81szjv5nJ1IlHsmY6t75M59aXCbBv/0mOvHVaQ32vllLPmzue55+emzVzc99qliz8XtZM59aX6dz6MgEW//n+c37Ph99mhXGpzQrjUpsVxqU2K4xLbVYYl9qsMEmllrRc0kuSXpF0V91DmVn7hi21pB7gn4EVwOXAzZIur3swM2tPyiP1YuCViNgbESeAh4Gb6h3LzNqVUuo5wJmXrwxUX/sASaskbZW09fCvTueaz8xGKNuJsohYExGLImLRxTN6csWa2QillPoAcOaF3L3V18ysC6WU+mfAZZLmS5oArAQer3csM2vXsK/SiohTkr4EPA30APdHxK7aJzOztiS99DIingSerHkWM8vAV5SZFcalNiuMS21WGJfarDAutVlhlPO9tCR9BvjMnEtn3PbgA3+fLRfg+HuzmDrpzayZzq0v07n1ZQLcecedbN35myG3iWYt9aBFV0wMbxNtTm6TZm1abp3bRM9Vah9+mxXGpTYrjEttVhiX2qwwLrVZYVxqs8KkLB68X9IhSX2jMZCZdSblkfpfgOU1z2FmmQxb6ojYDLw1CrOYWQbZnlN7m6hZd/A2UbPC+Oy3WWFcarPCpPxKaz3wH8ACSQOS/rr+scysXSkrgm8ejUHMLA8ffpsVxqU2K4xLbVYYl9qsMC61WWG8TdS5jZq1abnH35vFwZePZc0EuOPOOzkab3mbqHNHL9O5/5d577LnsmYCbIl/P2epffhtVhiX2qwwLrVZYVxqs8K41GaFcanNCpPy0su5kp6VtFvSLkmrR2MwM2vPsC+9BE4Bd0TEdkkXANskPRMRu2uezczakLJN9GBEbK9uHwP6gTl1D2Zm7RnRc2pJ84CrgC1DfM/bRM26QHKpJU0FHgVuj4ijZ3/f20TNukNSqSWNp1XodRHxWL0jmVknUs5+C1gL9EfEd+ofycw6kfJIfS3wOeA6STuqj0/VPJeZtSllm+hPgSFf4mVm3cdXlJkVxqU2K4xLbVYYl9qsMCnXfpv9Tvj5zsncu+zKrJmf/dbkrHkpvE3UuY2atc7cI7+azdsD72TNvLB3WvZM8DZR545BZhNz1667m0e+/qOsmZ/91orsmeBtoma/U1xqs8K41GaFcanNCuNSmxUm5aWXEyU9L2lntXjwm6MxmJm1J+Xik/eB6yLieLUs4aeSfhQR/1nzbGbWhpSXXgZwvPp0fPWR/5fbZpZF6jqjHkk7gEPAMxHhxYNmXSqp1BFxOiKuBHqBxZIWDnEfLx406wIjOvsdEb8GngWW1zOOmXUq5ez3xZKmV7cnAdcDe+oezMzak3L2ezbwgKQeWj8EHomIJ+ody8zalXL2+79ovSuHmTWArygzK4xLbVYYl9qsMC61WWFcarPCePGgcxs162DuwZePZc+tY0mgFw+eR9OW2DUpt0mzDubeu+y57Ll1LAn04kEz65hLbVYYl9qsMC61WWFcarPCuNRmhUkudbXS6AVJftmlWRcbySP1aqC/rkHMLI/UxYO9wKeB++odx8w6lfpI/V3gG8Bvz3UHbxM16w4pO8puAA5FxLbz3c/bRM26Q8oj9bXAjZL2AQ8D10l6sNapzKxtw5Y6Iu6OiN6ImAesBH4cEbfUPpmZtcW/pzYrTMqK4P8VET8BflLLJGaWhR+pzQrjUpsVxqU2K4xLbVYYl9qsMN4m6lyOvzeL19475xXAbZs1fgLv7H07e25dGzq9TfQ8vE20Wbmb+1Zz6673s2YCfG12L4+v3JA9t64Nnd4mamZdyaU2K4xLbVYYl9qsMC61WWFcarPCJL1Kq1qQcAw4DZyKiEV1DmVm7RvJSy//LCKO1DaJmWXhw2+zwqSWOoBNkrZJWjXUHbxN1Kw7pB5+fzwiDki6BHhG0p6I2HzmHSJiDbAGWpeJZp7TzBIlPVJHxIHqz0PARmBxnUOZWftS9n5PkXTB4G1gGdBX92Bm1p6Uw+9ZwEZJg/d/KCKeqnUqM2vbsKWOiL3AFaMwi5ll4F9pmRXGpTYrjEttVhiX2qwwLrVZYWrZJjrzwotv+4e/+8dsudCsDZJNy23SrE3LLWab6Id0UfyJlmbNbNIGyablNmnWpuV6m6iZdcylNiuMS21WGJfarDAutVlhXGqzwiSVWtJ0SRsk7ZHUL+mjdQ9mZu1JXWf0PeCpiPhLSROAyTXOZGYdGLbUkqYBS4DPA0TECeBEvWOZWbtSDr/nA4eBH0h6QdJ91VqjDzhzm+hJ8r/XsZmlSSn1OOBq4PsRcRXwLnDX2XeKiDURsSgiFo3n9zOPaWapUko9AAxExJbq8w20Sm5mXWjYUkfEG8B+SQuqLy0Fdtc6lZm1LfXs95eBddWZ773ArfWNZGadSCp1ROwA/E6XZg3gK8rMCuNSmxXGpTYrjEttVhiX2qww3ibq3EbN2rRcbxM9jyZtkGxabpNmbVqut4maWcdcarPCuNRmhXGpzQrjUpsVZthSS1ogaccZH0cl3T4aw5nZyA37Kq2IeAm4EkBSD3AA2FjzXGbWppEefi8FXo2IX9QxjJl1LnVJwqCVwPqhviFpFbAKYKI3CJuNmeRH6mrryY3AD4f6vhcPmnWHkRx+rwC2R8SbdQ1jZp0bSalv5hyH3mbWPVLfS2sKcD3wWL3jmFmnUhcPvgvMqHkWM8vAV5SZFcalNiuMS21WGJfarDAutVlhvHjQuY2atWm5Xjx4Hk1aNte03CbN2rRcLx40s4651GaFcanNCuNSmxXGpTYrjEttVpjUl15+VdIuSX2S1kuaWPdgZtaelBXBc4CvAIsiYiHQQ2tXmZl1odTD73HAJEnjgMnAL+sbycw6MWypI+IA8G3gdeAg8E5EbDr7fpJWSdoqaetJ3s8/qZklSTn8vhC4CZgPXApMkXTL2ffzNlGz7pBy+P1J4LWIOBwRJ2ntKftYvWOZWbtSSv06cI2kyZJE6106+usdy8zalfKceguwAdgOvFj9N2tqnsvM2pS6TfQe4J6aZzGzDHxFmVlhXGqzwrjUZoVxqc0K41KbFcbbRJ3bqFmbluttoufRpA2STctt0qxNy/U2UTPrmEttVhiX2qwwLrVZYVxqs8K41GaFSd0murraJLpL0u11D2Vm7UtZZ7QQuA1YDFwB3CDpw3UPZmbtSXmk/mNgS0T8d0ScAp4D/qLescysXSml7gM+IWmGpMnAp4C5Z9/J20TNusOwm08iol/SPwGbgHeBHcDpIe63hmrN0Yd0Uf5rT80sSdKJsohYGxEfiYglwNvAy/WOZWbtStpRJumSiDgk6Q9oPZ++pt6xzKxdSaUGHpU0AzgJfDEifl3jTGbWgdRtop+oexAzy8NXlJkVxqU2K4xLbVYYl9qsMC61WWFq2SYK/BXw84T/ZCZwJDF+GpC6ltG5zZq1abndMOtlETFtyO9ExJh9AFtHcN81zk3PbdKsTcvt9lmbdPj9r86tLbdJszYtd9RnbUypI6KW/znObdasTcsdi1nHutR1vXm9c5s1a9Nyu3rWWt6hw8zGzlg/UptZZi61WWHGrNSSlkt6SdIrku7KlHm/pEOS+nLkVZlzJT0raXe1TXV1ptyJkp6XtLPK/WaO3DPyeyS9IOmJjJn7JL0oaYekrZkyp0vaIGmPpH5JH82QuaCacfDjaK4tuJK+Wv199UlaL2liptx8G3tTfy+W8wPoAV4F/giYAOwELs+QuwS4GujLOOts4Orq9gW0tr7kmFXA1Or2eGALcE3Gub8GPAQ8kTFzHzAz87+FB4C/qW5PAKbX8G/tDeAPM2TNAV4DJlWfPwJ8PkPuQlq7ACfTejn0vwEfbjdvrB6pFwOvRMTeiDgBPAzc1GloRGwG3uo056zMgxGxvbp9DOin9ZfbaW5ExPHq0/HVR5azlpJ6gU8D9+XIq4ukabR+EK8FiIgTkX8Bx1Lg1Yj4Raa8ccAkSeNolfCXGTKzbuwdq1LPAfaf8fkAGYpSN0nzgKtoParmyOuRtAM4BDwTEVlyge8C3wB+mylvUACbJG2TtCpD3nzgMPCD6qnCfZKmZMg900pgfY6giDgAfBt4HTgIvBMRmzJEJ23sTeUTZYkkTQUeBW6PiKM5MiPidERcCfQCi6s3TuiIpBuAQxGxreMB/7+PR8TVwArgi5KWdJg3jtbTpe9HxFW0ttVmOb8CIGkCcCPww0x5F9I6opwPXApMkXRLp7kR0Q8Mbux9inNs7E01VqU+wAd/EvVWX+tKksbTKvS6iHgsd351yPkssDxD3LXAjZL20Xpac52kBzPkDj5SERGHgI20nkZ1YgAYOOMIZQOtkueyAtgeEW9myvsk8FpEHI6Ik8BjwMdyBEfGjb1jVeqfAZdJml/9NF0JPD5Gs5yXJNF6ztcfEd/JmHuxpOnV7UnA9cCeTnMj4u6I6I2IebT+v/44Ijp+NJE0RdIFg7eBZbQOGzuZ9Q1gv6QF1ZeWArs7GvSDbibToXfldeAaSZOrfxdLaZ1j6ZikS6o/Bzf2PtRuVuo20awi4pSkLwFP0zo7eX9E7Oo0V9J64E+BmZIGgHsiYm2HsdcCnwNerJ7/AvxtRDzZYe5s4AFJPbR+uD4SEdl+/VSDWcDG1r9lxgEPRcRTGXK/DKyrfrjvBW7NkDn4g+d64As58gAiYoukDcB24BTwAvkuGc22sdeXiZoVxifKzArjUpsVxqU2K4xLbVYYl9qsMC61WWFcarPC/A9lao9AALc59QAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", + "x = ResidualBlock(h=1)(x)\n", + "x = ResidualBlock(h=1)(x)\n", + "x = ResidualBlock(h=1)(x)\n", + "x = ResidualBlock(h=1)(x)\n", + "\n", + "model = tf.keras.Model(inputs=inputs, outputs=x)\n", + "\n", + "plot_receptive_field(model, data)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "Receptive fields.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/WIP/3 -PixelCNNs blind spot in the receptive field/Receptive_fields.ipynb b/WIP/3 -PixelCNNs blind spot in the receptive field/Receptive_fields.ipynb deleted file mode 100644 index 237f045..0000000 --- a/WIP/3 -PixelCNNs blind spot in the receptive field/Receptive_fields.ipynb +++ /dev/null @@ -1,251 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Receptive fields.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "YqTKIYLooHsq", - "colab_type": "text" - }, - "source": [ - "# Comparing receptive fields" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "wM-m3Z8CiLXU", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "gf5wwqP3ozaN", - "colab_type": "code", - "colab": {} - }, - "source": [ - "import random as rn\n", - "\n", - "import matplotlib\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "from tensorflow import nn\n", - "from tensorflow.keras import initializers\n", - "from tensorflow.keras.utils import Progbar" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "jEkll1yno2Vb", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Defining random seeds\n", - "random_seed = 42\n", - "tf.random.set_seed(random_seed)\n", - "np.random.seed(random_seed)\n", - "rn.seed(random_seed)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "yJ_JlzWco7ci", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class MaskedConv2D(keras.layers.Layer):\n", - " \"\"\"Convolutional layers with masks.\n", - "\n", - " Convolutional layers with simple implementation of masks type A and B for\n", - " autoregressive models.\n", - "\n", - " Arguments:\n", - " mask_type: one of `\"A\"` or `\"B\".`\n", - " filters: Integer, the dimensionality of the output space\n", - " (i.e. the number of output filters in the convolution).\n", - " kernel_size: An integer or tuple/list of 2 integers, specifying the\n", - " height and width of the 2D convolution window.\n", - " Can be a single integer to specify the same value for\n", - " all spatial dimensions.\n", - " strides: An integer or tuple/list of 2 integers,\n", - " specifying the strides of the convolution along the height and width.\n", - " Can be a single integer to specify the same value for\n", - " all spatial dimensions.\n", - " Specifying any stride value != 1 is incompatible with specifying\n", - " any `dilation_rate` value != 1.\n", - " padding: one of `\"valid\"` or `\"same\"` (case-insensitive).\n", - " kernel_initializer: Initializer for the `kernel` weights matrix.\n", - " bias_initializer: Initializer for the bias vector.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " mask_type,\n", - " filters,\n", - " kernel_size,\n", - " strides=1,\n", - " padding='same',\n", - " kernel_initializer='glorot_uniform',\n", - " bias_initializer='zeros'):\n", - " super(MaskedConv2D, self).__init__()\n", - "\n", - " assert mask_type in {'A', 'B'}\n", - " self.mask_type = mask_type\n", - "\n", - " self.filters = filters\n", - " self.kernel_size = kernel_size\n", - " self.strides = strides\n", - " self.padding = padding.upper()\n", - " self.kernel_initializer = initializers.get(kernel_initializer)\n", - " self.bias_initializer = initializers.get(bias_initializer)\n", - "\n", - " def build(self, input_shape):\n", - " self.kernel = self.add_weight('kernel',\n", - " shape=(self.kernel_size,\n", - " self.kernel_size,\n", - " int(input_shape[-1]),\n", - " self.filters),\n", - " initializer=self.kernel_initializer,\n", - " trainable=True)\n", - "\n", - " self.bias = self.add_weight('bias',\n", - " shape=(self.filters,),\n", - " initializer=self.bias_initializer,\n", - " trainable=True)\n", - "\n", - " center = self.kernel_size // 2\n", - "\n", - " mask = np.ones(self.kernel.shape, dtype=np.float32)\n", - " mask[center, center + (self.mask_type == 'B'):, :, :] = 0.\n", - " mask[center + 1:, :, :, :] = 0.\n", - "\n", - " self.mask = tf.constant(mask, dtype=tf.float32, name='mask')\n", - "\n", - " def call(self, input):\n", - " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n", - " x = nn.conv2d(input,\n", - " masked_kernel,\n", - " strides=[1, self.strides, self.strides, 1],\n", - " padding=self.padding)\n", - " x = nn.bias_add(x, self.bias)\n", - " return x" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "jxCLMYc-FxdJ", - "colab_type": "code", - "colab": {} - }, - "source": [ - "def plot_receptive_field(model, data):\n", - " out = model(data)\n", - "\n", - " with tf.GradientTape() as tape:\n", - " tape.watch(data)\n", - " prediction = model(data)\n", - " loss = prediction[:,5,5,0]\n", - "\n", - " gradients = tape.gradient(loss, data)\n", - "\n", - " gradients = np.abs(gradients.numpy().squeeze())\n", - " gradients = (gradients > 1e-8).astype('float32')\n", - " gradients[5, 5] = 0.5\n", - "\n", - " plt.figure()\n", - " plt.imshow(gradients)\n", - " plt.title(f'Receptive field from pixel layers')\n", - " plt.show()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "0qpDtNuvo9NL", - "colab_type": "code", - "outputId": "926434e2-44bf-40d3-a39f-6c80ebc880b6", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 281 - } - }, - "source": [ - "height = 10\n", - "width = 10\n", - "n_channel = 1\n", - "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", - "x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", - "x = MaskedConv2D(mask_type='B', filters=1, kernel_size=3, strides=1)(x)\n", - "x = MaskedConv2D(mask_type='B', filters=1, kernel_size=3, strides=1)(x)\n", - "\n", - "model = keras.Model(inputs=inputs, outputs=x)\n", - "data = tf.random.normal((1,10,10,1))\n", - "\n", - "plot_receptive_field(model, data)" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAARBElEQVR4nO3dfZBV9X3H8fdHWEGQKkbzwEOQiUYltoLdxKdqjWI1PmaaptVEU50axjaKWic+dVqcxDTTiSY6k2iCD9gIVVPiRGuNmgjGaCsjIBkFtCWIgkDEKgo2Aazf/nF+K5dl796zy72cuz8/r5mdufeec8/5nofP/n7nt/eeVURgZvnYpeoCzKy5HGqzzDjUZplxqM0y41CbZcahNsuMQ12SpKsl3dqC5X5I0uOSNki6vi/rkfSYpPPrTNtXUkgaXGf6AZIWpfVO3ZFtaDZJX5T0SBOWs0LS5DrT7pB07Y6uox31eMB3NkkrgA8B/wdsBB4CLoyIjRXVcywwMyLGdL0WEf/YotVNAV4Dfi927ocGLgfmRsTEnbjOUiJiFjCr6joGqnZqqU+LiN2BicAk4KqK69lZxgFLdnKgu9a7uN5ESYN2Yi3vC/V6Tc3WTqEGICLWAg9ThBsASYdL+g9J6yX9KrWkXdP2kjRD0mpJb0j6Sc20U1MXc316/x/UTFsh6SpJS9L7ZkgaKmk48FNglKSN6WeUpGskzUzv/amkC2vrTnX9aXp8oKSfSXpd0guS/rynbZV0B/CXwOVpPZNr19No27sta5Ck6yS9Jmk5cEq9fSxpDvBp4LtpvR9P3dGbJT0o6W3g05IOSl389ZIWSzq9tnZJN6V9sVHSk5I+LOmGtD+flzSplxpC0lRJy1PN35K0S5p2rqQn0uMj0/Sx6fkhafkHpud1j3FZkkZKekDSurTsBySNSdM+L2lBt/n/VtJ96fGQtN9flvQbSd+XtFuadqykVZKukLQWmCFp77T89en8+GXXdjdNRFT+A6wAJqfHY4BngRvT89HA/wAnU/wSOiE93ydN/3fgHmAk0AH8cXp9EvAqcBgwiCI8K4AhNet8DhgL7AU8CVybph0LrOpW4zUUXXKALwFP1kybAKwHhgDDgZXAeRSXN5MoutcT6mz7HV3r7WE9jbb9MeD89PgC4Pma7ZkLBDC4znrfe29NHW8CR6V1jQCWAVcDuwLHARuAA2rmfw34Q2AoMAd4Me2bQcC1FN37esc8Uo17AR8F/qtmW84FnqiZ9xtp+bulc+PCPhzjyY32O/AB4HPAsLTd/wr8JE0bArwOHFTz3meAz6XH3wHuT9sxAvg34Js159E7wD+l5ewGfBP4PsW52gEcDaipeao60DU7f2M6aQJ4FNgzTbsCuLPb/A+nA/gR4F1gZA/LvBn4erfXXmBr6FcAF9RMOxn4dclQjwDeBsbVnHS3p8d/Afyy23t/AEzrR6jrbnsPoZ7TbXv+hL6H+oc1z48G1gK71Lx2F3BNzfy31Ey7CFha8/z3gfW9HPMATqp5/jfAo+nxuWwb6g5gAUWgH+oKQclj3DDUPUybCLzR7Vz6Rnr8CeANipAqnQcfq5n3CODFmvNoMzC0ZvrXgPuA/VqVp3bqfn82IkZQ7IgDgb3T6+OAz6fuynpJ64E/ogj0WOD1iHijh+WNAy7r9r6xwKiaeVbWPH6p27S6ImIDRQ/hzPTSWWwd2BkHHNZtvV8EPlxm2T1sQ71t724U229PX9W+fxSwMiLe7bbM0TXPf1Pz+Lc9PN+9D+uru/8jYgtFCA8Gro+UDsod44YkDZP0A0kvSXoLeBzYU1vHFf4Z+IIkAecAP4qITcA+FK37gpr1P5Re77IuIn5X8/xbFD2gR9Klx5V9qbWMthj9rhURv0jXmtcBn6U48HdGxJe7zyvpI8BekvaMiPXdJq+k+O36jV5WN7bm8UeB1V1llCj1LmCapMcpup9za9b7i4g4ocQyGqm77T1Yw/bb01e1270aGCtpl5pgd3WTm2UsWwfravf/NiSNBqYBM4DrJX0yharMMS7jMuAA4LCIWCtpIkUXWwAR8ZSkzRS9ly+kHyguP34LfCIiXqmz7G3OpdQgXEbxy+hgYI6kpyPi0R3chve0U0td6wbgBEmHADOB0ySdmAaDhqYBiDERsYZiUOumNNjRIemYtIxbgAskHabCcEmnSBpRs56vSBojaS/g7yiuzaFocT4gaY9eanyQoqX4GnBPzYn/APBxSeekejokfVLSQf3YD3W3vYd5fwRMTdszEtjRFmAe8L8Ug3gdKgboTgPu3sHl1vpqOm5jgYvZuv/fk1rHO4DbgL+i+OX19TS5zDEuYwRFONenc2FaD/P8EPgusCUingBIx/wW4DuSPpjqHS3pxHorSgN7+6XtepPiz7jv1pu/P9oy1BGxjmIn/kNErATOoBiwWUfx2/mrbK39HGALxSDRq8AlaRnzgS9THIg3KLo853Zb1b8AjwDLgV9TDO4QEc9TtMTLU7dqu+5cainuBSan5XS9voHievZMipZnLVsHSvq6Hxpte61bKK63fwUsTLX1W0RspgjxZyhapJuAL6V90yz3UVwrL6K4nLmth3mmAh8E/j51u88DzpN0dMljXMYNFINYrwFPUXShu7uTovs/s9vrV6T1PpW67j+naPXr2T/NsxH4T+CmiJjby/x91jXg8L6j4gMv50fEz6uu5f1IUgD7R8SyqmspI/2Z6lXg0Ij476rr6U1bttRmbeivgafbPdDQhgNlZu0m9epEMXDb9t633W+zXLn7bZaZlnS/d9WQGMrwVizazIDf8TabY5N6mtaSUA9lOIfp+FYs2syAeb18VsXdb7PMONRmmXGozTLjUJtlxqE2y4xDbZaZUqGWdJKKe20ta8WXus2seRqGOt394XsUX8GbAJwlaUKrCzOz/inTUn8KWBYRy9N3bO+m+I6vmbWhMqEezbb3klrFtvepAkDSFEnzJc3fwqZm1WdmfdS0gbKImB4RnRHR2dH3m3yYWZOUCfUrbHtDuzHpNTNrQ2VC/TSwv6TxknaluPfW/a0ty8z6q+G3tCLiHRX/YuZhiv+CcHtE1P0fTGZWrVJfvYyIByluiWtmbc6fKDPLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzpf6XVu4eXr2o6hIsYyeOmrhT1+eW2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLTMNQSxoraa6kJZIWS7p4ZxRmZv1T5sMn7wCXRcRCSSOABZJ+FhFLWlybmfVDw5Y6ItZExML0eAOwFBjd6sLMrH/69DFRSfsCk4B5PUybAkwBGMqwJpRmZv1ReqBM0u7Aj4FLIuKt7tMjYnpEdEZEZwdDmlmjmfVBqVBL6qAI9KyIuLe1JZnZjigz+i3gNmBpRHy79SWZ2Y4o01IfBZwDHCdpUfo5ucV1mVk/NRwoi4gnAO2EWsysCfyJMrPMONRmmXGozTLjUJtlxjceNEt29g0CW8UttVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcZ3E7UBKZc7f7aCW2qzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDOlQy1pkKRnJD3QyoLMbMf0paW+GFjaqkLMrDlKhVrSGOAU4NbWlmNmO6psS30DcDnwbr0ZJE2RNF/S/C1sakpxZtZ3DUMt6VTg1YhY0Nt8ETE9IjojorODIU0r0Mz6pkxLfRRwuqQVwN3AcZJmtrQqM+u3hqGOiKsiYkxE7AucCcyJiLNbXpmZ9Yv/Tm2WmT59nzoiHgMea0klZtYUbqnNMuNQm2XGoTbLjENtlhmH2iwzvpuoAfCxey5oyXL3u/SplizX6nNLbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmfHdRIETR02suoTK7Yfv+pkLt9RmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWZKhVrSnpJmS3pe0lJJR7S6MDPrn7IfPrkReCgi/kzSrsCwFtZkZjugYagl7QEcA5wLEBGbgc2tLcvM+qtM93s8sA6YIekZSbdKGt59JklTJM2XNH8Lm5peqJmVUybUg4FDgZsjYhLwNnBl95kiYnpEdEZEZwdDmlymmZVVJtSrgFURMS89n00RcjNrQw1DHRFrgZWSDkgvHQ8saWlVZtZvZUe/LwJmpZHv5cB5rSvJzHZEqVBHxCKgs8W1mFkT+BNlZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0yUyrUki6VtFjSc5LukjS01YWZWf80DLWk0cBUoDMiDgYGAWe2ujAz65+y3e/BwG6SBgPDgNWtK8nMdkTDUEfEK8B1wMvAGuDNiHik+3ySpkiaL2n+FjY1v1IzK6VM93skcAYwHhgFDJd0dvf5ImJ6RHRGRGcHQ5pfqZmVUqb7PRl4MSLWRcQW4F7gyNaWZWb9VSbULwOHSxomScDxwNLWlmVm/VXmmnoeMBtYCDyb3jO9xXWZWT8NLjNTREwDprW4FjNrAn+izCwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8woIpq/UGkd8FKJWfcGXmt6Aa0zkOodSLXCwKq3HWodFxH79DShJaEuS9L8iOisrIA+Gkj1DqRaYWDV2+61uvttlhmH2iwzVYd6oP3z+oFU70CqFQZWvW1da6XX1GbWfFW31GbWZA61WWYqC7WkkyS9IGmZpCurqqMRSWMlzZW0RNJiSRdXXVMZkgZJekbSA1XX0htJe0qaLel5SUslHVF1Tb2RdGk6D56TdJekoVXX1F0loZY0CPge8BlgAnCWpAlV1FLCO8BlETEBOBz4ShvXWutiYGnVRZRwI/BQRBwIHEIb1yxpNDAV6IyIg4FBwJnVVrW9qlrqTwHLImJ5RGwG7gbOqKiWXkXEmohYmB5voDjpRldbVe8kjQFOAW6tupbeSNoDOAa4DSAiNkfE+mqramgwsJukwcAwYHXF9WynqlCPBlbWPF9FmwcFQNK+wCRgXrWVNHQDcDnwbtWFNDAeWAfMSJcKt0oaXnVR9UTEK8B1wMvAGuDNiHik2qq254GykiTtDvwYuCQi3qq6nnoknQq8GhELqq6lhMHAocDNETEJeBto5/GVkRQ9yvHAKGC4pLOrrWp7VYX6FWBszfMx6bW2JKmDItCzIuLequtp4CjgdEkrKC5rjpM0s9qS6loFrIqIrp7PbIqQt6vJwIsRsS4itgD3AkdWXNN2qgr108D+ksZL2pVisOH+imrplSRRXPMtjYhvV11PIxFxVUSMiYh9KfbrnIhou9YEICLWAislHZBeOh5YUmFJjbwMHC5pWDovjqcNB/YGV7HSiHhH0oXAwxQjiLdHxOIqainhKOAc4FlJi9JrV0fEgxXWlJOLgFnpl/ty4LyK66krIuZJmg0spPiryDO04UdG/TFRs8x4oMwsMw61WWYcarPMONRmmXGozTLjUJtlxqE2y8z/A1DIMnrGMqKcAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [], - "needs_background": "light" - } - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "I2APaCzDGeqP", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/WIP/4 - Gated_PixelCNN/gated_pixelCNN.py b/WIP/4 - Gated PixelCNN/gated_pixelCNN.py similarity index 76% rename from WIP/4 - Gated_PixelCNN/gated_pixelCNN.py rename to WIP/4 - Gated PixelCNN/gated_pixelCNN.py index dd0a87c..386ece8 100644 --- a/WIP/4 - Gated_PixelCNN/gated_pixelCNN.py +++ b/WIP/4 - Gated PixelCNN/gated_pixelCNN.py @@ -1,85 +1,35 @@ -import tensorflow as tf - -gpu_devices = tf.config.experimental.list_physical_devices('GPU') -for device in gpu_devices: tf.config.experimental.set_memory_growth(device, True) - +"""Script to train Gated pixelCNN on the MNIST dataset.""" import random as rn import matplotlib import matplotlib.pyplot as plt import numpy as np import tensorflow as tf -from tensorflow.keras.utils import Progbar from tensorflow import keras -from tensorflow.keras import initializers from tensorflow import nn - -# Defining random seeds -random_seed = 42 -tf.random.set_seed(random_seed) -np.random.seed(random_seed) -rn.seed(random_seed) - -# Loading data -(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() - -height = 28 -width = 28 -n_channel = 1 - -x_train = x_train.astype('float32') / 255. -x_test = x_test.astype('float32') / 255. - -x_train = x_train.reshape(x_train.shape[0], height, width, n_channel) -x_test = x_test.reshape(x_test.shape[0], height, width, n_channel) - - -def quantise(images, q_levels): - """Quantise image into q levels""" - return (np.digitize(images, np.arange(q_levels) / q_levels) - 1).astype('float32') - - -# Quantise the input data in q levels -q_levels = 2 -x_train_quantised = quantise(x_train, q_levels) -x_test_quantised = quantise(x_test, q_levels) - -# Creating input stream using tf.data API -batch_size = 192 -train_buf = 10000 - -train_dataset = tf.data.Dataset.from_tensor_slices((x_train_quantised / (q_levels - 1), - x_train_quantised.astype('int32'))) -train_dataset = train_dataset.shuffle(buffer_size=train_buf) -train_dataset = train_dataset.batch(batch_size) - -test_dataset = tf.data.Dataset.from_tensor_slices((x_test_quantised / (q_levels - 1), - x_test_quantised.astype('int32'))) -test_dataset = test_dataset.batch(batch_size) +from tensorflow.keras import initializers +from tensorflow.keras.utils import Progbar class MaskedConv2D(keras.layers.Layer): - """Convolutional layers with masks for Gated PixelCNN. - - Masked convolutional layers used to implement Vertical and Horizontal - stacks of the Gated PixelCNN. + """Convolutional layers with masks extended to work with Gated PixelCNN. - Note: This implementation is different from the normal PixelCNN. + Convolutional layers with simple implementation of masks type A and B for + autoregressive models. Extended version to work with the verticala and horizontal + stacks from the Gated PixelCNN model. Arguments: mask_type: one of `"V"`, `"A"` or `"B".` - filters: Integer, the dimensionality of the output space - (i.e. the number of output filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. + filters: Integer, the dimensionality of the output space (i.e. the number of output + filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the height and width + of the 2D convolution window. + Can be a single integer to specify the same value for all spatial dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides of the + convolution along the height and width. + Can be a single integer to specify the same value for all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying any + `dilation_rate` value != 1. padding: one of `"valid"` or `"same"` (case-insensitive). kernel_initializer: Initializer for the `kernel` weights matrix. bias_initializer: Initializer for the bias vector. @@ -127,6 +77,7 @@ def build(self, input_shape): mask = np.ones(self.kernel.shape, dtype=np.float32) + # Get centre of the filter for even or odd dimensions if kernel_h % 2 != 0: center_h = kernel_h // 2 else: @@ -156,8 +107,8 @@ def call(self, input): return x -class GatedBlock(tf.keras.Model): - """ Gated block of the Gated PixelCNN.""" +class GatedBlock(keras.Model): + """ Gated block that compose Gated PixelCNN.""" def __init__(self, mask_type, filters, kernel_size): super(GatedBlock, self).__init__(name='') @@ -169,7 +120,7 @@ def __init__(self, mask_type, filters, kernel_size): self.horizontal_conv = MaskedConv2D(mask_type=mask_type, filters=2 * filters, - kernel_size=kernel_size) + kernel_size=(1, kernel_size)) self.padding = keras.layers.ZeroPadding2D(padding=((1, 0), 0)) self.cropping = keras.layers.Cropping2D(cropping=((0, 1), 0)) @@ -180,20 +131,21 @@ def __init__(self, mask_type, filters, kernel_size): def _gate(self, x): tanh_preactivation, sigmoid_preactivation = tf.split(x, 2, axis=-1) - return tf.nn.tanh(tanh_preactivation) * tf.nn.sigmoid(sigmoid_preactivation) + return nn.tanh(tanh_preactivation) * nn.sigmoid(sigmoid_preactivation) def call(self, input_tensor): v = input_tensor[0] h = input_tensor[1] - vertical_preactivation = self.vertical_conv(v) # NxN + vertical_preactivation = self.vertical_conv(v) - # Shifting feature map down to ensure causality + # Shifting vertical stack feature map down before feed into horizontal stack to + # ensure causality v_to_h = self.padding(vertical_preactivation) v_to_h = self.cropping(v_to_h) - v_to_h = self.v_to_h_conv(v_to_h) # 1x1 + v_to_h = self.v_to_h_conv(v_to_h) - horizontal_preactivation = self.horizontal_conv(h) # 1xN + horizontal_preactivation = self.horizontal_conv(h) v_out = self._gate(vertical_preactivation) @@ -209,29 +161,80 @@ def call(self, input_tensor): return v_out, h_out +def quantise(images, q_levels): + """Quantise image into q levels""" + return (np.digitize(images, np.arange(q_levels) / q_levels) - 1).astype('float32') + + +# def main(): +# ------------------------------------------------------------------------------------ +# Defining random seeds +random_seed = 42 +tf.random.set_seed(random_seed) +np.random.seed(random_seed) +rn.seed(random_seed) + +# ------------------------------------------------------------------------------------ +# Loading data +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + +height = 28 +width = 28 +n_channel = 1 + +x_train = x_train.astype('float32') / 255. +x_test = x_test.astype('float32') / 255. + +x_train = x_train.reshape(x_train.shape[0], height, width, n_channel) +x_test = x_test.reshape(x_test.shape[0], height, width, n_channel) + +# ------------------------------------------------------------------------------------ +# Quantise the input data in q levels +q_levels = 2 +x_train_quantised = quantise(x_train, q_levels) +x_test_quantised = quantise(x_test, q_levels) + +# ------------------------------------------------------------------------------------ +# Creating input stream using tf.data API +batch_size = 256 +train_buf = 60000 + +train_dataset = tf.data.Dataset.from_tensor_slices( + (x_train_quantised / (q_levels - 1), + x_train_quantised.astype('int32')) +) +train_dataset = train_dataset.shuffle(buffer_size=train_buf) +train_dataset = train_dataset.batch(batch_size) + +test_dataset = tf.data.Dataset.from_tensor_slices((x_test_quantised / (q_levels - 1), + x_test_quantised.astype('int32'))) +test_dataset = test_dataset.batch(batch_size) + +# ------------------------------------------------------------------------------------ # Create Gated PixelCNN model inputs = keras.layers.Input(shape=(height, width, n_channel)) v, h = GatedBlock(mask_type='A', filters=64, kernel_size=3)([inputs, inputs]) -for i in range(7): +for i in range(10): v, h = GatedBlock(mask_type='B', filters=64, kernel_size=3)([v, h]) x = keras.layers.Activation(activation='relu')(h) x = keras.layers.Conv2D(filters=128, kernel_size=1, strides=1)(x) - x = keras.layers.Activation(activation='relu')(x) x = keras.layers.Conv2D(filters=q_levels, kernel_size=1, strides=1)(x) -gated_pixelcnn = tf.keras.Model(inputs=inputs, outputs=x) +gated_pixelcnn = keras.Model(inputs=inputs, outputs=x) +# ------------------------------------------------------------------------------------ # Prepare optimizer and loss function -lr_decay = 0.999995 +lr_decay = 0.999 learning_rate = 1e-3 optimizer = keras.optimizers.Adam(lr=learning_rate) compute_loss = keras.losses.CategoricalCrossentropy(from_logits=True) +# ------------------------------------------------------------------------------------ @tf.function def train_step(batch_x, batch_y): with tf.GradientTape() as ae_tape: @@ -246,8 +249,9 @@ def train_step(batch_x, batch_y): return loss +# ------------------------------------------------------------------------------------ # Training loop -n_epochs = 50 +n_epochs = 20 n_iter = int(np.ceil(x_train_quantised.shape[0] / batch_size)) for epoch in range(n_epochs): progbar = Progbar(n_iter) @@ -259,18 +263,20 @@ def train_step(batch_x, batch_y): progbar.add(1, values=[('loss', loss)]) +# ------------------------------------------------------------------------------------ # Test set performance test_loss = [] for batch_x, batch_y in test_dataset: logits = gated_pixelcnn(batch_x, training=False) # Calculate cross-entropy (= negative log-likelihood) - loss = compute_loss(tf.one_hot(batch_y, q_levels), logits) + loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits) test_loss.append(loss) print('nll : {:} nats'.format(np.array(test_loss).mean())) print('bits/dim : {:}'.format(np.array(test_loss).mean() / np.log(2))) +# ------------------------------------------------------------------------------------ # Generating new images samples = np.zeros((100, height, width, n_channel), dtype='float32') for i in range(height): @@ -287,6 +293,7 @@ def train_step(batch_x, batch_y): plt.yticks(np.array([])) plt.show() +# ------------------------------------------------------------------------------------ # Filling occluded images occlude_start_row = 14 num_generated_images = 10 @@ -316,3 +323,6 @@ def train_step(batch_x, batch_y): plt.xticks(np.array([])) plt.yticks(np.array([])) plt.show() + +# if __name__ == '__main__': +# main() diff --git a/WIP/4 - Gated_PixelCNN/gated_pixelcnn.ipynb b/WIP/4 - Gated PixelCNN/gated_pixelcnn.ipynb similarity index 100% rename from WIP/4 - Gated_PixelCNN/gated_pixelcnn.ipynb rename to WIP/4 - Gated PixelCNN/gated_pixelcnn.ipynb diff --git a/WIP/4 - Gated PixelCNN/gated_pixelcnn_receptive_field.ipynb b/WIP/4 - Gated PixelCNN/gated_pixelcnn_receptive_field.ipynb new file mode 100644 index 0000000..4a319c3 --- /dev/null +++ b/WIP/4 - Gated PixelCNN/gated_pixelcnn_receptive_field.ipynb @@ -0,0 +1,472 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "YqTKIYLooHsq" + }, + "source": [ + "# Gated PixelCNN receptive fields\n", + "\n", + "Hi everybody!\n", + "In this notebook, we will analyse the Gated PixelCNN's block receptive field. Diferent of the original PixelCNN, we expect that the blocks of the Gated PixelCNN do not create blind spots that limit the information flow of the previous pixel in order to model the density probability function.\n", + "\n", + "Let's start!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we define the masked convolutions involved in the Gated PixelCNN as presented in the post.\n", + "\n", + "*Note: Here we are using float64 to get more precise values of the gradients and avoid false values." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "gf5wwqP3ozaN" + }, + "outputs": [], + "source": [ + "import random as rn\n", + "\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.ticker import FixedLocator\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "from tensorflow import nn\n", + "from tensorflow.keras import initializers\n", + "from tensorflow.keras.utils import Progbar\n", + "\n", + "tf.keras.backend.set_floatx('float64')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "yJ_JlzWco7ci" + }, + "outputs": [], + "source": [ + "class MaskedConv2D(keras.layers.Layer):\n", + " \"\"\"Convolutional layers with masks extended to work with Gated PixelCNN.\n", + "\n", + " Convolutional layers with simple implementation of masks type A and B for\n", + " autoregressive models. Extended version to work with the verticala and horizontal\n", + " stacks from the Gated PixelCNN model.\n", + "\n", + " Arguments:\n", + " mask_type: one of `\"V\"`, `\"A\"` or `\"B\".`\n", + " filters: Integer, the dimensionality of the output space (i.e. the number of output\n", + " filters in the convolution).\n", + " kernel_size: An integer or tuple/list of 2 integers, specifying the height and width\n", + " of the 2D convolution window.\n", + " Can be a single integer to specify the same value for all spatial dimensions.\n", + " strides: An integer or tuple/list of 2 integers, specifying the strides of the\n", + " convolution along the height and width.\n", + " Can be a single integer to specify the same value for all spatial dimensions.\n", + " Specifying any stride value != 1 is incompatible with specifying any\n", + " `dilation_rate` value != 1.\n", + " padding: one of `\"valid\"` or `\"same\"` (case-insensitive).\n", + " kernel_initializer: Initializer for the `kernel` weights matrix.\n", + " bias_initializer: Initializer for the bias vector.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " mask_type,\n", + " filters,\n", + " kernel_size,\n", + " strides=1,\n", + " padding='same',\n", + " kernel_initializer='glorot_uniform',\n", + " bias_initializer='zeros'):\n", + " super(MaskedConv2D, self).__init__()\n", + "\n", + " assert mask_type in {'A', 'B', 'V'}\n", + " self.mask_type = mask_type\n", + "\n", + " self.filters = filters\n", + "\n", + " if isinstance(kernel_size, int):\n", + " kernel_size = (kernel_size, kernel_size)\n", + " self.kernel_size = kernel_size\n", + "\n", + " self.strides = strides\n", + " self.padding = padding.upper()\n", + " self.kernel_initializer = initializers.get(kernel_initializer)\n", + " self.bias_initializer = initializers.get(bias_initializer)\n", + "\n", + " def build(self, input_shape):\n", + " kernel_h, kernel_w = self.kernel_size\n", + "\n", + " self.kernel = self.add_weight('kernel',\n", + " shape=(kernel_h,\n", + " kernel_w,\n", + " int(input_shape[-1]),\n", + " self.filters),\n", + " initializer=self.kernel_initializer,\n", + " trainable=True)\n", + "\n", + " self.bias = self.add_weight('bias',\n", + " shape=(self.filters,),\n", + " initializer=self.bias_initializer,\n", + " trainable=True)\n", + "\n", + " mask = np.ones(self.kernel.shape, dtype=np.float64)\n", + "\n", + " # Get centre of the filter for even or odd dimensions\n", + " if kernel_h % 2 != 0:\n", + " center_h = kernel_h // 2\n", + " else:\n", + " center_h = (kernel_h - 1) // 2\n", + "\n", + " if kernel_w % 2 != 0:\n", + " center_w = kernel_w // 2\n", + " else:\n", + " center_w = (kernel_w - 1) // 2\n", + "\n", + " if self.mask_type == 'V':\n", + " mask[center_h + 1:, :, :, :] = 0.\n", + " else:\n", + " mask[:center_h, :, :] = 0.\n", + " mask[center_h, center_w + (self.mask_type == 'B'):, :, :] = 0.\n", + " mask[center_h + 1:, :, :] = 0.\n", + "\n", + " self.mask = tf.constant(mask, dtype=tf.float64, name='mask')\n", + "\n", + " def call(self, input):\n", + " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n", + " x = nn.conv2d(input,\n", + " masked_kernel,\n", + " strides=[1, self.strides, self.strides, 1],\n", + " padding=self.padding)\n", + " x = nn.bias_add(x, self.bias)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, we define th eblock implementation." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "K5topys_HW7-" + }, + "outputs": [], + "source": [ + "class GatedBlock(tf.keras.Model):\n", + " \"\"\" Gated block that compose Gated PixelCNN.\"\"\"\n", + "\n", + " def __init__(self, mask_type, filters, kernel_size):\n", + " super(GatedBlock, self).__init__(name='')\n", + "\n", + " self.mask_type = mask_type\n", + " self.vertical_conv = MaskedConv2D(mask_type='V',\n", + " filters=2 * filters,\n", + " kernel_size=kernel_size)\n", + "\n", + " self.horizontal_conv = MaskedConv2D(mask_type=mask_type,\n", + " filters=2 * filters,\n", + " kernel_size=(1, kernel_size))\n", + "\n", + " self.padding = keras.layers.ZeroPadding2D(padding=((1, 0), 0))\n", + " self.cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", + "\n", + " self.v_to_h_conv = keras.layers.Conv2D(filters=2 * filters, kernel_size=1)\n", + "\n", + " self.horizontal_output = keras.layers.Conv2D(filters=filters, kernel_size=1)\n", + "\n", + " def _gate(self, x):\n", + " tanh_preactivation, sigmoid_preactivation = tf.split(x, 2, axis=-1)\n", + " return tf.nn.tanh(tanh_preactivation) * tf.nn.sigmoid(sigmoid_preactivation)\n", + "\n", + " def call(self, input_tensor):\n", + " v = input_tensor[0]\n", + " h = input_tensor[1]\n", + "\n", + " vertical_preactivation = self.vertical_conv(v)\n", + "\n", + " # Shifting vertical stack feature map down before feed into horizontal stack to\n", + " # ensure causality\n", + " v_to_h = self.padding(vertical_preactivation)\n", + " v_to_h = self.cropping(v_to_h)\n", + " v_to_h = self.v_to_h_conv(v_to_h)\n", + "\n", + " horizontal_preactivation = self.horizontal_conv(h)\n", + "\n", + " v_out = self._gate(vertical_preactivation)\n", + "\n", + " horizontal_preactivation = horizontal_preactivation + v_to_h\n", + " h_activated = self._gate(horizontal_preactivation)\n", + " h_activated = self.horizontal_output(h_activated)\n", + "\n", + " if self.mask_type == 'A':\n", + " h_out = h_activated\n", + " elif self.mask_type == 'B':\n", + " h_out = h + h_activated\n", + "\n", + " return v_out, h_out" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to analyse grow the receptive field grows along the layers, we will start analysing 1 block." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "height = 10\n", + "width = 10\n", + "n_channel = 1\n", + "\n", + "data = tf.random.normal((1, height, width, n_channel))\n", + "\n", + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs])\n", + "model = tf.keras.Model(inputs=inputs, outputs=h)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "jxCLMYc-FxdJ" + }, + "outputs": [], + "source": [ + "def plot_receptive_field(model, data):\n", + " with tf.GradientTape() as tape:\n", + " tape.watch(data)\n", + " prediction = model(data)\n", + " loss = prediction[:,5,5,0]\n", + "\n", + " gradients = tape.gradient(loss, data)\n", + "\n", + " gradients = np.abs(gradients.numpy().squeeze())\n", + " gradients = (gradients > 0).astype('float64')\n", + " gradients[5, 5] = 0.5\n", + "\n", + " fig = plt.figure()\n", + " ax = fig.add_subplot(1, 1, 1)\n", + "\n", + " plt.xticks(np.arange(0, 10, step=1))\n", + " plt.yticks(np.arange(0, 10, step=1))\n", + " ax.xaxis.set_minor_locator(FixedLocator(np.arange(0.5, 10.5, step=1)))\n", + " ax.yaxis.set_minor_locator(FixedLocator(np.arange(0.5, 10.5, step=1)))\n", + " plt.grid(which=\"minor\")\n", + " plt.imshow(gradients, vmin=0, vmax=1)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOMElEQVR4nO3da6xldX3G8e/TuTjMgICABBlaaKSkhoRLCUVRbEEIeMGkaVpINJG0Tl+oBcUarS+ML9q0qTH6ojGZANZELsEBEmsUoYoSk3Z0GMYyMGgREGZEB8MdLTd/fbH3NAOdy5q915pz1r/fT3Iy+5yz58kvc+bZe+119v7tVBWS2vFbCz2ApH5ZaqkxllpqjKWWGmOppcYsHSJ0eV5VK1jVa+YhRx3EE4883WumucNlmjtcJsB/8yzP13PZ1fcGKfUKVvGHOafXzD+7/AKu/5tv9Jpp7nCZ5g6XCbC+vrXb73n4LTXGUkuNsdRSYyy11BhLLTXGUkuN6VTqJOcn+VGS+5J8fOihJM1ur6VOsgT4Z+AC4A3AxUneMPRgkmbT5Z76dOC+qrq/qp4HrgPePexYkmbVpdRHAw/v9PnW6ddeJsmaJBuSbHiB5/qaT9I+6u1EWVWtrarTquq0Zbyqr1hJ+6hLqbcBx+z0+erp1yQtQl1K/QPg+CTHJVkOXAR8ddixJM1qr6/SqqoXk3wQ+CawBLiqqu4efDJJM+n00suq+jrw9YFnkdQDn1EmNcZSS42x1FJjLLXUGEstNSZ9vpdWkncB7zr80CPe//ef/IfecgEOXX0wj299stdMc4fLNHe4TIDLP/pRnqrHdrlNtNdS7/DqvKZ63yb6TwNtkDR3VLOOLXeoWdfXt3Zbag+/pcZYaqkxllpqjKWWGmOppcZYaqkxXRYPXpVke5LN+2MgSfPpck/9L8D5A88hqSd7LXVV3Q48th9mkdSD3t6fOskaYA3AClb2FStpH7lNVGqMZ7+lxlhqqTFdfqV1LfDvwAlJtib5i+HHkjSrLiuCL94fg0jqh4ffUmMstdQYSy01xlJLjbHUUmPcJmruqGYdW67bRPdgTBskx5Y7plnHlus2UUlzs9RSYyy11BhLLTXGUkuNsdRSY7q89PKYJLcluSfJ3Uku3R+DSZpNlx1lLwKXV9XGJAcBdyS5taruGXg2STPosk30karaOL38NLAFOHrowSTNZp+2iSY5FjgFWL+L77lNVFoEOp8oS3IgcANwWVU99crvu01UWhw6lTrJMiaFvrqqbhx2JEnz6HL2O8CVwJaq+uzwI0maR5d76jOB9wJnJ9k0/Xj7wHNJmlGXbaLfA3b5Ei9Ji4/PKJMaY6mlxlhqqTGWWmpMb+9Prf3j+JN+xTd/tqnXzNs3v7X3zKFztXtuEx1Z7lG/dxAHHvCLXjOf+fWRvWcOnfvIj5/uPddtonvgNtHhcj95y1s568TP95p5++ZLe88cOvfvzvtu77luE5W0KFlqqTGWWmqMpZYaY6mlxnR56eWKJN9P8sPp4sFP74/BJM2my5NPngPOrqpnpssSvpfkG1X1HwPPJmkGXV56WcAz00+XTT/6/+W2pF50XWe0JMkmYDtwa1XtcvFgkg1JNrzAc33PKamjTqWuqpeq6mRgNXB6khN3cR0XD0qLwD6d/a6qJ4DbgPOHGUfSvLqc/T4iySHTywcA5wL3Dj2YpNl0Oft9FPClJEuY3AhcX1VfG3YsSbPqcvb7P5m8K4ekEfAZZVJjLLXUGEstNcZSS42x1FJjXDw4slwXD7p4EFw82FSuiwddPAguHpT+X7HUUmMstdQYSy01xlJLjbHUUmM6l3q60ujOJL7sUlrE9uWe+lJgy1CDSOpH18WDq4F3AFcMO46keXW9p/4c8DHgN7u7gttEpcWhy46ydwLbq+qOPV3PbaLS4tDlnvpM4MIkDwLXAWcn+fKgU0ma2V5LXVWfqKrVVXUscBHw7ap6z+CTSZqJv6eWGtNlRfD/qqrvAN8ZZBJJvfCeWmqMpZYaY6mlxlhqqTGWWmqM20RHljvUNtEHfr3bZwDP7Mhly3ny/sd7zx3Tz8xtonswtq2fY9smesnd/T9f/yNHrearF63rPXdMPzO3iUqam6WWGmOppcZYaqkxllpqjKWWGtPpVVrTBQlPAy8BL1bVaUMOJWl2+/LSyz+uql8ONomkXnj4LTWma6kLuCXJHUnW7OoKbhOVFoeuh99vrqptSV4L3Jrk3qq6fecrVNVaYC1Mniba85ySOup0T11V26Z/bgduAk4fcihJs+uy93tVkoN2XAbOAzYPPZik2XQ5/D4SuCnJjutfU1U3DzqVpJnttdRVdT9w0n6YRVIP/JWW1BhLLTXGUkuNsdRSYyy11Bi3iZo7qlnHlus20T0Y0wbJseWOadax5bpNVNLcLLXUGEstNcZSS42x1FJjLLXUmE6lTnJIknVJ7k2yJckbhx5M0my6rjP6PHBzVf1pkuXAygFnkjSHvZY6ycHAWcD7AKrqeeD5YceSNKsuh9/HAY8CX0xyZ5IrpmuNXsZtotLi0KXUS4FTgS9U1SnAs8DHX3mlqlpbVadV1WnLeFXPY0rqqkuptwJbq2r99PN1TEouaRHaa6mr6ufAw0lOmH7pHOCeQaeSNLOuZ78/BFw9PfN9P3DJcCNJmkenUlfVJsB3upRGwGeUSY2x1FJjLLXUGEstNcZSS41xm6i5o5p1bLluE92DMW2QHFvumGYdW67bRCXNzVJLjbHUUmMstdQYSy01Zq+lTnJCkk07fTyV5LL9MZykfbfXV2lV1Y+AkwGSLAG2ATcNPJekGe3r4fc5wE+q6qdDDCNpfl2XJOxwEXDtrr6RZA2wBmCFG4SlBdP5nnq69eRC4Cu7+r6LB6XFYV8Ovy8ANlbVL4YaRtL89qXUF7ObQ29Ji0fX99JaBZwL3DjsOJLm1XXx4LPAYQPPIqkHPqNMaoyllhpjqaXGWGqpMZZaaoyLB80d1axjy3Xx4B6Madnc2HLHNOvYcl08KGlullpqjKWWGmOppcZYaqkxllpqTNeXXn44yd1JNie5NsmKoQeTNJsuK4KPBv4aOK2qTgSWMNlVJmkR6nr4vRQ4IMlSYCXws+FGkjSPvZa6qrYBnwEeAh4BnqyqW155vSRrkmxIsuEFnut/UkmddDn8PhR4N3Ac8DpgVZL3vPJ6bhOVFocuh99vAx6oqker6gUme8reNOxYkmbVpdQPAWckWZkkTN6lY8uwY0maVZfH1OuBdcBG4K7p31k78FySZtR1m+ingE8NPIukHviMMqkxllpqjKWWGmOppcZYaqkxbhM1d1Szji3XbaJ7MKYNkmPLHdOsY8t1m6ikuVlqqTGWWmqMpZYaY6mlxlhqqTFdt4leOt0keneSy4YeStLsuqwzOhF4P3A6cBLwziSvH3owSbPpck/9+8D6qvpVVb0IfBf4k2HHkjSrLqXeDLwlyWFJVgJvB4555ZXcJiotDnvdfFJVW5L8I3AL8CywCXhpF9dby3TN0avzmv6feyqpk04nyqrqyqr6g6o6C3gc+PGwY0maVacdZUleW1Xbk/w2k8fTZww7lqRZdSo1cEOSw4AXgA9U1RMDziRpDl23ib5l6EEk9cNnlEmNsdRSYyy11BhLLTXGUkuNGWSbKPDnwH91+CuHA7/sGH8w0HUto7njmnVsuYth1uOr6uBdfqeqFuwD2LAP111rbvfcMc06ttzFPuuYDr//1dzBcsc069hy9/usoyl1VQ3yj2PuuGYdW+5CzLrQpR7qzevNHdesY8td1LMO8g4dkhbOQt9TS+qZpZYas2ClTnJ+kh8luS/Jx3vKvCrJ9iSb+8ibZh6T5LYk90y3qV7aU+6KJN9P8sNp7qf7yN0pf0mSO5N8rcfMB5PclWRTkg09ZR6SZF2Se5NsSfLGHjJPmM644+OpvrbgJvnw9Oe1Ocm1SVb0lNvfxt6uvxfr8wNYAvwE+F1gOfBD4A095J4FnAps7nHWo4BTp5cPYrL1pY9ZAxw4vbwMWA+c0ePcHwGuAb7WY+aDwOE9/1/4EvCX08vLgUMG+L/2c+B3esg6GngAOGD6+fXA+3rIPZHJLsCVTF4O/W/A62fNW6h76tOB+6rq/qp6HrgOePe8oVV1O/DYvDmvyHykqjZOLz8NbGHyw503t6rqmemny6YfvZy1TLIaeAdwRR95Q0lyMJMb4isBqur56n8BxznAT6rqpz3lLQUOSLKUSQl/1kNmrxt7F6rURwMP7/T5VnooytCSHAucwuRetY+8JUk2AduBW6uql1zgc8DHgN/0lLdDAbckuSPJmh7yjgMeBb44fahwRZJVPeTu7CLg2j6Cqmob8BngIeAR4MmquqWH6E4be7vyRFlHSQ4EbgAuq6qn+sisqpeq6mRgNXD69I0T5pLkncD2qrpj7gH/rzdX1anABcAHkpw1Z95SJg+XvlBVpzDZVtvL+RWAJMuBC4Gv9JR3KJMjyuOA1wGrkrxn3tyq2gLs2Nh7M7vZ2NvVQpV6Gy+/JVo9/dqilGQZk0JfXVU39p0/PeS8DTi/h7gzgQuTPMjkYc3ZSb7cQ+6OeyqqajtwE5OHUfPYCmzd6QhlHZOS9+UCYGNV/aKnvLcBD1TVo1X1AnAj8KY+gqvHjb0LVeofAMcnOW56a3oR8NUFmmWPkoTJY74tVfXZHnOPSHLI9PIBwLnAvfPmVtUnqmp1VR3L5N/121U1971JklVJDtpxGTiPyWHjPLP+HHg4yQnTL50D3DPXoC93MT0dek89BJyRZOX0/8U5TM6xzC3Ja6d/7tjYe82sWV23ifaqql5M8kHgm0zOTl5VVXfPm5vkWuCPgMOTbAU+VVVXzhl7JvBe4K7p41+Av62qr8+ZexTwpSRLmNy4Xl9Vvf36aQBHAjdN/i+zFLimqm7uIfdDwNXTG/f7gUt6yNxxw3Mu8Fd95AFU1fok64CNwIvAnfT3lNHeNvb6NFGpMZ4okxpjqaXGWGqpMZZaaoyllhpjqaXGWGqpMf8DsdB7WCt2dpcAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_receptive_field(model, data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "I2APaCzDGeqP" + }, + "source": [ + "Excellent! Like we expected the block considered all the previous blocks in the same row of the analyssed pixel, and the two rows over it.\n", + "\n", + "Note that this receptive field is different from the original PixelCNN. In the original PixelCNN only one row over the analysed pixel influenced in its prediction (when using one masked convolution). In the Gated PixelCNN, the authors used a vertical stack with effective area of 2x3 per vertical convolution. This is not a problem, since the considered pixels still being the ones in past positions. We believe the main coice for this format is to implement an efficient way to apply the masked convolutions without using masking (which we will discuss in future posts).\n", + "\n", + "For the next step, we wll verify a model with 2, 3, 4, and 5 layers" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOFklEQVR4nO3dbczddX3H8fdnvbG0VIrcpVI2WHRkhkRgDUNRtoEY8AaTZdlKgolks3ugDhRnZD4gPlm2zBh9sJg0gDMRMFghUYJYpjhislVLKbOlsHEntALFyf0cd3734JwuhfXm33P+/17X+eX9Sq70XNd1+sk37fU553/+1znfk6pCUjt+Y64HkNQvSy01xlJLjbHUUmMstdSYhUOELs4bagnLes1csXI5Tz/2XK+Z5g6Xae5wmQD/wwu8VC9mb98bpNRLWMbv59xeM//08gu44a+/22umucNlmjtcJsDG+v4+v+fht9QYSy01xlJLjbHUUmMstdQYSy01plOpk5yf5L4k9yf57NBDSZrcAUudZAHwj8AFwNuAi5K8bejBJE2myz31GcD9VfVgVb0EfAP40LBjSZpUl1IfDzy6x+c7xl97jSRrk2xKsullXuxrPkkHqbcTZVW1rqpWV9XqRbyhr1hJB6lLqXcCJ+zx+arx1yTNQ11K/RPgrUlOSrIYWAN8e9ixJE3qgK/SqqpXknwc+B6wALimqrYNPpmkiXR66WVV3QLcMvAsknrgM8qkxlhqqTGWWmqMpZYaY6mlxqTP99JK8kHgg0cfecxH//Zzf9dbLsCRq47gqR3P9Jpp7nCZ5g6XCXD5pz/Ns/XLvW4T7bXUu70xb6ret4n+w0AbJM2dqVlnLXeoWTfW9/dZag+/pcZYaqkxllpqjKWWGmOppcZYaqkxXRYPXpNkV5Kth2IgSdPpck/9T8D5A88hqScHLHVV3QH88hDMIqkHvb0/dZK1wFqAJSztK1bSQXKbqNQYz35LjbHUUmO6/ErreuBfgZOT7Ejy58OPJWlSXVYEX3QoBpHUDw+/pcZYaqkxllpqjKWWGmOppca4TXSg3JW/s5zDD3ui99znf3Vc77lDZM5i7i/+a6XbRPfFbaLwuQ1/wNmnfLn33Du2Xtp77hCZs5h79bVXuE1U0vxjqaXGWGqpMZZaaoyllhpjqaXGdHnp5QlJbk9yT5JtSS49FINJmkyXHWWvAJdX1eYky4E7k9xWVfcMPJukCXTZJvpYVW0eX34O2A4cP/RgkiZzUNtEk5wInAZs3Mv33CYqzQOdT5QlORz4FnBZVT37+u+7TVSaHzqVOskiRoW+tqpuHHYkSdPocvY7wNXA9qr64vAjSZpGl3vqs4APA+ck2TL+eN/Ac0maUJdtoj8C9voSL0nzj88okxpjqaXGWGqpMZZaaoyllhrjNlG3ic7c1k+3ibpNdE5y3SY6e7luE5U0L1lqqTGWWmqMpZYaY6mlxnR56eWSJD9Ocvd48eDnD8VgkibTZZ3Ri8A5VfX8eFnCj5J8t6r+beDZJE2gy0svC3h+/Omi8Uf/v9yW1Iuu64wWJNkC7AJuq6q9Lh5MsinJppd5se85JXXUqdRV9WpVnQqsAs5IcsperuPiQWkeOKiz31X1NHA7cP4w40iaVpez38ckWTG+fBhwHnDv0INJmkyXs98rga8lWcDoRuCGqrp52LEkTarL2e9/Z/SuHJJmgM8okxpjqaXGWGqpMZZaaoyllhrj4kEXD87cgkAXD7p4cE5yXTw4e7kuHpQ0L1lqqTGWWmqMpZYaY6mlxlhqqTGdSz1eaXRXEl92Kc1jB3NPfSmwfahBJPWj6+LBVcD7gauGHUfStLreU38J+Azw631dwW2i0vzQZUfZB4BdVXXn/q7nNlFpfuhyT30WcGGSh4FvAOck+fqgU0ma2AFLXVVXVNWqqjoRWAP8oKouHnwySRPx99RSY7qsCP4/VfVD4IeDTCKpF95TS42x1FJjLLXUGEstNcZSS41xm6jbRHn+V8fx0K/2+QzgiR23aDHPPPhU77mz9LPgNtH9cJvoyFDbRC/Z1v/z9T+1chXfXrO+99xZ+llwm6ikqVlqqTGWWmqMpZYaY6mlxlhqqTGdXqU1XpDwHPAq8EpVrR5yKEmTO5iXXv5RVf1isEkk9cLDb6kxXUtdwIYkdyZZu7cruE1Umh+6Hn6/q6p2JjkWuC3JvVV1x55XqKp1wDoYPU205zklddTpnrqqdo7/3AXcBJwx5FCSJtdl7/eyJMt3XwbeC2wdejBJk+ly+H0ccFOS3de/rqpuHXQqSRM7YKmr6kHg7YdgFkk98FdaUmMstdQYSy01xlJLjbHUUmPcJmruTM06a7luE92PWdogOWu5szTrrOW6TVTS1Cy11BhLLTXGUkuNsdRSYyy11JhOpU6yIsn6JPcm2Z7kHUMPJmkyXdcZfRm4tar+JMliYOmAM0mawgFLneQI4GzgIwBV9RLw0rBjSZpUl8Pvk4Anga8muSvJVeO1Rq/hNlFpfuhS6oXA6cBXquo04AXgs6+/UlWtq6rVVbV6EW/oeUxJXXUp9Q5gR1VtHH++nlHJJc1DByx1VT0OPJrk5PGXzgXuGXQqSRPrevb7E8C14zPfDwKXDDeSpGl0KnVVbQF8p0tpBviMMqkxllpqjKWWGmOppcZYaqkxbhM1d6ZmnbVct4nuxyxtkJy13FmaddZy3SYqaWqWWmqMpZYaY6mlxlhqqTEHLHWSk5Ns2ePj2SSXHYrhJB28A75Kq6ruA04FSLIA2AncNPBckiZ0sIff5wIPVNXPhhhG0vS6LknYbQ1w/d6+kWQtsBZgiRuEpTnT+Z56vPXkQuCbe/u+iwel+eFgDr8vADZX1RNDDSNpegdT6ovYx6G3pPmj63tpLQPOA24cdhxJ0+q6ePAF4KiBZ5HUA59RJjXGUkuNsdRSYyy11BhLLTXGxYPmztSss5br4sH9mKVlc7OWO0uzzlquiwclTc1SS42x1FJjLLXUGEstNcZSS43p+tLLTybZlmRrkuuTLBl6MEmT6bIi+Hjgr4DVVXUKsIDRrjJJ81DXw++FwGFJFgJLgZ8PN5KkaRyw1FW1E/gC8AjwGPBMVW14/fWSrE2yKcmml3mx/0klddLl8PtI4EPAScCbgWVJLn799dwmKs0PXQ6/3wM8VFVPVtXLjPaUvXPYsSRNqkupHwHOTLI0SRi9S8f2YceSNKkuj6k3AuuBzcBPx39n3cBzSZpQ122iVwJXDjyLpB74jDKpMZZaaoyllhpjqaXGWGqpMW4TNXemZp21XLeJ7scsbZCctdxZmnXWct0mKmlqllpqjKWWGmOppcZYaqkxllpqTNdtopeON4luS3LZ0ENJmlyXdUanAB8FzgDeDnwgyVuGHkzSZLrcU/8usLGq/ruqXgH+BfjjYceSNKkupd4KvDvJUUmWAu8DTnj9ldwmKs0PB9x8UlXbk/w9sAF4AdgCvLqX661jvObojXlT/889ldRJpxNlVXV1Vf1eVZ0NPAX8x7BjSZpUpx1lSY6tql1JfpPR4+kzhx1L0qQ6lRr4VpKjgJeBj1XV0wPOJGkKXbeJvnvoQST1w2eUSY2x1FJjLLXUGEstNcZSS40ZZJso8GfAf3b4K0cDv+gYfwTQdS2jubM166zlzodZ31pVR+z1O1U1Zx/ApoO47jpzu+fO0qyzljvfZ52lw+/vmDtY7izNOmu5h3zWmSl1VQ3yj2PubM06a7lzMetcl3qoN683d7ZmnbXceT3rIO/QIWnuzPU9taSeWWqpMXNW6iTnJ7kvyf1JPttT5jVJdiXZ2kfeOPOEJLcnuWe8TfXSnnKXJPlxkrvHuZ/vI3eP/AVJ7kpyc4+ZDyf5aZItSTb1lLkiyfok9ybZnuQdPWSePJ5x98ezfW3BTfLJ8f/X1iTXJ1nSU25/G3u7/l6szw9gAfAA8NvAYuBu4G095J4NnA5s7XHWlcDp48vLGW196WPWAIePLy8CNgJn9jj3p4DrgJt7zHwYOLrnn4WvAX8xvrwYWDHAz9rjwG/1kHU88BBw2PjzG4CP9JB7CqNdgEsZvRz6n4G3TJo3V/fUZwD3V9WDVfUS8A3gQ9OGVtUdwC+nzXld5mNVtXl8+TlgO6P/3Glzq6qeH3+6aPzRy1nLJKuA9wNX9ZE3lCRHMLohvhqgql6q/hdwnAs8UFU/6ylvIXBYkoWMSvjzHjJ73dg7V6U+Hnh0j8930ENRhpbkROA0RveqfeQtSLIF2AXcVlW95AJfAj4D/LqnvN0K2JDkziRre8g7CXgS+Or4ocJVSZb1kLunNcD1fQRV1U7gC8AjwGPAM1W1oYfoTht7u/JEWUdJDge+BVxWVc/2kVlVr1bVqcAq4IzxGydMJckHgF1VdefUA/5/76qq04ELgI8lOXvKvIWMHi59papOY7SttpfzKwBJFgMXAt/sKe9IRkeUJwFvBpYluXja3KraDuze2Hsr+9jY29VclXonr70lWjX+2ryUZBGjQl9bVTf2nT8+5LwdOL+HuLOAC5M8zOhhzTlJvt5D7u57KqpqF3ATo4dR09gB7NjjCGU9o5L35QJgc1U90VPee4CHqurJqnoZuBF4Zx/B1ePG3rkq9U+AtyY5aXxrugb49hzNsl9Jwugx3/aq+mKPucckWTG+fBhwHnDvtLlVdUVVraqqExn9u/6gqqa+N0myLMny3ZeB9zI6bJxm1seBR5OcPP7SucA9Uw36WhfR06H32CPAmUmWjn8uzmV0jmVqSY4d/7l7Y+91k2Z13Sbaq6p6JcnHge8xOjt5TVVtmzY3yfXAHwJHJ9kBXFlVV08ZexbwYeCn48e/AH9TVbdMmbsS+FqSBYxuXG+oqt5+/TSA44CbRj/LLASuq6pbe8j9BHDt+Mb9QeCSHjJ33/CcB/xlH3kAVbUxyXpgM/AKcBf9PWW0t429Pk1UaownyqTGWGqpMZZaaoyllhpjqaXGWGqpMZZaasz/AuAwbUDSVViDAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "model = tf.keras.Model(inputs=inputs, outputs=h)\n", + "\n", + "plot_receptive_field(model, data)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOMUlEQVR4nO3dfcyddX3H8fdnfbC0ICAoQcoGi4aMkAiMMBQlGwgBHzBZlg0yTCTO7g91xYcZ3f4w/rGnaIz+sZg0gCORh2CFxBFE2ESJyVYtpcxC0fEktFaLUXka48nv/jinS2GlvXrOdfW+r9/er+ROz33fp598096fc65z3ed8T6oKSe34jYUeQFK/LLXUGEstNcZSS42x1FJjlg4RujyvqhWs6jXzsKMP4Vc7nuw109zhMs0dLhPgv3ma5+rZ7Ol7g5R6Bav4vZzTa+Yff+wCrv/Lb/Saae5wmeYOlwmwof71Fb/n4bfUGEstNcZSS42x1FJjLLXUGEstNaZTqZOcn+SHSe5P8smhh5I0u32WOskS4B+BC4ATgYuTnDj0YJJm0+We+nTg/qp6sKqeA64D3jPsWJJm1aXUxwCP7vb5tunXXiLJmiQbk2x8nmf7mk/SfurtRFlVrauq06rqtGW8qq9YSfupS6m3A8fu9vnq6dckLUJdSv194I1Jjk+yHLgI+PqwY0ma1T5fpVVVLyT5EPBNYAlwZVXdM/hkkmbS6aWXVXUzcPPAs0jqgc8okxpjqaXGWGqpMZZaaoyllhqTPt9LK8m7gXcfefhrP/C3f/33veUCHL76UH657fFeM3flHnnEjt5zn3rmKA4+6GejyB3TrGPLfeqZo9jxo/63iX7s4x/nifrFHreJ9lrqXV6d11Tv20Q/O9AGyc9ewPv/9O96z71jy1rOOumLo8gd06xjy71jy1r+5rzv9JoJk22ir1RqD7+lxlhqqTGWWmqMpZYaY6mlxlhqqTFdFg9emWRnki0HYiBJ8+lyT/1PwPkDzyGpJ/ssdVXdAfziAMwiqQe9vT91kjXAGoAVrOwrVtJ+cpuo1BjPfkuNsdRSY7r8Suta4N+AE5JsS/L+4ceSNKsuK4IvPhCDSOqHh99SYyy11BhLLTXGUkuNsdRSY9wm6jbRUc06tly3ie6F20SHyx3TrGPLdZuopLlZaqkxllpqjKWWGmOppcZYaqkxXV56eWyS25Pcm+SeJGsPxGCSZtNlR9kLwMeqalOSQ4A7k9xWVfcOPJukGXTZJrqjqjZNLz8JbAWOGXowSbPZr22iSY4DTgE27OF7bhOVFoHOJ8qSHAx8Dbisqp54+ffdJiotDp1KnWQZk0JfXVU3DDuSpHl0Ofsd4Apga1V9fviRJM2jyz31mcB7gbOTbJ5+vGPguSTNqMs20e8Ce3yJl6TFx2eUSY2x1FJjLLXUGEstNcZSS41xm6jbREc169hy3Sa6F24THS53TLOOLddtopLmZqmlxlhqqTGWWmqMpZYa0+WllyuSfC/J3dPFg585EINJmk2XdUbPAmdX1VPTZQnfTfKNqvr3gWeTNIMuL70s4Knpp8umH/3/cltSL7quM1qSZDOwE7itqva4eDDJxiQbn+fZvueU1FGnUlfVi1V1MrAaOD3JSXu4josHpUVgv85+V9WvgNuB84cZR9K8upz9fm2Sw6aXDwLOBe4bejBJs+ly9vto4KokS5jcCFxfVTcNO5akWXU5+/0fTN6VQ9II+IwyqTGWWmqMpZYaY6mlxlhqqTEuHnTx4KhmHVuuiwf3wsWDw+WOadax5bp4UNLcLLXUGEstNcZSS42x1FJjLLXUmM6lnq40uiuJL7uUFrH9uadeC2wdahBJ/ei6eHA18E7g8mHHkTSvrvfUXwA+Afz6la7gNlFpceiyo+xdwM6qunNv13ObqLQ4dLmnPhO4MMnDwHXA2Um+MuhUkma2z1JX1aeqanVVHQdcBHyrqi4ZfDJJM/H31FJjuqwI/l9V9W3g24NMIqkX3lNLjbHUUmMstdQYSy01xlJLjXGbqNtEeeqZo3jomVd8BvDMjlq2nMcf/GXvuUP+LPSdO9SsbhPdR67bRNdy6T39P1//o0ev5usXre89d8ifhb5zh5rVbaLS/yOWWmqMpZYaY6mlxlhqqTGWWmpMp1dpTRckPAm8CLxQVacNOZSk2e3PSy//oKp+Ptgkknrh4bfUmK6lLuDWJHcmWbOnK7hNVFocuh5+v7Wqtid5HXBbkvuq6o7dr1BV64B1MHmaaM9zSuqo0z11VW2f/rkTuBE4fcihJM2uy97vVUkO2XUZOA/YMvRgkmbT5fD7KODGJLuuf01V3TLoVJJmts9SV9WDwJsOwCySeuCvtKTGWGqpMZZaaoyllhpjqaXGuE3U3FHNOrZct4nuxZg2SI4td0yzji3XbaKS5mappcZYaqkxllpqjKWWGmOppcZ0KnWSw5KsT3Jfkq1J3jz0YJJm03Wd0ReBW6rqj5IsB1YOOJOkOeyz1EkOBc4C3gdQVc8Bzw07lqRZdTn8Ph54DPhykruSXD5da/QSbhOVFocupV4KnAp8qapOAZ4GPvnyK1XVuqo6rapOW8areh5TUlddSr0N2FZVG6afr2dSckmL0D5LXVU/BR5NcsL0S+cA9w46laSZdT37/WHg6umZ7weBS4cbSdI8OpW6qjYDvtOlNAI+o0xqjKWWGmOppcZYaqkxllpqjNtEzR3VrGPLdZvoXoxpg+TYcsc069hy3SYqaW6WWmqMpZYaY6mlxlhqqTH7LHWSE5Js3u3jiSSXHYjhJO2/fb5Kq6p+CJwMkGQJsB24ceC5JM1ofw+/zwEeqKofDzGMpPl1XZKwy0XAtXv6RpI1wBqAFW4QlhZM53vq6daTC4Gv7un7Lh6UFof9Ofy+ANhUVT8bahhJ89ufUl/MKxx6S1o8ur6X1irgXOCGYceRNK+uiwefBo4YeBZJPfAZZVJjLLXUGEstNcZSS42x1FJjXDxo7qhmHVuuiwf3YkzL5saWO6ZZx5br4kFJc7PUUmMstdQYSy01xlJLjbHUUmO6vvTyI0nuSbIlybVJVgw9mKTZdFkRfAzwF8BpVXUSsITJrjJJi1DXw++lwEFJlgIrgZ8MN5Kkeeyz1FW1Hfgc8AiwA3i8qm59+fWSrEmyMcnG53m2/0klddLl8Ptw4D3A8cDrgVVJLnn59dwmKi0OXQ6/3w48VFWPVdXzTPaUvWXYsSTNqkupHwHOSLIySZi8S8fWYceSNKsuj6k3AOuBTcAPpn9n3cBzSZpR122inwY+PfAsknrgM8qkxlhqqTGWWmqMpZYaY6mlxrhN1NxRzTq2XLeJ7sWYNkiOLXdMs44t122ikuZmqaXGWGqpMZZaaoyllhpjqaXGdN0muna6SfSeJJcNPZSk2XVZZ3QS8AHgdOBNwLuSvGHowSTNpss99e8AG6rqv6rqBeA7wB8OO5akWXUp9RbgbUmOSLISeAdw7Muv5DZRaXHY5+aTqtqa5B+AW4Gngc3Ai3u43jqma45endf0/9xTSZ10OlFWVVdU1e9W1VnAL4EfDTuWpFl12lGW5HVVtTPJbzJ5PH3GsGNJmlWnUgNfS3IE8Dzwwar61YAzSZpD122ibxt6EEn98BllUmMstdQYSy01xlJLjbHUUmMG2SYK/Anwnx3+ypHAzzvGHwp0Xcto7rhmHVvuYpj1jVV16B6/U1UL9gFs3I/rrjO3e+6YZh1b7mKfdUyH3/9s7mC5Y5p1bLkHfNbRlLqqBvnHMXdcs44tdyFmXehSD/Xm9eaOa9ax5S7qWQd5hw5JC2eh76kl9cxSS41ZsFInOT/JD5Pcn+STPWVemWRnki195E0zj01ye5J7p9tU1/aUuyLJ95LcPc39TB+5u+UvSXJXkpt6zHw4yQ+SbE6ysafMw5KsT3Jfkq1J3txD5gnTGXd9PNHXFtwkH5n+f21Jcm2SFT3l9rext+vvxfr8AJYADwC/DSwH7gZO7CH3LOBUYEuPsx4NnDq9fAiTrS99zBrg4OnlZcAG4Iwe5/4ocA1wU4+ZDwNH9vyzcBXwZ9PLy4HDBvhZ+ynwWz1kHQM8BBw0/fx64H095J7EZBfgSiYvh/4X4A2z5i3UPfXpwP1V9WBVPQdcB7xn3tCqugP4xbw5L8vcUVWbppefBLYy+c+dN7eq6qnpp8umH72ctUyyGngncHkfeUNJciiTG+IrAKrquep/Acc5wANV9eOe8pYCByVZyqSEP+khs9eNvQtV6mOAR3f7fBs9FGVoSY4DTmFyr9pH3pIkm4GdwG1V1Usu8AXgE8Cve8rbpYBbk9yZZE0PeccDjwFfnj5UuDzJqh5yd3cRcG0fQVW1Hfgc8AiwA3i8qm7tIbrTxt6uPFHWUZKDga8Bl1XVE31kVtWLVXUysBo4ffrGCXNJ8i5gZ1XdOfeA/9dbq+pU4ALgg0nOmjNvKZOHS1+qqlOYbKvt5fwKQJLlwIXAV3vKO5zJEeXxwOuBVUkumTe3qrYCuzb23sIrbOztaqFKvZ2X3hKtnn5tUUqyjEmhr66qG/rOnx5y3g6c30PcmcCFSR5m8rDm7CRf6SF31z0VVbUTuJHJw6h5bAO27XaEsp5JyftyAbCpqn7WU97bgYeq6rGqeh64AXhLH8HV48behSr194E3Jjl+emt6EfD1BZplr5KEyWO+rVX1+R5zX5vksOnlg4Bzgfvmza2qT1XV6qo6jsm/67eqau57kySrkhyy6zJwHpPDxnlm/SnwaJITpl86B7h3rkFf6mJ6OvSeegQ4I8nK6c/FOUzOscwtyeumf+7a2HvNrFldt4n2qqpeSPIh4JtMzk5eWVX3zJub5Frg94Ejk2wDPl1VV8wZeybwXuAH08e/AH9VVTfPmXs0cFWSJUxuXK+vqt5+/TSAo4AbJz/LLAWuqapbesj9MHD19Mb9QeDSHjJ33fCcC/x5H3kAVbUhyXpgE/ACcBf9PWW0t429Pk1UaownyqTGWGqpMZZaaoyllhpjqaXGWGqpMZZaasz/AKlGbUA+kPk+AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "model = tf.keras.Model(inputs=inputs, outputs=h)\n", + "\n", + "plot_receptive_field(model, data)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOVElEQVR4nO3db4yV5Z3G8e+1Azj8K1BRQhl3ZVND1pBUXcLa2pJdqV1orSbNpoupTWrasi/aLra6jd19Yfpm/6RN077YNCFq16SKsSiJNdbitnZJk11aQKwDg11FKlAUjVbEtfKnv31xntkgOzD3nHPfzHnuXp9kwpmZw5VfYK5znvPMOb+jiMDM6vEHkz2AmeXlUptVxqU2q4xLbVYZl9qsMlNKhE7TeTHIzKyZcxfO5sILXsyaCXD0t/OZNfjy73Vum2ZtW26pWfftP87Lr5zUWN8rUupBZvJnWpk18+O3rObTn/inrJkAW4bXsWLpt36vc9s0a9tyS826/C/3n/F7Pvw2q4xLbVYZl9qsMi61WWVcarPKuNRmlUkqtaRVkp6W9Iyk20oPZWbdG7fUkgaAfwVWA5cCN0i6tPRgZtadlHvq5cAzEbE3Io4B9wHXlx3LzLqVUupFwKlPXznQfO1tJK2VtE3StuO8lWs+M5ugbCfKImJ9RCyLiGVTOS9XrJlNUEqpDwIXnfL5UPM1M+tDKaX+OXCJpMWSpgFrgIfKjmVm3Rr3VVoRcULS54EfAgPAXRGxq/hkZtaVpJdeRsQjwCOFZzGzDPyMMrPKuNRmlXGpzSrjUptVxqU2q4xyvpeWpI8CH50/74LP/uM//HO2XIB5Q3OYf/6hrJkAR99cwKzpBbaUtii3TbO2LbfUrLfecivbnvztmNtEs5Z61Dv0zsi+TfRr3iZaKrdNs7Ytt+Q20TOV2offZpVxqc0q41KbVcalNquMS21WGZfarDIpiwfvknRY0vC5GMjMepNyT/1vwKrCc5hZJuOWOiK2AK+cg1nMLINs708taS2wFmCQGblizWyCvE3UrDI++21WGZfarDIpv9LaAPwnsETSAUmfLj+WmXUrZUXwDediEDPLw4ffZpVxqc0q41KbVcalNquMS21WGW8TbdFmylK5bZq1bbneJnoW3iZaLrdNs7Yt19tEzaxnLrVZZVxqs8q41GaVcanNKuNSm1Um5aWXF0l6XNJuSbskrTsXg5lZd1J2lJ0AbomIHZJmA9slPRYRuwvPZmZdSNkmeigidjSXXwdGgEWlBzOz7kxom6iki4HLga1jfM/bRM36QPKJMkmzgAeAmyPiyOnf9zZRs/6QVGpJU+kU+p6IeLDsSGbWi5Sz3wLuBEYi4hvlRzKzXqTcU18FfBK4WtLO5uPDhecysy6lbBP9KTDmS7zMrP/4GWVmlXGpzSrjUptVxqU2q4xLbVYZbxNt0WbKUrltmrVtud4mehbeJlout02zti3X20TNrGcutVllXGqzyrjUZpVxqc0qk/LSy0FJP5P0ZLN48KvnYjAz607KOqO3gKsj4mizLOGnkn4QEf9VeDYz60LKSy8DONp8OrX5yP/LbTPLInWd0YCkncBh4LGIGHPxoKRtkrYd563cc5pZoqRSR8TJiLgMGAKWS1o6xnW8eNCsD0zo7HdE/AZ4HFhVZhwz61XK2e8LJM1tLk8HrgH2lB7MzLqTcvZ7IXC3pAE6NwL3R8TDZccys26lnP3+BZ135TCzFvAzyswq41KbVcalNquMS21WGZfarDJePNiiJXalcts0a9tyvXjwLLx4sFxum2ZtW64XD5pZz1xqs8q41GaVcanNKuNSm1XGpTarTHKpm5VGT0jyyy7N+thE7qnXASOlBjGzPFIXDw4BHwHuKDuOmfUq9Z76m8CXgd+d6QreJmrWH1J2lF0LHI6I7We7nreJmvWHlHvqq4DrJO0D7gOulvTdolOZWdfGLXVEfCUihiLiYmAN8OOIuLH4ZGbWFf+e2qwyKSuC/09E/AT4SZFJzCwL31ObVcalNquMS21WGZfarDIutVllvE20RZspS+UefXMBz715xmcAd23B1Gm8tvfV7Lnzhubw6oHXWpFbatZbbr2VI/GKt4mOpU2bKUvlbhlex0278j9f/0sLh3hozcbsuR//2mru/7sftCK31Kxb40dnLLUPv80q41KbVcalNquMS21WGZfarDIutVllkl6l1SxIeB04CZyIiGUlhzKz7k3kpZd/EREvF5vEzLLw4bdZZVJLHcBmSdslrR3rCt4matYfUg+/3x8RByVdCDwmaU9EbDn1ChGxHlgPnaeJZp7TzBIl3VNHxMHmz8PAJmB5yaHMrHspe79nSpo9ehn4EDBcejAz607K4fcCYJOk0evfGxGPFp3KzLo2bqkjYi/wnnMwi5ll4F9pmVXGpTarjEttVhmX2qwyLrVZZVq1TbQtGyTbltumWduW622iZ9GmDZJty23TrG3L9TZRM+uZS21WGZfarDIutVllXGqzyrjUZpVJKrWkuZI2StojaUTSe0sPZmbdSV1n9C3g0Yj4K0nTgBkFZzKzHoxbaklzgBXApwAi4hhwrOxYZtatlMPvxcBLwHckPSHpjmat0dt4m6hZf0gp9RTgCuDbEXE58AZw2+lXioj1EbEsIpZN5bzMY5pZqpRSHwAORMTW5vONdEpuZn1o3FJHxAvAfklLmi+tBHYXncrMupZ69vsLwD3Nme+9wE3lRjKzXiSVOiJ2An6nS7MW8DPKzCrjUptVxqU2q4xLbVYZl9qsMt4m6txWzdq2XG8TPYs2bZBsW26bZm1brreJmlnPXGqzyrjUZpVxqc0q41KbVWbcUktaImnnKR9HJN18LoYzs4kb91VaEfE0cBmApAHgILCp8Fxm1qWJHn6vBJ6NiF+VGMbMepe6JGHUGmDDWN+QtBZYCzDoDcJmkyb5nrrZenId8L2xvu/Fg2b9YSKH36uBHRHxYqlhzKx3Eyn1DZzh0NvM+kfqe2nNBK4BHiw7jpn1KnXx4BvA+YVnMbMM/Iwys8q41GaVcanNKuNSm1XGpTarjBcPOrdVs7Yt14sHz6JNy+baltumWduW68WDZtYzl9qsMi61WWVcarPKuNRmlXGpzSqT+tLLL0raJWlY0gZJg6UHM7PupKwIXgT8LbAsIpYCA3R2lZlZH0o9/J4CTJc0BZgB/LrcSGbWi3FLHREHga8DzwOHgNciYvPp15O0VtI2SduO81b+Sc0sScrh9zzgemAx8C5gpqQbT7+et4ma9YeUw+8PAs9FxEsRcZzOnrL3lR3LzLqVUurngSslzZAkOu/SMVJ2LDPrVspj6q3ARmAH8FTzd9YXnsvMupS6TfR24PbCs5hZBn5GmVllXGqzyrjUZpVxqc0q41KbVcbbRJ3bqlnbluttomfRpg2Sbctt06xty/U2UTPrmUttVhmX2qwyLrVZZVxqs8q41GaVSd0muq7ZJLpL0s2lhzKz7qWsM1oKfBZYDrwHuFbSu0sPZmbdSbmn/hNga0T8T0ScAP4D+FjZscysWymlHgY+IOl8STOADwMXnX4lbxM16w/jbj6JiBFJ/wJsBt4AdgInx7jeepo1R+/QO/M/99TMkiSdKIuIOyPiTyNiBfAq8MuyY5lZt5J2lEm6MCIOS/pDOo+nryw7lpl1K6nUwAOSzgeOA5+LiN8UnMnMepC6TfQDpQcxszz8jDKzyrjUZpVxqc0q41KbVcalNqtMkW2iwF8D/53wV+YDLyfGzwFS1zI6t12zti23H2a9JCLmjPmdiJi0D2DbBK673rnpuW2atW25/T5rmw6/v+/cYrltmrVtued81taUOiKK/OM4t12zti13Mmad7FKXevN657Zr1rbl9vWsRd6hw8wmz2TfU5tZZi61WWUmrdSSVkl6WtIzkm7LlHmXpMOShnPkNZkXSXpc0u5mm+q6TLmDkn4m6ckm96s5ck/JH5D0hKSHM2buk/SUpJ2StmXKnCtpo6Q9kkYkvTdD5pJmxtGPI7m24Er6YvP/NSxpg6TBTLn5Nvam/l4s5wcwADwL/DEwDXgSuDRD7grgCmA446wLgSuay7PpbH3JMauAWc3lqcBW4MqMc38JuBd4OGPmPmB+5p+Fu4HPNJenAXML/Ky9APxRhqxFwHPA9Obz+4FPZchdSmcX4Aw6L4f+d+Dd3eZN1j31cuCZiNgbEceA+4Drew2NiC3AK73mnJZ5KCJ2NJdfB0bo/Of2mhsRcbT5dGrzkeWspaQh4CPAHTnySpE0h84N8Z0AEXEs8i/gWAk8GxG/ypQ3BZguaQqdEv46Q2bWjb2TVepFwP5TPj9AhqKUJuli4HI696o58gYk7QQOA49FRJZc4JvAl4HfZcobFcBmSdslrc2Qtxh4CfhO81DhDkkzM+Seag2wIUdQRBwEvg48DxwCXouIzRmikzb2pvKJskSSZgEPADdHxJEcmRFxMiIuA4aA5c0bJ/RE0rXA4YjY3vOA/9/7I+IKYDXwOUkresybQufh0rcj4nI622qznF8BkDQNuA74Xqa8eXSOKBcD7wJmSrqx19yIGAFGN/Y+yhk29qaarFIf5O23REPN1/qSpKl0Cn1PRDyYO7855HwcWJUh7irgOkn76DysuVrSdzPkjt5TERGHgU10Hkb14gBw4JQjlI10Sp7LamBHRLyYKe+DwHMR8VJEHAceBN6XIzgybuydrFL/HLhE0uLm1nQN8NAkzXJWkkTnMd9IRHwjY+4FkuY2l6cD1wB7es2NiK9ExFBEXEzn3/XHEdHzvYmkmZJmj14GPkTnsLGXWV8A9kta0nxpJbC7p0Hf7gYyHXo3ngeulDSj+blYSeccS88kXdj8Obqx995us1K3iWYVESckfR74IZ2zk3dFxK5ecyVtAP4cmC/pAHB7RNzZY+xVwCeBp5rHvwB/HxGP9Ji7ELhb0gCdG9f7IyLbr58KWABs6vwsMwW4NyIezZD7BeCe5sZ9L3BThszRG55rgL/JkQcQEVslbQR2ACeAJ8j3lNFsG3v9NFGzyvhEmVllXGqzyrjUZpVxqc0q41KbVcalNquMS21Wmf8FgMVbSvHwSOgAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "model = tf.keras.Model(inputs=inputs, outputs=h)\n", + "\n", + "plot_receptive_field(model, data)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOH0lEQVR4nO3da6xldX3G8e/TmaHDTe4QZGihkZAaEoFOpihKWhALipA0TQuJJpoW+kItqNRg+8L6pmlTY/RFYzIBLAm34ACJGuTSiiUm7egwDHVgsOUmzAgMFAWxKBd/fbH3NAOdy5q915pz1j/fT3Iy+5yz58kvM+fZe+119v7tVBWS2vFrCz2ApH5ZaqkxllpqjKWWGmOppcYsHSL08EOX1HHHLus186VfHM4By5/rNdPc4TLNHS4T4PEnX+W551/Pjr43SKmPO3YZ37vj2F4z79l4KWec9OVeM80dLtPc4TIBVv3Bkzv9noffUmMstdQYSy01xlJLjbHUUmMstdSYTqVOck6SHyZ5OMkVQw8laXa7LXWSJcA/AucCbwcuSvL2oQeTNJsu99SrgIer6tGqegW4Ebhg2LEkzapLqY8Btn/6yubp194gySVJ1iVZ9+x/v97XfJL2UG8nyqpqdVWtrKqVRxy2pK9YSXuoS6m3ANs/kXvF9GuSFqEupf4+cEKS45PsA1wIfH3YsSTNarev0qqq15J8HLgDWAJcXVUPDD6ZpJl0eullVd0G3DbwLJJ64DPKpMZYaqkxllpqjKWWGmOppcakz/fSSvJB4IPHvPWwi6+95m96ywV46eWjOGDfZ3rNNHe4THOHywS4/NOXs+7+X+xwm2ivpd5m5TuWl9tEx5M7plnHljvkNtGdldrDb6kxllpqjKWWGmOppcZYaqkxllpqTJfFg1cn2Zpk494YSNJ8utxT/xNwzsBzSOrJbktdVfcAz++FWST1oLfH1G4TlRYHt4lKjfHst9QYSy01psuvtG4A/g04McnmJH86/FiSZtVlRfBFe2MQSf3w8FtqjKWWGmOppcZYaqkxllpqjNtEzR3VrGPLdZvoLoxpg+TYcsc069hy3SYqaW6WWmqMpZYaY6mlxlhqqTGWWmpMl5deHpvk7iQPJnkgyaV7YzBJs9ntSy+B14BPV9X6JAcC9ya5q6oeHHg2STPosk30qapaP738M2ATcMzQg0mazR49pk5yHHAKsHYH33ObqLQIdC51kgOAm4HLqurFN3/fbaLS4tCp1EmWMSn0dVV1y7AjSZpHl7PfAa4CNlXVF4cfSdI8utxTnw58GDgzyYbpx/sHnkvSjLpsE/0usMOXeElafHxGmdQYSy01xlJLjbHUUmMstdQYt4maO6pZx5brNtFdGNMGybHljmnWseW6TVTS3Cy11BhLLTXGUkuNsdRSY7q89HJ5ku8luX+6ePDze2MwSbPpsnjwl8CZVfXSdFnCd5N8q6r+feDZJM2gy0svC3hp+umy6Uf/v9yW1Iuu64yWJNkAbAXuqioXD0qLVKdSV9XrVXUysAJYleSkHVzHxYPSIrBHZ7+r6qfA3cA5w4wjaV5dzn4fkeTg6eV9gbOBh4YeTNJsupz9Phq4JskSJjcCN1XVN4cdS9Ksupz9/g8m78ohaQR8RpnUGEstNcZSS42x1FJjLLXUGBcPmjuqWceW6+LBXRjTsrmx5Y5p1rHlunhQ0twstdQYSy01xlJLjbHUUmMstdSYzqWerjS6L4kvu5QWsT25p74U2DTUIJL60XXx4ArgA8CVw44jaV5d76m/BHwG+NXOruA2UWlx6LKj7Dxga1Xdu6vruU1UWhy63FOfDpyf5HHgRuDMJNcOOpWkme221FX12apaUVXHARcC366qDw0+maSZ+HtqqTFdVgT/n6r6DvCdQSaR1AvvqaXGWGqpMZZaaoyllhpjqaXGuE3UXF56+Sgee3mnzwCe2VHL9uGFR3/Se+4hKw7iJ5tfGEXuULN++vLLebGed5uouTvP/OgDv+w1E+BTR6/g6xeu6T33j//hXG76y2+NIneoWdfWv+y01B5+S42x1FJjLLXUGEstNcZSS42x1FJjOr1Ka7og4WfA68BrVbVyyKEkzW5PXnr5+1X13GCTSOqFh99SY7qWuoA7k9yb5JIdXcFtotLi0PXw+91VtSXJkcBdSR6qqnu2v0JVrQZWw+Rpoj3PKamjTvfUVbVl+udW4FZg1ZBDSZpdl73f+yc5cNtl4H3AxqEHkzSbLoffRwG3Jtl2/eur6vZBp5I0s92WuqoeBd6xF2aR1AN/pSU1xlJLjbHUUmMstdQYSy01ZpBtoocfcsTFf/vXf9dbLoxrg+TYcsc069hym9km+pYcWr+bs3rNHNMGybHljmnWseW6TVTS3Cy11BhLLTXGUkuNsdRSYyy11JhOpU5ycJI1SR5KsinJO4ceTNJsuq4z+jJwe1X9UZJ9gP0GnEnSHHZb6iQHAWcAHwGoqleAV4YdS9Ksuhx+Hw88C3w1yX1JrpyuNXqD7beJvkr/73UsqZsupV4KnAp8papOAX4OXPHmK1XV6qpaWVUrl/HrPY8pqasupd4MbK6qtdPP1zApuaRFaLelrqqngSeTnDj90lnAg4NOJWlmXc9+fwK4bnrm+1Hgo8ONJGkenUpdVRsA3+lSGgGfUSY1xlJLjbHUUmMstdQYSy01xm2i5o5q1rHluk10F8a0QXJsuWOadWy5bhOVNDdLLTXGUkuNsdRSYyy11JjdljrJiUk2bPfxYpLL9sZwkvbcbl+lVVU/BE4GSLIE2ALcOvBckma0p4ffZwGPVNWPhhhG0vy6LknY5kLghh19I8klwCUAy90gLC2YzvfU060n5wNf29H3XTwoLQ57cvh9LrC+qp4ZahhJ89uTUl/ETg69JS0eXd9La3/gbOCWYceRNK+uiwd/Dhw28CySeuAzyqTGWGqpMZZaaoyllhpjqaXGuHjQ3FHNOrZcFw/uwpiWzY0td0yzji3XxYOS5mappcZYaqkxllpqjKWWGmOppcZ0fenlJ5M8kGRjkhuSLB96MEmz6bIi+BjgL4CVVXUSsITJrjJJi1DXw++lwL5JlgL7AT8ebiRJ89htqatqC/AF4AngKeCFqrrzzddLckmSdUnWvcov+59UUiddDr8PAS4AjgfeCuyf5ENvvp7bRKXFocvh93uBx6rq2ap6lcmesncNO5akWXUp9RPAaUn2SxIm79KxadixJM2qy2PqtcAaYD3wg+nfWT3wXJJm1HWb6OeAzw08i6Qe+IwyqTGWWmqMpZYaY6mlxlhqqTFuEzV3VLOOLddtorswpg2SY8sd06xjy3WbqKS5WWqpMZZaaoyllhpjqaXGWGqpMV23iV463ST6QJLLhh5K0uy6rDM6CbgYWAW8AzgvyduGHkzSbLrcU/82sLaq/qeqXgP+FfjDYceSNKsupd4IvCfJYUn2A94PHPvmK7lNVFocdrv5pKo2Jfl74E7g58AG4PUdXG810zVHb8mh/T/3VFInnU6UVdVVVfU7VXUG8BPgP4cdS9KsOu0oS3JkVW1N8htMHk+fNuxYkmbVqdTAzUkOA14FPlZVPx1wJklz6LpN9D1DDyKpHz6jTGqMpZYaY6mlxlhqqTGWWmrMINtEgT8B/qvDXzkceK5j/EFA17WM5o5r1rHlLoZZT6iqg3b4napasA9g3R5cd7W53XPHNOvYchf7rGM6/P6GuYPljmnWseXu9VlHU+qqGuQfx9xxzTq23IWYdaFLPdSb15s7rlnHlruoZx3kHTokLZyFvqeW1DNLLTVmwUqd5JwkP0zycJIresq8OsnWJBv7yJtmHpvk7iQPTrepXtpT7vIk30ty/zT3833kbpe/JMl9Sb7ZY+bjSX6QZEOSdT1lHpxkTZKHkmxK8s4eMk+czrjt48W+tuAm+eT0/2tjkhuSLO8pt7+NvV1/L9bnB7AEeAT4LWAf4H7g7T3kngGcCmzscdajgVOnlw9ksvWlj1kDHDC9vAxYC5zW49yfAq4Hvtlj5uPA4T3/LFwD/Nn08j7AwQP8rD0N/GYPWccAjwH7Tj+/CfhID7knMdkFuB+Tl0P/M/C2WfMW6p56FfBwVT1aVa8ANwIXzBtaVfcAz8+b86bMp6pq/fTyz4BNTP5z582tqnpp+umy6UcvZy2TrAA+AFzZR95QkhzE5Ib4KoCqeqX6X8BxFvBIVf2op7ylwL5JljIp4Y97yOx1Y+9ClfoY4MntPt9MD0UZWpLjgFOY3Kv2kbckyQZgK3BXVfWSC3wJ+Azwq57ytingziT3Jrmkh7zjgWeBr04fKlyZZP8ecrd3IXBDH0FVtQX4AvAE8BTwQlXd2UN0p429XXmirKMkBwA3A5dV1Yt9ZFbV61V1MrACWDV944S5JDkP2FpV98494P/37qo6FTgX+FiSM+bMW8rk4dJXquoUJttqezm/ApBkH+B84Gs95R3C5IjyeOCtwP5JPjRvblVtArZt7L2dnWzs7WqhSr2FN94SrZh+bVFKsoxJoa+rqlv6zp8ect4NnNND3OnA+UkeZ/Kw5swk1/aQu+2eiqraCtzK5GHUPDYDm7c7QlnDpOR9ORdYX1XP9JT3XuCxqnq2ql4FbgHe1Udw9bixd6FK/X3ghCTHT29NLwS+vkCz7FKSMHnMt6mqvthj7hFJDp5e3hc4G3ho3tyq+mxVraiq45j8u367qua+N0myf5IDt10G3sfksHGeWZ8Gnkxy4vRLZwEPzjXoG11ET4feU08ApyXZb/pzcRaTcyxzS3Lk9M9tG3uvnzWr6zbRXlXVa0k+DtzB5Ozk1VX1wLy5SW4Afg84PMlm4HNVddWcsacDHwZ+MH38C/BXVXXbnLlHA9ckWcLkxvWmqurt108DOAq4dfKzzFLg+qq6vYfcTwDXTW/cHwU+2kPmthues4E/7yMPoKrWJlkDrAdeA+6jv6eM9rax16eJSo3xRJnUGEstNcZSS42x1FJjLLXUGEstNcZSS435X6ZMf6JHvKMCAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", + "v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", + "model = tf.keras.Model(inputs=inputs, outputs=h)\n", + "\n", + "plot_receptive_field(model, data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As you can notice, the Gated PixelCNN does not create blind spots when adding more and more layers." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "Gated Receptive fields.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/WIP/4 - Gated_PixelCNN/Gated_Receptive_fields.ipynb b/WIP/4 - Gated_PixelCNN/Gated_Receptive_fields.ipynb deleted file mode 100644 index c068cee..0000000 --- a/WIP/4 - Gated_PixelCNN/Gated_Receptive_fields.ipynb +++ /dev/null @@ -1,321 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Gated Receptive fields.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "YqTKIYLooHsq", - "colab_type": "text" - }, - "source": [ - "# Comparing receptive fields" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "gf5wwqP3ozaN", - "colab_type": "code", - "colab": {} - }, - "source": [ - "import random as rn\n", - "\n", - "import matplotlib\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "from tensorflow import nn\n", - "from tensorflow.keras import initializers\n", - "from tensorflow.keras.utils import Progbar" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "jEkll1yno2Vb", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Defining random seeds\n", - "random_seed = 42\n", - "tf.random.set_seed(random_seed)\n", - "np.random.seed(random_seed)\n", - "rn.seed(random_seed)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "yJ_JlzWco7ci", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class MaskedConv2D(keras.layers.Layer):\n", - " \"\"\"Convolutional layers with masks for Gated PixelCNN.\n", - "\n", - " Masked convolutional layers used to implement Vertical and Horizontal\n", - " stacks of the Gated PixelCNN.\n", - "\n", - " Note: This implementation is different from the normal PixelCNN.\n", - "\n", - " Arguments:\n", - " mask_type: one of `\"V\"`, `\"A\"` or `\"B\".`\n", - " filters: Integer, the dimensionality of the output space\n", - " (i.e. the number of output filters in the convolution).\n", - " kernel_size: An integer or tuple/list of 2 integers, specifying the\n", - " height and width of the 2D convolution window.\n", - " Can be a single integer to specify the same value for\n", - " all spatial dimensions.\n", - " strides: An integer or tuple/list of 2 integers,\n", - " specifying the strides of the convolution along the height and width.\n", - " Can be a single integer to specify the same value for\n", - " all spatial dimensions.\n", - " Specifying any stride value != 1 is incompatible with specifying\n", - " any `dilation_rate` value != 1.\n", - " padding: one of `\"valid\"` or `\"same\"` (case-insensitive).\n", - " kernel_initializer: Initializer for the `kernel` weights matrix.\n", - " bias_initializer: Initializer for the bias vector.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " mask_type,\n", - " filters,\n", - " kernel_size,\n", - " strides=1,\n", - " padding='same',\n", - " kernel_initializer='glorot_uniform',\n", - " bias_initializer='zeros'):\n", - " super(MaskedConv2D, self).__init__()\n", - "\n", - " assert mask_type in {'A', 'B', 'V'}\n", - " self.mask_type = mask_type\n", - "\n", - " self.filters = filters\n", - "\n", - " if isinstance(kernel_size, int):\n", - " kernel_size = (kernel_size, kernel_size)\n", - " self.kernel_size = kernel_size\n", - "\n", - " self.strides = strides\n", - " self.padding = padding.upper()\n", - " self.kernel_initializer = initializers.get(kernel_initializer)\n", - " self.bias_initializer = initializers.get(bias_initializer)\n", - "\n", - " def build(self, input_shape):\n", - " kernel_h, kernel_w = self.kernel_size\n", - "\n", - " self.kernel = self.add_weight('kernel',\n", - " shape=(kernel_h,\n", - " kernel_w,\n", - " int(input_shape[-1]),\n", - " self.filters),\n", - " initializer=self.kernel_initializer,\n", - " trainable=True)\n", - "\n", - " self.bias = self.add_weight('bias',\n", - " shape=(self.filters,),\n", - " initializer=self.bias_initializer,\n", - " trainable=True)\n", - "\n", - " mask = np.ones(self.kernel.shape, dtype=np.float32)\n", - "\n", - " if kernel_h % 2 != 0: \n", - " center_h = kernel_h // 2\n", - " else:\n", - " center_h = (kernel_h - 1) // 2\n", - "\n", - " if kernel_w % 2 != 0: \n", - " center_w = kernel_w // 2\n", - " else:\n", - " center_w = (kernel_w - 1) // 2\n", - "\n", - " if self.mask_type == 'V':\n", - " mask[center_h + 1:, :, :, :] = 0.\n", - " else:\n", - " mask[:center_h, :, :] = 0.\n", - " mask[center_h, center_w + (self.mask_type == 'B'):, :, :] = 0.\n", - " mask[center_h + 1:, :, :] = 0. \n", - "\n", - " self.mask = tf.constant(mask, dtype=tf.float32, name='mask')\n", - "\n", - " def call(self, input):\n", - " masked_kernel = tf.math.multiply(self.mask, self.kernel)\n", - " x = nn.conv2d(input,\n", - " masked_kernel,\n", - " strides=[1, self.strides, self.strides, 1],\n", - " padding=self.padding)\n", - " x = nn.bias_add(x, self.bias)\n", - " return x" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "K5topys_HW7-", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class GatedBlock(tf.keras.Model):\n", - " \"\"\" Gated block of the Gated PixelCNN.\"\"\"\n", - "\n", - " def __init__(self, mask_type, filters, kernel_size):\n", - " super(GatedBlock, self).__init__(name='')\n", - "\n", - " self.mask_type = mask_type\n", - " self.vertical_conv = MaskedConv2D(mask_type='V',\n", - " filters=2 * filters,\n", - " kernel_size=kernel_size)\n", - " \n", - " self.horizontal_conv = MaskedConv2D(mask_type=mask_type,\n", - " filters=2 * filters,\n", - " kernel_size=kernel_size)\n", - "\n", - " self.padding = keras.layers.ZeroPadding2D(padding=((1,0),0))\n", - " self.cropping = keras.layers.Cropping2D(cropping=((0, 1), 0))\n", - "\n", - " self.v_to_h_conv = keras.layers.Conv2D(filters=2 * filters, kernel_size=1)\n", - "\n", - " self.horizontal_output = keras.layers.Conv2D(filters=filters, kernel_size=1)\n", - "\n", - " def _gate(self, x):\n", - " tanh_preactivation, sigmoid_preactivation = tf.split(x, 2, axis=-1)\n", - " return tf.nn.tanh(tanh_preactivation) * tf.nn.sigmoid(sigmoid_preactivation)\n", - "\n", - " def call(self, input_tensor):\n", - " v = input_tensor[0]\n", - " h = input_tensor[1]\n", - "\n", - " vertical_preactivation = self.vertical_conv(v) # NxN\n", - "\n", - " # Shifting feature map down to ensure causality\n", - " v_to_h = self.padding(vertical_preactivation)\n", - " v_to_h = self.cropping(v_to_h)\n", - " v_to_h = self.v_to_h_conv(v_to_h) # 1x1\n", - "\n", - " horizontal_preactivation = self.horizontal_conv(h) # 1xN\n", - " \n", - " v_out = self._gate(vertical_preactivation)\n", - "\n", - " horizontal_preactivation = horizontal_preactivation + v_to_h\n", - " h_activated = self._gate(horizontal_preactivation)\n", - " h_activated = self.horizontal_output(h_activated)\n", - "\n", - " if self.mask_type == 'A':\n", - " h_out = h_activated\n", - " elif self.mask_type == 'B':\n", - " h_out = h + h_activated\n", - "\n", - " return v_out, h_out" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "jxCLMYc-FxdJ", - "colab_type": "code", - "colab": {} - }, - "source": [ - "def plot_receptive_field(model, data):\n", - " out = model(data)\n", - "\n", - " with tf.GradientTape() as tape:\n", - " tape.watch(data)\n", - " prediction = model(data)\n", - " loss = prediction[:,5,5,0]\n", - "\n", - " gradients = tape.gradient(loss, data)\n", - "\n", - " gradients = np.abs(gradients.numpy().squeeze())\n", - " gradients = (gradients > 1e-8).astype('float32')\n", - " gradients[5, 5] = 0.5\n", - "\n", - " plt.figure()\n", - " plt.imshow(gradients)\n", - " plt.title(f'Receptive field from pixel layers')\n", - " plt.show()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "0qpDtNuvo9NL", - "colab_type": "code", - "outputId": "b8142536-f9db-4ce3-ed7e-03741da11a3f", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 281 - } - }, - "source": [ - "inputs = keras.layers.Input(shape=(height, width, n_channel))\n", - "v, h = GatedBlock(mask_type='A', filters=1, kernel_size=3)([inputs, inputs])\n", - "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", - "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", - "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", - "v, h = GatedBlock(mask_type='B', filters=1, kernel_size=3)([v, h])\n", - "model = tf.keras.Model(inputs=inputs, outputs=h)\n", - "\n", - "data = tf.random.normal((1,10,10,1))\n", - "\n", - "plot_receptive_field(model, data)" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAARAUlEQVR4nO3dfbBU9X3H8fdHQBAkCkoegBtgolGJrWJvfKzWKFbjY6ZpWk001alhbKOodeJTp8VJTDOdaKIziSb4gI1QNSVOtNaoiWCMtlIBzSigKUEUFCKoKNpUUL/94/yQ5XL37rn37nKWn5/XzJ3ZPb+z53zPw2fP7/zu3r2KCMwsHztUXYCZNZdDbZYZh9osMw61WWYcarPMONRmmXGoS5J0uaQbW7Dcj0h6WNJ6SVf3Zj2SHpJ0dp228ZJC0sA67XtJejKtd2p/tqHZJH1J0gNNWM5ySZPrtN0i6cr+rqMddXvAtzVJy4GPAO8CbwL3AedGxJsV1XMkMDMixm6aFhH/1KLVTQHWAh+KbfuhgYuBuRGx/zZcZykRMQuYVXUd26t2ulKfFBE7A/sDk4DLKq5nWxkHLN7Ggd603kX1GiUN2Ia1fCDU6zU1WzuFGoCIWA3cTxFuACQdLOk/Ja2T9Ot0Jd3UNlLSDEkvSXpN0k9r2k5MXcx16fV/WNO2XNJlkhan182QNETSMOBnwGhJb6af0ZKukDQzvfZnks6trTvV9Wfp8d6Sfi7pVUnPSvqL7rZV0i3AXwEXp/VMrl1Po23vsqwBkq6StFbSMuCEevtY0hzgM8D30no/mbqj10u6V9JbwGck7ZO6+OskLZJ0cm3tkq5L++JNSY9K+qika9L+fEbSpB5qCElTJS1LNX9b0g6p7UxJj6THh6b2jvR8v7T8vdPzuse4LEkjJN0jaU1a9j2Sxqa2L0ha0GX+v5N0V3o8OO33FyT9TtIPJO2U2o6UtFLSJZJWAzMk7Z6Wvy6dH7/atN1NExGV/wDLgcnp8VjgKeDa9HwM8ApwPMWb0DHp+ajU/h/AHcAIYBDwJ2n6JOBl4CBgAEV4lgODa9b5NNABjAQeBa5MbUcCK7vUeAVFlxzgy8CjNW0TgXXAYGAYsAI4i+L2ZhJF93pinW2/ZdN6u1lPo21/CDg7PT4HeKZme+YCAQyss973X1tTx+vAYWldw4GlwOXAjsBRwHpgr5r51wJ/BAwB5gDPpX0zALiSontf75hHqnEk8HHgNzXbcibwSM2830zL3ymdG+f24hhPbrTfgd2AzwND03b/G/DT1DYYeBXYp+a1TwCfT4+/C9ydtmM48O/At2rOo3eAf07L2Qn4FvADinN1EHA4oKbmqepA1+z8N9NJE8CDwK6p7RLg1i7z358O4MeA94AR3SzzeuAbXaY9y+bQLwfOqWk7HvhtyVAPB94CxtWcdDenx38J/KrLa38ITOtDqOtuezehntNle/6U3of6RzXPDwdWAzvUTLsNuKJm/htq2s4DltQ8/wNgXQ/HPIDjap7/LfBgenwmW4Z6ELCAItD3bQpByWPcMNTdtO0PvNblXPpmevwp4DWKkCqdB5+omfcQ4Lma82gDMKSm/evAXcAercpTO3W/PxcRwyl2xN7A7mn6OOALqbuyTtI64I8pAt0BvBoRr3WzvHHARV1e1wGMrplnRc3j57u01RUR6yl6CKemSaexeWBnHHBQl/V+CfhomWV3sw31tr2r0Wy9Pb1V+/rRwIqIeK/LMsfUPP9dzePfd/N8516sr+7+j4iNFCHcF7g6Ujood4wbkjRU0g8lPS/pDeBhYFdtHlf4F+CLkgScAfw4It4GRlFc3RfUrP++NH2TNRHxfzXPv03RA3og3Xpc2ptay2iL0e9aEfHLdK95FfA5igN/a0R8peu8kj4GjJS0a0Ss69K8guLd9Zs9rK6j5vHHgZc2lVGi1NuAaZIepuh+zq1Z7y8j4pgSy2ik7rZ3YxVbb09v1W73S0CHpB1qgr2pm9wsHWwerKvd/1uQNAaYBswArpb06RSqMse4jIuAvYCDImK1pP0putgCiIjHJG2g6L18Mf1Acfvxe+BTEfFinWVvcS6lC8JFFG9G+wJzJD0eEQ/2cxve105X6lrXAMdI2g+YCZwk6dg0GDQkDUCMjYhVFINa16XBjkGSjkjLuAE4R9JBKgyTdIKk4TXr+aqksZJGAn9PcW8OxRVnN0m79FDjvRRXiq8Dd9Sc+PcAn5R0RqpnkKRPS9qnD/uh7rZ3M++Pgalpe0YA/b0CzAP+l2IQb5CKAbqTgNv7udxaX0vHrQM4n837/33p6ngLcBPw1xRvXt9IzWWOcRnDKcK5Lp0L07qZ50fA94CNEfEIQDrmNwDflfThVO8YScfWW1Ea2NsjbdfrFL/Gfa/e/H3RlqGOiDUUO/EfI2IFcArFgM0ainfnr7G59jOAjRSDRC8DF6RlzAe+QnEgXqPo8pzZZVX/CjwALAN+SzG4Q0Q8Q3ElXpa6VVt159KV4k5gclrOpunrKe5nT6W48qxm80BJb/dDo22vdQPF/favgYWptj6LiA0UIf4sxRXpOuDLad80y10U98pPUtzO3NTNPFOBDwP/kLrdZwFnSTq85DEu4xqKQay1wGMUXeiubqXo/s/sMv2StN7HUtf9FxRX/Xr2TPO8CfwXcF1EzO1h/l7bNODwgaPiAy9nR8Qvqq7lg0hSAHtGxNKqaykj/ZrqZeCAiPifquvpSVteqc3a0N8Aj7d7oKENB8rM2k3q1Yli4LbtfWC732a5cvfbLDMt6X7vPnJAjO8Y1IpFmxmwfMVG1r76rrpra0mox3cM4r/v72g8o5n1yYHHrqjb5u63WWYcarPMONRmmXGozTLjUJtlxqE2y0ypUEs6TsV3bS1txR91m1nzNAx1+vaH71P8Cd5E4DRJE1tdmJn1TZkr9YHA0ohYlv7G9naKv/E1szZUJtRj2PK7pFay5fdUASBpiqT5kuaveeXdZtVnZr3UtIGyiJgeEZ0R0TlqN38PvFlVyoT6Rbb8QruxaZqZtaEyoX4c2FPSBEk7Unz31t2tLcvM+qrhX2lFxDsq/sXM/RT/BeHmiKj7P5jMrFql/vQyIu6l+EpcM2tz/kSZWWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmWkYakkdkuZKWixpkaTzt0VhZtY3A0vM8w5wUUQslDQcWCDp5xGxuMW1mVkfNLxSR8SqiFiYHq8HlgBjWl2YmfVNr+6pJY0HJgHzummbImm+pPlrXnm3OdWZWa+VDrWknYGfABdExBtd2yNiekR0RkTnqN0GNLNGM+uFUqGWNIgi0LMi4s7WlmRm/VFm9FvATcCSiPhO60sys/4oc6U+DDgDOErSk+nn+BbXZWZ91PBXWhHxCKBtUIuZNYE/UWaWGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDOlQy1pgKQnJN3TyoLMrH96c6U+H1jSqkLMrDlKhVrSWOAE4MbWlmNm/VX2Sn0NcDHwXr0ZJE2RNF/S/DWvvNuU4sys9xqGWtKJwMsRsaCn+SJiekR0RkTnqN0GNK1AM+udMlfqw4CTJS0HbgeOkjSzpVWZWZ81DHVEXBYRYyNiPHAqMCciTm95ZWbWJ/49tVlmBvZm5oh4CHioJZWYWVP4Sm2WGYfaLDMOtVlmHGqzzDjUZpnp1ei35esTd5zTkuXuceFjLVnuB91v4pW6bb5Sm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhQRTV/ohzQyDtLRTV+umRXmxYO8Ea+quzZfqc0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzJQKtaRdJc2W9IykJZIOaXVhZtY3Zf+V7bXAfRHx55J2BIa2sCYz64eGoZa0C3AEcCZARGwANrS2LDPrqzLd7wnAGmCGpCck3ShpWNeZJE2RNF/S/I283fRCzaycMqEeCBwAXB8Rk4C3gEu7zhQR0yOiMyI6BzG4yWWaWVllQr0SWBkR89Lz2RQhN7M21DDUEbEaWCFprzTpaGBxS6sysz4rO/p9HjArjXwvA85qXUlm1h+lQh0RTwKdLa7FzJrAnygzy4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZaZUqGWdKGkRZKelnSbpCGtLszM+qZhqCWNAaYCnRGxLzAAOLXVhZlZ35Ttfg8EdpI0EBgKvNS6ksysPxqGOiJeBK4CXgBWAa9HxANd55M0RdJ8SfM38nbzKzWzUsp0v0cApwATgNHAMEmnd50vIqZHRGdEdA5icPMrNbNSynS/JwPPRcSaiNgI3Akc2tqyzKyvyoT6BeBgSUMlCTgaWNLassysr8rcU88DZgMLgafSa6a3uC4z66OBZWaKiGnAtBbXYmZN4E+UmWXGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhlFRPMXKq0Bni8x6+7A2qYX0DrbU73bU62wfdXbDrWOi4hR3TW0JNRlSZofEZ2VFdBL21O921OtsH3V2+61uvttlhmH2iwzVYd6e/vn9dtTvdtTrbB91dvWtVZ6T21mzVf1ldrMmsyhNstMZaGWdJykZyUtlXRpVXU0IqlD0lxJiyUtknR+1TWVIWmApCck3VN1LT2RtKuk2ZKekbRE0iFV19QTSRem8+BpSbdJGlJ1TV1VEmpJA4DvA58FJgKnSZpYRS0lvANcFBETgYOBr7ZxrbXOB5ZUXUQJ1wL3RcTewH60cc2SxgBTgc6I2BcYAJxabVVbq+pKfSCwNCKWRcQG4HbglIpq6VFErIqIhenxeoqTbky1VfVM0ljgBODGqmvpiaRdgCOAmwAiYkNErKu2qoYGAjtJGggMBV6quJ6tVBXqMcCKmucrafOgAEgaD0wC5lVbSUPXABcD71VdSAMTgDXAjHSrcKOkYVUXVU9EvAhcBbwArAJej4gHqq1qax4oK0nSzsBPgAsi4o2q66lH0onAyxGxoOpaShgIHABcHxGTgLeAdh5fGUHRo5wAjAaGSTq92qq2VlWoXwQ6ap6PTdPakqRBFIGeFRF3Vl1PA4cBJ0taTnFbc5SkmdWWVNdKYGVEbOr5zKYIebuaDDwXEWsiYiNwJ3BoxTVtpapQPw7sKWmCpB0pBhvurqiWHkkSxT3fkoj4TtX1NBIRl0XE2IgYT7Ff50RE211NACJiNbBC0l5p0tHA4gpLauQF4GBJQ9N5cTRtOLA3sIqVRsQ7ks4F7qcYQbw5IhZVUUsJhwFnAE9JejJNuzwi7q2wppycB8xKb+7LgLMqrqeuiJgnaTawkOK3Ik/Qhh8Z9cdEzTLjgTKzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDP/D4wuQNxPJtmXAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [], - "needs_background": "light" - } - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "I2APaCzDGeqP", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/WIP/rascunho.py b/WIP/rascunho.py index ad76edd..347881b 100644 --- a/WIP/rascunho.py +++ b/WIP/rascunho.py @@ -3,7 +3,7 @@ grad = np.abs(grad) grad = (grad != 0).astype('float64') grad[5, 5, 5] = 0.5 - +from matplotlib.ticker import FixedLocator plt.figure() for i in range(10): axes = plt.subplot(2,5,i+1) @@ -17,4 +17,4 @@ plt.ylabel("out[i,j,:]") plt.imshow(grad[i,:,:],vmin=0, vmax=1) -plt.show() \ No newline at end of file +plt.show()