-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Warvito
authored and
Warvito
committed
Apr 12, 2020
1 parent
8663539
commit af09a1e
Showing
4 changed files
with
2,331 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"name": "Receptive fields.ipynb", | ||
"provenance": [], | ||
"collapsed_sections": [] | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "YqTKIYLooHsq", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"# Comparing receptive fields" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "wM-m3Z8CiLXU", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "gf5wwqP3ozaN", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"import random as rn\n", | ||
"\n", | ||
"import matplotlib\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np\n", | ||
"import tensorflow as tf\n", | ||
"from tensorflow import keras\n", | ||
"from tensorflow import nn\n", | ||
"from tensorflow.keras import initializers\n", | ||
"from tensorflow.keras.utils import Progbar" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "jEkll1yno2Vb", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"# Defining random seeds\n", | ||
"random_seed = 42\n", | ||
"tf.random.set_seed(random_seed)\n", | ||
"np.random.seed(random_seed)\n", | ||
"rn.seed(random_seed)" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "yJ_JlzWco7ci", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"class MaskedConv2D(keras.layers.Layer):\n", | ||
" \"\"\"Convolutional layers with masks.\n", | ||
"\n", | ||
" Convolutional layers with simple implementation of masks type A and B for\n", | ||
" autoregressive models.\n", | ||
"\n", | ||
" Arguments:\n", | ||
" mask_type: one of `\"A\"` or `\"B\".`\n", | ||
" filters: Integer, the dimensionality of the output space\n", | ||
" (i.e. the number of output filters in the convolution).\n", | ||
" kernel_size: An integer or tuple/list of 2 integers, specifying the\n", | ||
" height and width of the 2D convolution window.\n", | ||
" Can be a single integer to specify the same value for\n", | ||
" all spatial dimensions.\n", | ||
" strides: An integer or tuple/list of 2 integers,\n", | ||
" specifying the strides of the convolution along the height and width.\n", | ||
" Can be a single integer to specify the same value for\n", | ||
" all spatial dimensions.\n", | ||
" Specifying any stride value != 1 is incompatible with specifying\n", | ||
" any `dilation_rate` value != 1.\n", | ||
" padding: one of `\"valid\"` or `\"same\"` (case-insensitive).\n", | ||
" kernel_initializer: Initializer for the `kernel` weights matrix.\n", | ||
" bias_initializer: Initializer for the bias vector.\n", | ||
" \"\"\"\n", | ||
"\n", | ||
" def __init__(self,\n", | ||
" mask_type,\n", | ||
" filters,\n", | ||
" kernel_size,\n", | ||
" strides=1,\n", | ||
" padding='same',\n", | ||
" kernel_initializer='glorot_uniform',\n", | ||
" bias_initializer='zeros'):\n", | ||
" super(MaskedConv2D, self).__init__()\n", | ||
"\n", | ||
" assert mask_type in {'A', 'B'}\n", | ||
" self.mask_type = mask_type\n", | ||
"\n", | ||
" self.filters = filters\n", | ||
" self.kernel_size = kernel_size\n", | ||
" self.strides = strides\n", | ||
" self.padding = padding.upper()\n", | ||
" self.kernel_initializer = initializers.get(kernel_initializer)\n", | ||
" self.bias_initializer = initializers.get(bias_initializer)\n", | ||
"\n", | ||
" def build(self, input_shape):\n", | ||
" self.kernel = self.add_weight('kernel',\n", | ||
" shape=(self.kernel_size,\n", | ||
" self.kernel_size,\n", | ||
" int(input_shape[-1]),\n", | ||
" self.filters),\n", | ||
" initializer=self.kernel_initializer,\n", | ||
" trainable=True)\n", | ||
"\n", | ||
" self.bias = self.add_weight('bias',\n", | ||
" shape=(self.filters,),\n", | ||
" initializer=self.bias_initializer,\n", | ||
" trainable=True)\n", | ||
"\n", | ||
" center = self.kernel_size // 2\n", | ||
"\n", | ||
" mask = np.ones(self.kernel.shape, dtype=np.float32)\n", | ||
" mask[center, center + (self.mask_type == 'B'):, :, :] = 0.\n", | ||
" mask[center + 1:, :, :, :] = 0.\n", | ||
"\n", | ||
" self.mask = tf.constant(mask, dtype=tf.float32, name='mask')\n", | ||
"\n", | ||
" def call(self, input):\n", | ||
" masked_kernel = tf.math.multiply(self.mask, self.kernel)\n", | ||
" x = nn.conv2d(input,\n", | ||
" masked_kernel,\n", | ||
" strides=[1, self.strides, self.strides, 1],\n", | ||
" padding=self.padding)\n", | ||
" x = nn.bias_add(x, self.bias)\n", | ||
" return x" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "jxCLMYc-FxdJ", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"def plot_receptive_field(model, data):\n", | ||
" out = model(data)\n", | ||
"\n", | ||
" with tf.GradientTape() as tape:\n", | ||
" tape.watch(data)\n", | ||
" prediction = model(data)\n", | ||
" loss = prediction[:,5,5,0]\n", | ||
"\n", | ||
" gradients = tape.gradient(loss, data)\n", | ||
"\n", | ||
" gradients = np.abs(gradients.numpy().squeeze())\n", | ||
" gradients = (gradients > 1e-8).astype('float32')\n", | ||
" gradients[5, 5] = 0.5\n", | ||
"\n", | ||
" plt.figure()\n", | ||
" plt.imshow(gradients)\n", | ||
" plt.title(f'Receptive field from pixel layers')\n", | ||
" plt.show()" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "0qpDtNuvo9NL", | ||
"colab_type": "code", | ||
"outputId": "926434e2-44bf-40d3-a39f-6c80ebc880b6", | ||
"colab": { | ||
"base_uri": "https://localhost:8080/", | ||
"height": 281 | ||
} | ||
}, | ||
"source": [ | ||
"height = 10\n", | ||
"width = 10\n", | ||
"n_channel = 1\n", | ||
"inputs = keras.layers.Input(shape=(height, width, n_channel))\n", | ||
"x = MaskedConv2D(mask_type='A', filters=1, kernel_size=3, strides=1)(inputs)\n", | ||
"x = MaskedConv2D(mask_type='B', filters=1, kernel_size=3, strides=1)(x)\n", | ||
"x = MaskedConv2D(mask_type='B', filters=1, kernel_size=3, strides=1)(x)\n", | ||
"\n", | ||
"model = keras.Model(inputs=inputs, outputs=x)\n", | ||
"data = tf.random.normal((1,10,10,1))\n", | ||
"\n", | ||
"plot_receptive_field(model, data)" | ||
], | ||
"execution_count": 0, | ||
"outputs": [ | ||
{ | ||
"output_type": "display_data", | ||
"data": { | ||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAARBElEQVR4nO3dfZBV9X3H8fdHWEGQKkbzwEOQiUYltoLdxKdqjWI1PmaaptVEU50axjaKWic+dVqcxDTTiSY6k2iCD9gIVVPiRGuNmgjGaCsjIBkFtCWIgkDEKgo2Aazf/nF+K5dl796zy72cuz8/r5mdufeec8/5nofP/n7nt/eeVURgZvnYpeoCzKy5HGqzzDjUZplxqM0y41CbZcahNsuMQ12SpKsl3dqC5X5I0uOSNki6vi/rkfSYpPPrTNtXUkgaXGf6AZIWpfVO3ZFtaDZJX5T0SBOWs0LS5DrT7pB07Y6uox31eMB3NkkrgA8B/wdsBB4CLoyIjRXVcywwMyLGdL0WEf/YotVNAV4Dfi927ocGLgfmRsTEnbjOUiJiFjCr6joGqnZqqU+LiN2BicAk4KqK69lZxgFLdnKgu9a7uN5ESYN2Yi3vC/V6Tc3WTqEGICLWAg9ThBsASYdL+g9J6yX9KrWkXdP2kjRD0mpJb0j6Sc20U1MXc316/x/UTFsh6SpJS9L7ZkgaKmk48FNglKSN6WeUpGskzUzv/amkC2vrTnX9aXp8oKSfSXpd0guS/rynbZV0B/CXwOVpPZNr19No27sta5Ck6yS9Jmk5cEq9fSxpDvBp4LtpvR9P3dGbJT0o6W3g05IOSl389ZIWSzq9tnZJN6V9sVHSk5I+LOmGtD+flzSplxpC0lRJy1PN35K0S5p2rqQn0uMj0/Sx6fkhafkHpud1j3FZkkZKekDSurTsBySNSdM+L2lBt/n/VtJ96fGQtN9flvQbSd+XtFuadqykVZKukLQWmCFp77T89en8+GXXdjdNRFT+A6wAJqfHY4BngRvT89HA/wAnU/wSOiE93ydN/3fgHmAk0AH8cXp9EvAqcBgwiCI8K4AhNet8DhgL7AU8CVybph0LrOpW4zUUXXKALwFP1kybAKwHhgDDgZXAeRSXN5MoutcT6mz7HV3r7WE9jbb9MeD89PgC4Pma7ZkLBDC4znrfe29NHW8CR6V1jQCWAVcDuwLHARuAA2rmfw34Q2AoMAd4Me2bQcC1FN37esc8Uo17AR8F/qtmW84FnqiZ9xtp+bulc+PCPhzjyY32O/AB4HPAsLTd/wr8JE0bArwOHFTz3meAz6XH3wHuT9sxAvg34Js159E7wD+l5ewGfBP4PsW52gEcDaipeao60DU7f2M6aQJ4FNgzTbsCuLPb/A+nA/gR4F1gZA/LvBn4erfXXmBr6FcAF9RMOxn4dclQjwDeBsbVnHS3p8d/Afyy23t/AEzrR6jrbnsPoZ7TbXv+hL6H+oc1z48G1gK71Lx2F3BNzfy31Ey7CFha8/z3gfW9HPMATqp5/jfAo+nxuWwb6g5gAUWgH+oKQclj3DDUPUybCLzR7Vz6Rnr8CeANipAqnQcfq5n3CODFmvNoMzC0ZvrXgPuA/VqVp3bqfn82IkZQ7IgDgb3T6+OAz6fuynpJ64E/ogj0WOD1iHijh+WNAy7r9r6xwKiaeVbWPH6p27S6ImIDRQ/hzPTSWWwd2BkHHNZtvV8EPlxm2T1sQ71t724U229PX9W+fxSwMiLe7bbM0TXPf1Pz+Lc9PN+9D+uru/8jYgtFCA8Gro+UDsod44YkDZP0A0kvSXoLeBzYU1vHFf4Z+IIkAecAP4qITcA+FK37gpr1P5Re77IuIn5X8/xbFD2gR9Klx5V9qbWMthj9rhURv0jXmtcBn6U48HdGxJe7zyvpI8BekvaMiPXdJq+k+O36jV5WN7bm8UeB1V1llCj1LmCapMcpup9za9b7i4g4ocQyGqm77T1Yw/bb01e1270aGCtpl5pgd3WTm2UsWwfravf/NiSNBqYBM4DrJX0yharMMS7jMuAA4LCIWCtpIkUXWwAR8ZSkzRS9ly+kHyguP34LfCIiXqmz7G3OpdQgXEbxy+hgYI6kpyPi0R3chve0U0td6wbgBEmHADOB0ySdmAaDhqYBiDERsYZiUOumNNjRIemYtIxbgAskHabCcEmnSBpRs56vSBojaS/g7yiuzaFocT4gaY9eanyQoqX4GnBPzYn/APBxSeekejokfVLSQf3YD3W3vYd5fwRMTdszEtjRFmAe8L8Ug3gdKgboTgPu3sHl1vpqOm5jgYvZuv/fk1rHO4DbgL+i+OX19TS5zDEuYwRFONenc2FaD/P8EPgusCUingBIx/wW4DuSPpjqHS3pxHorSgN7+6XtepPiz7jv1pu/P9oy1BGxjmIn/kNErATOoBiwWUfx2/mrbK39HGALxSDRq8AlaRnzgS9THIg3KLo853Zb1b8AjwDLgV9TDO4QEc9TtMTLU7dqu+5cainuBSan5XS9voHievZMipZnLVsHSvq6Hxpte61bKK63fwUsTLX1W0RspgjxZyhapJuAL6V90yz3UVwrL6K4nLmth3mmAh8E/j51u88DzpN0dMljXMYNFINYrwFPUXShu7uTovs/s9vrV6T1PpW67j+naPXr2T/NsxH4T+CmiJjby/x91jXg8L6j4gMv50fEz6uu5f1IUgD7R8SyqmspI/2Z6lXg0Ij476rr6U1bttRmbeivgafbPdDQhgNlZu0m9epEMXDb9t633W+zXLn7bZaZlnS/d9WQGMrwVizazIDf8TabY5N6mtaSUA9lOIfp+FYs2syAeb18VsXdb7PMONRmmXGozTLjUJtlxqE2y4xDbZaZUqGWdJKKe20ta8WXus2seRqGOt394XsUX8GbAJwlaUKrCzOz/inTUn8KWBYRy9N3bO+m+I6vmbWhMqEezbb3klrFtvepAkDSFEnzJc3fwqZm1WdmfdS0gbKImB4RnRHR2dH3m3yYWZOUCfUrbHtDuzHpNTNrQ2VC/TSwv6TxknaluPfW/a0ty8z6q+G3tCLiHRX/YuZhiv+CcHtE1P0fTGZWrVJfvYyIByluiWtmbc6fKDPLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzpf6XVu4eXr2o6hIsYyeOmrhT1+eW2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLTMNQSxoraa6kJZIWS7p4ZxRmZv1T5sMn7wCXRcRCSSOABZJ+FhFLWlybmfVDw5Y6ItZExML0eAOwFBjd6sLMrH/69DFRSfsCk4B5PUybAkwBGMqwJpRmZv1ReqBM0u7Aj4FLIuKt7tMjYnpEdEZEZwdDmlmjmfVBqVBL6qAI9KyIuLe1JZnZjigz+i3gNmBpRHy79SWZ2Y4o01IfBZwDHCdpUfo5ucV1mVk/NRwoi4gnAO2EWsysCfyJMrPMONRmmXGozTLjUJtlxjceNEt29g0CW8UttVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcZ3E7UBKZc7f7aCW2qzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDOlQy1pkKRnJD3QyoLMbMf0paW+GFjaqkLMrDlKhVrSGOAU4NbWlmNmO6psS30DcDnwbr0ZJE2RNF/S/C1sakpxZtZ3DUMt6VTg1YhY0Nt8ETE9IjojorODIU0r0Mz6pkxLfRRwuqQVwN3AcZJmtrQqM+u3hqGOiKsiYkxE7AucCcyJiLNbXpmZ9Yv/Tm2WmT59nzoiHgMea0klZtYUbqnNMuNQm2XGoTbLjENtlhmH2iwzvpuoAfCxey5oyXL3u/SplizX6nNLbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWYcarPMONRmmfHdRIETR02suoTK7Yfv+pkLt9RmmXGozTLjUJtlxqE2y4xDbZYZh9osMw61WWZKhVrSnpJmS3pe0lJJR7S6MDPrn7IfPrkReCgi/kzSrsCwFtZkZjugYagl7QEcA5wLEBGbgc2tLcvM+qtM93s8sA6YIekZSbdKGt59JklTJM2XNH8Lm5peqJmVUybUg4FDgZsjYhLwNnBl95kiYnpEdEZEZwdDmlymmZVVJtSrgFURMS89n00RcjNrQw1DHRFrgZWSDkgvHQ8saWlVZtZvZUe/LwJmpZHv5cB5rSvJzHZEqVBHxCKgs8W1mFkT+BNlZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0y41CbZcahNsuMQ22WGYfaLDMOtVlmHGqzzDjUZplxqM0yUyrUki6VtFjSc5LukjS01YWZWf80DLWk0cBUoDMiDgYGAWe2ujAz65+y3e/BwG6SBgPDgNWtK8nMdkTDUEfEK8B1wMvAGuDNiHik+3ySpkiaL2n+FjY1v1IzK6VM93skcAYwHhgFDJd0dvf5ImJ6RHRGRGcHQ5pfqZmVUqb7PRl4MSLWRcQW4F7gyNaWZWb9VSbULwOHSxomScDxwNLWlmVm/VXmmnoeMBtYCDyb3jO9xXWZWT8NLjNTREwDprW4FjNrAn+izCwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8w41GaZcajNMuNQm2XGoTbLjENtlhmH2iwzDrVZZhxqs8woIpq/UGkd8FKJWfcGXmt6Aa0zkOodSLXCwKq3HWodFxH79DShJaEuS9L8iOisrIA+Gkj1DqRaYWDV2+61uvttlhmH2iwzVYd6oP3z+oFU70CqFQZWvW1da6XX1GbWfFW31GbWZA61WWYqC7WkkyS9IGmZpCurqqMRSWMlzZW0RNJiSRdXXVMZkgZJekbSA1XX0htJe0qaLel5SUslHVF1Tb2RdGk6D56TdJekoVXX1F0loZY0CPge8BlgAnCWpAlV1FLCO8BlETEBOBz4ShvXWutiYGnVRZRwI/BQRBwIHEIb1yxpNDAV6IyIg4FBwJnVVrW9qlrqTwHLImJ5RGwG7gbOqKiWXkXEmohYmB5voDjpRldbVe8kjQFOAW6tupbeSNoDOAa4DSAiNkfE+mqramgwsJukwcAwYHXF9WynqlCPBlbWPF9FmwcFQNK+wCRgXrWVNHQDcDnwbtWFNDAeWAfMSJcKt0oaXnVR9UTEK8B1wMvAGuDNiHik2qq254GykiTtDvwYuCQi3qq6nnoknQq8GhELqq6lhMHAocDNETEJeBto5/GVkRQ9yvHAKGC4pLOrrWp7VYX6FWBszfMx6bW2JKmDItCzIuLequtp4CjgdEkrKC5rjpM0s9qS6loFrIqIrp7PbIqQt6vJwIsRsS4itgD3AkdWXNN2qgr108D+ksZL2pVisOH+imrplSRRXPMtjYhvV11PIxFxVUSMiYh9KfbrnIhou9YEICLWAislHZBeOh5YUmFJjbwMHC5pWDovjqcNB/YGV7HSiHhH0oXAwxQjiLdHxOIqainhKOAc4FlJi9JrV0fEgxXWlJOLgFnpl/ty4LyK66krIuZJmg0spPiryDO04UdG/TFRs8x4oMwsMw61WWYcarPMONRmmXGozTLjUJtlxqE2y8z/A1DIMnrGMqKcAAAAAElFTkSuQmCC\n", | ||
"text/plain": [ | ||
"<Figure size 432x288 with 1 Axes>" | ||
] | ||
}, | ||
"metadata": { | ||
"tags": [], | ||
"needs_background": "light" | ||
} | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "I2APaCzDGeqP", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
} | ||
] | ||
} |
Oops, something went wrong.