Skip to content

Commit

Permalink
Add gated pixelcnn new notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Warvito authored and Warvito committed Apr 12, 2020
1 parent 8663539 commit af09a1e
Show file tree
Hide file tree
Showing 4 changed files with 2,331 additions and 0 deletions.
251 changes: 251 additions & 0 deletions WIP/3 -/Receptive_fields.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Receptive fields.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "YqTKIYLooHsq",
"colab_type": "text"
},
"source": [
"# Comparing receptive fields"
]
},
{
"cell_type": "code",
"metadata": {
"id": "wM-m3Z8CiLXU",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "gf5wwqP3ozaN",
"colab_type": "code",
"colab": {}
},
"source": [
"import random as rn\n",
"\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow import nn\n",
"from tensorflow.keras import initializers\n",
"from tensorflow.keras.utils import Progbar"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jEkll1yno2Vb",
"colab_type": "code",
"colab": {}
},
"source": [
"# Defining random seeds\n",
"random_seed = 42\n",
"tf.random.set_seed(random_seed)\n",
"np.random.seed(random_seed)\n",
"rn.seed(random_seed)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "yJ_JlzWco7ci",
"colab_type": "code",
"colab": {}
},
"source": [
"class MaskedConv2D(keras.layers.Layer):\n",
" \"\"\"Convolutional layers with masks.\n",
"\n",
" Convolutional layers with simple implementation of masks type A and B for\n",
" autoregressive models.\n",
"\n",
" Arguments:\n",
" mask_type: one of `\"A\"` or `\"B\".`\n",
" filters: Integer, the dimensionality of the output space\n",
" (i.e. the number of output filters in the convolution).\n",
" kernel_size: An integer or tuple/list of 2 integers, specifying the\n",
" height and width of the 2D convolution window.\n",
" Can be a single integer to specify the same value for\n",
" all spatial dimensions.\n",
" strides: An integer or tuple/list of 2 integers,\n",
" specifying the strides of the convolution along the height and width.\n",
" Can be a single integer to specify the same value for\n",
" all spatial dimensions.\n",
" Specifying any stride value != 1 is incompatible with specifying\n",
" any `dilation_rate` value != 1.\n",
" padding: one of `\"valid\"` or `\"same\"` (case-insensitive).\n",
" kernel_initializer: Initializer for the `kernel` weights matrix.\n",
" bias_initializer: Initializer for the bias vector.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" mask_type,\n",
" filters,\n",
" kernel_size,\n",
" strides=1,\n",
" padding='same',\n",
" kernel_initializer='glorot_uniform',\n",
" bias_initializer='zeros'):\n",
" super(MaskedConv2D, self).__init__()\n",
"\n",
" assert mask_type in {'A', 'B'}\n",
" self.mask_type = mask_type\n",
"\n",
" self.filters = filters\n",
" self.kernel_size = kernel_size\n",
" self.strides = strides\n",
" self.padding = padding.upper()\n",
" self.kernel_initializer = initializers.get(kernel_initializer)\n",
" self.bias_initializer = initializers.get(bias_initializer)\n",
"\n",
" def build(self, input_shape):\n",
" self.kernel = self.add_weight('kernel',\n",
" shape=(self.kernel_size,\n",
" self.kernel_size,\n",
" int(input_shape[-1]),\n",
" self.filters),\n",
" initializer=self.kernel_initializer,\n",
" trainable=True)\n",
"\n",
" self.bias = self.add_weight('bias',\n",
" shape=(self.filters,),\n",
" initializer=self.bias_initializer,\n",
" trainable=True)\n",
"\n",
" center = self.kernel_size // 2\n",
"\n",
" mask = np.ones(self.kernel.shape, dtype=np.float32)\n",
" mask[center, center + (self.mask_type == 'B'):, :, :] = 0.\n",
" mask[center + 1:, :, :, :] = 0.\n",
"\n",
" self.mask = tf.constant(mask, dtype=tf.float32, name='mask')\n",
"\n",
" def call(self, input):\n",
" masked_kernel = tf.math.multiply(self.mask, self.kernel)\n",
" x = nn.conv2d(input,\n",
" masked_kernel,\n",
" strides=[1, self.strides, self.strides, 1],\n",
" padding=self.padding)\n",
" x = nn.bias_add(x, self.bias)\n",
" return x"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jxCLMYc-FxdJ",
"colab_type": "code",
"colab": {}
},
"source": [
"def plot_receptive_field(model, data):\n",
" out = model(data)\n",
"\n",
" with tf.GradientTape() as tape:\n",
" tape.watch(data)\n",
" prediction = model(data)\n",
" loss = prediction[:,5,5,0]\n",
"\n",
" gradients = tape.gradient(loss, data)\n",
"\n",
" gradients = np.abs(gradients.numpy().squeeze())\n",
" gradients = (gradients > 1e-8).astype('float32')\n",
" gradients[5, 5] = 0.5\n",
"\n",
" plt.figure()\n",
" plt.imshow(gradients)\n",
" plt.title(f'Receptive field from pixel layers')\n",
" plt.show()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0qpDtNuvo9NL",
"colab_type": "code",
"outputId": "926434e2-44bf-40d3-a39f-6c80ebc880b6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
}
},
"source": [
"height = 10\n",
"width = 10\n",
"n_channel = 1\n",
"inputs = keras.layers.Input(shape=(height, width, n_channel))\n",
"x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n",
"x = MaskedConv2D(mask_type='B', filters=1, kernel_size=3, strides=1)(x)\n",
"x = MaskedConv2D(mask_type='B', filters=1, kernel_size=3, strides=1)(x)\n",
"\n",
"model = keras.Model(inputs=inputs, outputs=x)\n",
"data = tf.random.normal((1,10,10,1))\n",
"\n",
"plot_receptive_field(model, data)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAARBElEQVR4nO3dfZBV9X3H8fdHWEGQKkbzwEOQiUYltoLdxKdqjWI1PmaaptVEU50axjaKWic+dVqcxDTTiSY6k2iCD9gIVVPiRGuNmgjGaCsjIBkFtCWIgkDEKgo2Aazf/nF+K5dl796zy72cuz8/r5mdufeec8/5nofP/n7nt/eeVURgZvnYpeoCzKy5HGqzzDjUZplxqM0y41CbZcahNsuMQ12SpKsl3dqC5X5I0uOSNki6vi/rkfSYpPPrTNtXUkgaXGf6AZIWpfVO3ZFtaDZJX5T0SBOWs0LS5DrT7pB07Y6uox31eMB3NkkrgA8B/wdsBB4CLoyIjRXVcywwMyLGdL0WEf/YotVNAV4Dfi927ocGLgfmRsTEnbjOUiJiFjCr6joGqnZqqU+LiN2BicAk4KqK69lZxgFLdnKgu9a7uN5ESYN2Yi3vC/V6Tc3WTqEGICLWAg9ThBsASYdL+g9J6yX9KrWkXdP2kjRD0mpJb0j6Sc20U1MXc316/x/UTFsh6SpJS9L7ZkgaKmk48FNglKSN6WeUpGskzUzv/amkC2vrTnX9aXp8oKSfSXpd0guS/rynbZV0B/CXwOVpPZNr19No27sta5Ck6yS9Jmk5cEq9fSxpDvBp4LtpvR9P3dGbJT0o6W3g05IOSl389ZIWSzq9tnZJN6V9sVHSk5I+LOmGtD+flzSplxpC0lRJy1PN35K0S5p2rqQn0uMj0/Sx6fkhafkHpud1j3FZkkZKekDSurTsBySNSdM+L2lBt/n/VtJ96fGQtN9flvQbSd+XtFuadqykVZKukLQWmCFp77T89en8+GXXdjdNRFT+A6wAJqfHY4BngRvT89HA/wAnU/wSOiE93ydN/3fgHmAk0AH8cXp9EvAqcBgwiCI8K4AhNet8DhgL7AU8CVybph0LrOpW4zUUXXKALwFP1kybAKwHhgDDgZXAeRSXN5MoutcT6mz7HV3r7WE9jbb9MeD89PgC4Pma7ZkLBDC4znrfe29NHW8CR6V1jQCWAVcDuwLHARuAA2rmfw34Q2AoMAd4Me2bQcC1FN37esc8Uo17AR8F/qtmW84FnqiZ9xtp+bulc+PCPhzjyY32O/AB4HPAsLTd/wr8JE0bArwOHFTz3meAz6XH3wHuT9sxAvg34Js159E7wD+l5ewGfBP4PsW52gEcDaipeao60DU7f2M6aQJ4FNgzTbsCuLPb/A+nA/gR4F1gZA/LvBn4erfXXmBr6FcAF9RMOxn4dclQjwDeBsbVnHS3p8d/Afyy23t/AEzrR6jrbnsPoZ7TbXv+hL6H+oc1z48G1gK71Lx2F3BNzfy31Ey7CFha8/z3gfW9HPMATqp5/jfAo+nxuWwb6g5gAUWgH+oKQclj3DDUPUybCLzR7Vz6Rnr8CeANipAqnQcfq5n3CODFmvNoMzC0ZvrXgPuA/VqVp3bqfn82IkZQ7IgDgb3T6+OAz6fuynpJ64E/ogj0WOD1iHijh+WNAy7r9r6xwKiaeVbWPH6p27S6ImIDRQ/hzPTSWWwd2BkHHNZtvV8EPlxm2T1sQ71t724U229PX9W+fxSwMiLe7bbM0TXPf1Pz+Lc9PN+9D+uru/8jYgtFCA8Gro+UDsod44YkDZP0A0kvSXoLeBzYU1vHFf4Z+IIkAecAP4qITcA+FK37gpr1P5Re77IuIn5X8/xbFD2gR9Klx5V9qbWMthj9rhURv0jXmtcBn6U48HdGxJe7zyvpI8BekvaMiPXdJq+k+O36jV5WN7bm8UeB1V1llCj1LmCapMcpup9za9b7i4g4ocQyGqm77T1Yw/bb01e1270aGCtpl5pgd3WTm2UsWwfravf/NiSNBqYBM4DrJX0yharMMS7jMuAA4LCIWCtpIkUXWwAR8ZSkzRS9ly+kHyguP34LfCIiXqmz7G3OpdQgXEbxy+hgYI6kpyPi0R3chve0U0td6wbgBEmHADOB0ySdmAaDhqYBiDERsYZiUOumNNjRIemYtIxbgAskHabCcEmnSBpRs56vSBojaS/g7yiuzaFocT4gaY9eanyQoqX4GnBPzYn/APBxSeekejokfVLSQf3YD3W3vYd5fwRMTdszEtjRFmAe8L8Ug3gdKgboTgPu3sHl1vpqOm5jgYvZuv/fk1rHO4DbgL+i+OX19TS5zDEuYwRFONenc2FaD/P8EPgusCUingBIx/wW4DuSPpjqHS3pxHorSgN7+6XtepPiz7jv1pu/P9oy1BGxjmIn/kNErATOoBiwWUfx2/mrbK39HGALxSDRq8AlaRnzgS9THIg3KLo853Zb1b8AjwDLgV9TDO4QEc9TtMTLU7dqu+5cainuBSan5XS9voHievZMipZnLVsHSvq6Hxpte61bKK63fwUsTLX1W0RspgjxZyhapJuAL6V90yz3UVwrL6K4nLmth3mmAh8E/j51u88DzpN0dMljXMYNFINYrwFPUXShu7uTovs/s9vrV6T1PpW67j+naPXr2T/NsxH4T+CmiJjby/x91jXg8L6j4gMv50fEz6uu5f1IUgD7R8SyqmspI/2Z6lXg0Ij476rr6U1bttRmbeivgafbPdDQhgNlZu0m9epEMXDb9t633W+zXLn7bZaZlnS/d9WQGMrwVizazIDf8TabY5N6mtaSUA9lOIfp+FYs2syAeb18VsXdb7PMONRmmXGozTLjUJtlxqE2y4xDbZaZUqGWdJKKe20ta8WXus2seRqGOt394XsUX8GbAJwlaUKrCzOz/inTUn8KWBYRy9N3bO+m+I6vmbWhMqEezbb3klrFtvepAkDSFEnzJc3fwqZm1WdmfdS0gbKImB4RnRHR2dH3m3yYWZOUCfUrbHtDuzHpNTNrQ2VC/TSwv6TxknaluPfW/a0ty8z6q+G3tCLiHRX/YuZhiv+CcHtE1P0fTGZWrVJfvYyIByluiWtmbc6fKDPLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzpf6XVu4eXr2o6hIsYyeOmrhT1+eW2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLTMNQSxoraa6kJZIWS7p4ZxRmZv1T5sMn7wCXRcRCSSOABZJ+FhFLWlybmfVDw5Y6ItZExML0eAOwFBjd6sLMrH/69DFRSfsCk4B5PUybAkwBGMqwJpRmZv1ReqBM0u7Aj4FLIuKt7tMjYnpEdEZEZwdDmlmjmfVBqVBL6qAI9KyIuLe1JZnZjigz+i3gNmBpRHy79SWZ2Y4o01IfBZwDHCdpUfo5ucV1mVk/NRwoi4gnAO2EWsysCfyJMrPMONRmmXGozTLjUJtlxjceNEt29g0CW8UttVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcZ3E7UBKZc7f7aCW2qzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDOlQy1pkKRnJD3QyoLMbMf0paW+GFjaqkLMrDlKhVrSGOAU4NbWlmNmO6psS30DcDnwbr0ZJE2RNF/S/C1sakpxZtZ3DUMt6VTg1YhY0Nt8ETE9IjojorODIU0r0Mz6pkxLfRRwuqQVwN3AcZJmtrQqM+u3hqGOiKsiYkxE7AucCcyJiLNbXpmZ9Yv/Tm2WmT59nzoiHgMea0klZtYUbqnNMuNQm2XGoTbLjENtlhmH2iwzvpuoAfCxey5oyXL3u/SplizX6nNLbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmfHdRIETR02suoTK7Yfv+pkLt9RmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWZKhVrSnpJmS3pe0lJJR7S6MDPrn7IfPrkReCgi/kzSrsCwFtZkZjugYagl7QEcA5wLEBGbgc2tLcvM+qtM93s8sA6YIekZSbdKGt59JklTJM2XNH8Lm5peqJmVUybUg4FDgZsjYhLwNnBl95kiYnpEdEZEZwdDmlymmZVVJtSrgFURMS89n00RcjNrQw1DHRFrgZWSDkgvHQ8saWlVZtZvZUe/LwJmpZHv5cB5rSvJzHZEqVBHxCKgs8W1mFkT+BNlZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0yUyrUki6VtFjSc5LukjS01YWZWf80DLWk0cBUoDMiDgYGAWe2ujAz65+y3e/BwG6SBgPDgNWtK8nMdkTDUEfEK8B1wMvAGuDNiHik+3ySpkiaL2n+FjY1v1IzK6VM93skcAYwHhgFDJd0dvf5ImJ6RHRGRGcHQ5pfqZmVUqb7PRl4MSLWRcQW4F7gyNaWZWb9VSbULwOHSxomScDxwNLWlmVm/VXmmnoeMBtYCDyb3jO9xXWZWT8NLjNTREwDprW4FjNrAn+izCwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8woIpq/UGkd8FKJWfcGXmt6Aa0zkOodSLXCwKq3HWodFxH79DShJaEuS9L8iOisrIA+Gkj1DqRaYWDV2+61uvttlhmH2iwzVYd6oP3z+oFU70CqFQZWvW1da6XX1GbWfFW31GbWZA61WWYqC7WkkyS9IGmZpCurqqMRSWMlzZW0RNJiSRdXXVMZkgZJekbSA1XX0htJe0qaLel5SUslHVF1Tb2RdGk6D56TdJekoVXX1F0loZY0CPge8BlgAnCWpAlV1FLCO8BlETEBOBz4ShvXWutiYGnVRZRwI/BQRBwIHEIb1yxpNDAV6IyIg4FBwJnVVrW9qlrqTwHLImJ5RGwG7gbOqKiWXkXEmohYmB5voDjpRldbVe8kjQFOAW6tupbeSNoDOAa4DSAiNkfE+mqramgwsJukwcAwYHXF9WynqlCPBlbWPF9FmwcFQNK+wCRgXrWVNHQDcDnwbtWFNDAeWAfMSJcKt0oaXnVR9UTEK8B1wMvAGuDNiHik2qq254GykiTtDvwYuCQi3qq6nnoknQq8GhELqq6lhMHAocDNETEJeBto5/GVkRQ9yvHAKGC4pLOrrWp7VYX6FWBszfMx6bW2JKmDItCzIuLequtp4CjgdEkrKC5rjpM0s9qS6loFrIqIrp7PbIqQt6vJwIsRsS4itgD3AkdWXNN2qgr108D+ksZL2pVisOH+imrplSRRXPMtjYhvV11PIxFxVUSMiYh9KfbrnIhou9YEICLWAislHZBeOh5YUmFJjbwMHC5pWDovjqcNB/YGV7HSiHhH0oXAwxQjiLdHxOIqainhKOAc4FlJi9JrV0fEgxXWlJOLgFnpl/ty4LyK66krIuZJmg0spPiryDO04UdG/TFRs8x4oMwsMw61WWYcarPMONRmmXGozTLjUJtlxqE2y8z/A1DIMnrGMqKcAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "I2APaCzDGeqP",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Loading

0 comments on commit af09a1e

Please sign in to comment.