-
Notifications
You must be signed in to change notification settings - Fork 334
/
Copy pathsrc.py
315 lines (267 loc) · 12.4 KB
/
src.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import inspect
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer
from spektral.utils.keras import (
deserialize_kwarg,
is_keras_kwarg,
is_layer_kwarg,
serialize_kwarg,
)
class SRCPool(Layer):
r"""
A general class for graph pooling layers based on the "Select, Reduce,
Connect" framework presented in:
> [Understanding Pooling in Graph Neural Networks.](https://arxiv.org/abs/2110.05292)<br>
> Daniele Grattarola et al.
This layer computes:
$$
\begin{align}
& \mathcal{S} = \left\{\mathcal{S}_k\right\}_{k=1:K} = \textrm{Sel}(\mathcal{G}) \\
& \mathcal{X}'=\left\{\textrm{Red}( \mathcal{G}, \mathcal{S}_k )\right\}_{k=1:K} \\
& \mathcal{E}'=\left\{\textrm{Con}( \mathcal{G}, \mathcal{S}_k, \mathcal{S}_l )\right\}_{k,L=1:K} \\
\end{align}
$$
Where \(\textrm{Sel}\) is a node equivariant selection function that computes
the supernode assignments \(\mathcal{S}_k\), \(\textrm{Red}\) is a
permutation-invariant function to reduce the supernodes into the new node
attributes, and \(\textrm{Con}\) is a permutation-invariant connection
function that computes the link between the pooled nodes.
By extending this class, it is possible to create any pooling layer in the
SRC formalism.
**Input**
- `x`: Tensor of shape `([batch], N, F)` representing node features;
- `a`: Tensor or SparseTensor of shape `([batch], N, N)` representing the
adjacency matrix;
- `i`: (optional) Tensor of integers with shape `(N, )` representing the
batch index;
**Output**
- `x_pool`: Tensor of shape `([batch], K, F)`, representing the node
features of the output. `K` is the number of output nodes and depends on the
specific pooling strategy;
- `a_pool`: Tensor or SparseTensor of shape `([batch], K, K)` representing
the adjacency matrix of the output;
- `i_pool`: (only if i was given as input) Tensor of integers with shape
`(K, )` representing the batch index of the output;
- `s`: (if `return_selection=True`) Tensor or SparseTensor representing the
supernode assignments;
**API**
- `pool(x, a, i, **kwargs)`: pools the graph and returns the reduced node
features and adjacency matrix. If the batch index `i` is not `None`, a
reduced version of `i` will be returned as well.
Any given `kwargs` will be passed as keyword arguments to `select()`,
`reduce()` and `connect()` if any matching key is found.
The mandatory arguments of `pool()` **must** be computed in `call()` by
calling `self.get_inputs(inputs)`.
- `select(x, a, i, **kwargs)`: computes supernode assignments mapping the
nodes of the input graph to the nodes of the output.
- `reduce(x, s, **kwargs)`: reduces the supernodes to form the nodes of the
pooled graph.
- `connect(a, s, **kwargs)`: connects the reduced supernodes.
- `reduce_index(i, s, **kwargs)`: helper function to reduce the batch index
(only called if `i` is given as input).
When overriding any function of the API, it is possible to access the
true number of nodes of the input (`n_nodes`) as a Tensor in the instance variable
`self.n_nodes` (this is populated by `self.get_inputs()` at the beginning of
`call()`).
**Arguments**:
- `return_selection`: if `True`, the Tensor used to represent supernode assignments
will be returned with `x_pool`, `a_pool`, and `i_pool`;
"""
def __init__(self, return_selection=False, **kwargs):
# kwargs for the Layer class are handled automatically
super().__init__(**{k: v for k, v in kwargs.items() if is_keras_kwarg(k)})
self.supports_masking = True
self.return_selection = return_selection
# *_regularizer, *_constraint, *_initializer, activation, and use_bias
# are dealt with automatically if passed to the constructor
self.kwargs_keys = []
for key in kwargs:
if is_layer_kwarg(key):
attr = kwargs[key]
attr = deserialize_kwarg(key, attr)
self.kwargs_keys.append(key)
setattr(self, key, attr)
# Signature of the SRC functions
self.sel_signature = inspect.signature(self.select).parameters
self.red_signature = inspect.signature(self.reduce).parameters
self.con_signature = inspect.signature(self.connect).parameters
self.i_red_signature = inspect.signature(self.reduce_index).parameters
self._n_nodes = None
def build(self, input_shape):
super().build(input_shape)
def call(self, inputs, **kwargs):
# Always start the call() method with get_inputs(inputs) to set self.n_nodes
x, a, i = self.get_inputs(inputs)
return self.pool(x, a, i, **kwargs)
def pool(self, x, a, i, **kwargs):
"""
This is the core method of the SRC class, which runs a full pass of
selection, reduction and connection.
It is usually not necessary to modify this function. Any previous/shared
operations should be done in `call()` and their results can be passed to
the three SRC functions via keyword arguments (any kwargs given to this
function will be matched to the signature of `select()`, `reduce()` and
`connect()` and propagated as input to the three functions).
Any pooling logic should go in the SRC functions themselves.
:param x: Tensor of shape `([batch], N, F)`;
:param a: Tensor or SparseTensor of shape `([batch], N, N)`;
:param i: only in single/disjoint mode, Tensor of integers with shape
`(N, )`; otherwise, `None`;
:param kwargs: additional keyword arguments for `select()`, `reduce()`
and `connect()`. Any matching kwargs will be passed to each of the three
functions.
:return:
- `x_pool`: Tensor of shape `([batch], K, F)`, where `K` is the
number of output nodes and depends on the pooling strategy;
- `a_pool`: Tensor or SparseTensor of shape `([batch], K, K)`;
- `i_pool`: (only if `i` is not `None`) Tensor of integers with shape
`(K, )`;
"""
# Select
sel_kwargs = self._get_kwargs(x, a, i, self.sel_signature, kwargs)
s = self.select(x, a, i, **sel_kwargs)
# Reduce
red_kwargs = self._get_kwargs(x, a, i, self.red_signature, kwargs)
x_pool = self.reduce(x, s, **red_kwargs)
# Index reduce
i_red_kwargs = self._get_kwargs(x, a, i, self.i_red_signature, kwargs)
i_pool = self.reduce_index(i, s, **i_red_kwargs) if i is not None else None
# Connect
con_kwargs = self._get_kwargs(x, a, i, self.con_signature, kwargs)
a_pool = self.connect(a, s, **con_kwargs)
return self.get_outputs(x_pool, a_pool, i_pool, s)
def select(self, x, a, i, **kwargs):
"""
Selection function. Given the graph, computes the supernode assignments
that will eventually be mapped to the `K` nodes of the pooled graph.
Supernode assignments are usually represented as a dense matrix of shape
`(N, K)` or sparse indices of shape `(K, )`.
:param x: Tensor of shape `([batch], N, F)`;
:param a: Tensor or SparseTensor (depending on the implementation of the
SRC functions) of shape `([batch], N, N)`;
:param i: Tensor of integers with shape `(N, )` or `None`;
:param kwargs: additional keyword arguments.
:return: Tensor representing supernode assignments.
"""
return tf.range(tf.shape(i))
def reduce(self, x, s, **kwargs):
"""
Reduction function. Given a selection, reduces the supernodes to form
the nodes of the new graph.
:param x: Tensor of shape `([batch], N, F)`;
:param s: Tensor representing supernode assignments, as computed by
`select()`;
:param kwargs: additional keyword arguments; when overriding this
function, any keyword argument defined explicitly as `key=None` will be
automatically filled in when calling `pool(key=value)`.
:return: Tensor of shape `([batch], K, F)` representing the node attributes of
the pooled graph.
"""
return tf.gather(x, s)
def connect(self, a, s, **kwargs):
"""
Connection function. Given a selection, connects the nodes of the pooled
graphs.
:param a: Tensor or SparseTensor of shape `([batch], N, N)`;
:param s: Tensor representing supernode assignments, as computed by
`select()`;
:param kwargs: additional keyword arguments; when overriding this
function, any keyword argument defined explicitly as `key=None` will be
automatically filled in when calling `pool(key=value)`.
:return: Tensor or SparseTensor of shape `([batch], K, K)` representing
the adjacency matrix of the pooled graph.
"""
return sparse_connect(a, s, self.n_nodes)
def reduce_index(self, i, s, **kwargs):
"""
Helper function to reduce the batch index `i`. Given a selection,
returns a new batch index for the pooled graph. This is only called by
`pool()` when `i` is given as input to the layer.
:param i: Tensor of integers with shape `(N, )`;
:param s: Tensor representing supernode assignments, as computed by
`select()`.
:param kwargs: additional keyword arguments; when overriding this
function, any keyword argument defined explicitly as `key=None` will be
automatically filled in when calling `pool(key=value)`.
:return: Tensor of integers of shape `(K, )`.
"""
return tf.gather(i, s)
@staticmethod
def _get_kwargs(x, a, i, signature, kwargs):
output = {}
for k in signature.keys():
if signature[k].default is inspect.Parameter.empty or k == "kwargs":
pass
elif k == "x":
output[k] = x
elif k == "a":
output[k] = a
elif k == "i":
output[k] = i
elif k in kwargs:
output[k] = kwargs[k]
else:
raise ValueError("Missing key {} for signature {}".format(k, signature))
return output
def get_inputs(self, inputs):
if len(inputs) == 3:
x, a, i = inputs
if K.ndim(i) == 2:
i = i[:, 0]
assert K.ndim(i) == 1, "i must have rank 1"
elif len(inputs) == 2:
x, a = inputs
i = None
else:
raise ValueError(
"Expected 2 or 3 inputs tensors (x, a, i), got {}.".format(len(inputs))
)
self.n_nodes = tf.shape(x)[-2]
return x, a, i
def get_outputs(self, x_pool, a_pool, i_pool, s):
output = [x_pool, a_pool]
if i_pool is not None:
output.append(i_pool)
if self.return_selection:
output.append(s)
return output
def get_config(self):
config = {
"return_selection": self.return_selection,
}
for key in self.kwargs_keys:
config[key] = serialize_kwarg(key, getattr(self, key))
base_config = super().get_config()
return {**base_config, **config}
def compute_mask(self, inputs, mask=None):
# After pooling all nodes are always valid
return None
@property
def n_nodes(self):
if self._n_nodes is None:
raise ValueError(
"self.n_nodes has not been defined. Have you called "
"self.get_inputs(inputs) at the beginning of call()?"
)
return self._n_nodes
@n_nodes.setter
def n_nodes(self, value):
self._n_nodes = value
@n_nodes.deleter
def n_nodes(self):
self._n_nodes = None
def sparse_connect(A, S, N):
N_sel = tf.cast(tf.shape(S), tf.int64)[0]
m = tf.scatter_nd(S[:, None], tf.range(N_sel) + 1, (N,)) - 1
row, col = A.indices[:, 0], A.indices[:, 1]
r_mask = tf.gather(m, row)
c_mask = tf.gather(m, col)
mask_total = (r_mask >= 0) & (c_mask >= 0)
r_new = tf.boolean_mask(r_mask, mask_total)
c_new = tf.boolean_mask(c_mask, mask_total)
v_new = tf.boolean_mask(A.values, mask_total)
output = tf.SparseTensor(
values=v_new, indices=tf.stack((r_new, c_new), 1), dense_shape=(N_sel, N_sel)
)
return tf.sparse.reorder(output)