diff --git a/clrs/_src/baselines.py b/clrs/_src/baselines.py index 03f343af..61f80d20 100644 --- a/clrs/_src/baselines.py +++ b/clrs/_src/baselines.py @@ -72,8 +72,6 @@ def __init__( decode_hints: bool, decode_diffs: bool, kind: str, - inf_bias: bool, - inf_bias_edge: bool, use_lstm: bool, dropout_prob: float, nb_heads: int, @@ -85,8 +83,6 @@ def __init__( self._dropout_prob = dropout_prob self.spec = spec - self.inf_bias = inf_bias - self.inf_bias_edge = inf_bias_edge self.hidden_dim = hidden_dim self.encode_hints = encode_hints self.decode_hints = decode_hints @@ -191,12 +187,16 @@ def __call__(self, features: _Features, repred: bool): nb_nodes = inp.data.shape[1] break - self._construct_encoders_decoders() - self._construct_processor() + # Construct encoders and decoders. + (self.encoders, self.decoders, + self.diff_decoders) = self._construct_encoders_decoders() + self.processor = processors.construct_processor( + kind=self.kind, hidden_dim=self.hidden_dim, nb_heads=self.nb_heads) nb_mp_steps = max(1, hints[0].data.shape[0] - 1) hiddens = jnp.zeros((self.batch_size, nb_nodes, self.hidden_dim)) + # Optionally construct LSTM. if self.use_lstm: self.lstm = hk.LSTM( hidden_size=self.hidden_dim, @@ -258,58 +258,25 @@ def invert(d): def _construct_encoders_decoders(self): """Constructs encoders and decoders.""" - self.encoders = {} - self.decoders = {} + encoders_ = {} + decoders_ = {} for name, (stage, loc, t) in self.spec.items(): if stage == _Stage.INPUT or (stage == _Stage.HINT and self.encode_hints): # Build input encoders. - self.encoders[name] = encoders.construct_encoders( + encoders_[name] = encoders.construct_encoders( loc, t, hidden_dim=self.hidden_dim) if stage == _Stage.OUTPUT or (stage == _Stage.HINT and self.decode_hints): # Build output decoders. - self.decoders[name] = decoders.construct_decoders( + decoders_[name] = decoders.construct_decoders( loc, t, hidden_dim=self.hidden_dim, nb_dims=self.nb_dims[name]) if self.decode_diffs: # Optionally build diff decoders. - self.diff_decoders = decoders.construct_diff_decoders() - - def _construct_processor(self): - """Constructs processor.""" - - if self.kind in ['deepsets', 'mpnn', 'pgn']: - self.mpnn = processors.MPNN( - out_size=self.hidden_dim, - mid_act=jax.nn.relu, - activation=jax.nn.relu, - reduction=jnp.max, - msgs_mlp_sizes=[ - self.hidden_dim, - self.hidden_dim, - ]) - elif self.kind in ['gat', 'gat_full']: - self.mpnn = processors.GAT( - out_size=self.hidden_dim, - nb_heads=self.nb_heads, - activation=jax.nn.relu, - residual=True) - elif self.kind in ['gatv2', 'gatv2_full']: - self.mpnn = processors.GATv2( - out_size=self.hidden_dim, - nb_heads=self.nb_heads, - activation=jax.nn.relu, - residual=True) - elif self.kind == 'memnet_full' or self.kind == 'memnet_masked': - self.memnet = processors.MemNet( - vocab_size=self.hidden_dim, - embedding_size=16, - sentence_size=self.hidden_dim, - linear_output_size=self.hidden_dim, - memory_size=128, - num_hops=1, - apply_embeddings=True) + diff_decoders = decoders.construct_diff_decoders() + + return encoders_, decoders_, diff_decoders def _one_step_pred( self, @@ -329,7 +296,6 @@ def _one_step_pred( jnp.expand_dims(jnp.eye(nb_nodes), 0), self.batch_size, axis=0) # ENCODE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Encode node/edge/graph features from inputs and (optionally) hints. trajectories = [inputs] if self.encode_hints: @@ -348,33 +314,15 @@ def _one_step_pred( raise Exception(f'Failed to process {dp}') from e # PROCESS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - if self.kind == 'deepsets': - adj_mat = jnp.repeat( - jnp.expand_dims(jnp.eye(nb_nodes), 0), self.batch_size, axis=0) - elif (self.kind == 'mpnn' or self.kind == 'gat_full' or - self.kind == 'gatv2_full' or self.kind == 'memnet_full'): - adj_mat = jnp.ones_like(adj_mat) - elif (self.kind == 'pgn' or self.kind == 'gat' or self.kind == 'gatv2' or - self.kind == 'memnet_masked'): - adj_mat = (adj_mat > 0.0) * 1.0 - else: - raise ValueError('Unsupported kind of model') - - z = jnp.concatenate([node_fts, hidden], axis=-1) - if self.kind == 'memnet_full' or self.kind == 'memnet_masked': - node_and_graph_fts = jnp.concatenate( - [node_fts, graph_fts[:, None]], axis=1) - edge_fts_padded = jnp.pad(edge_fts * adj_mat[..., None], - ((0, 0), (0, 1), (0, 1), (0, 0))) - nxt_hidden = jax.vmap(self.memnet, (1), 1)(node_and_graph_fts, - edge_fts_padded) - # Broadcast hidden state corresponding to graph features across the nodes. - nxt_hidden = nxt_hidden[:, :-1] + nxt_hidden[:, -1:] - else: - nxt_hidden = self.mpnn(z, edge_fts, graph_fts, - (adj_mat > 0.0).astype('float32')) - + nxt_hidden = self.processor( + node_fts, + edge_fts, + graph_fts, + adj_mat, + hidden, + batch_size=self.batch_size, + nb_nodes=nb_nodes, + ) nxt_hidden = hk.dropout(hk.next_rng_key(), self._dropout_prob, nxt_hidden) if self.use_lstm: @@ -384,10 +332,9 @@ def _one_step_pred( else: nxt_lstm_state = None - h_t = jnp.concatenate([z, nxt_hidden], axis=-1) + h_t = jnp.concatenate([node_fts, hidden, nxt_hidden], axis=-1) # DECODE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Decode features and (optionally) hints. hint_preds, output_preds = decoders.decode_fts( decoders=self.decoders, @@ -396,8 +343,8 @@ def _one_step_pred( adj_mat=adj_mat, edge_fts=edge_fts, graph_fts=graph_fts, - inf_bias=self.inf_bias, - inf_bias_edge=self.inf_bias_edge, + inf_bias=self.processor.inf_bias, + inf_bias_edge=self.processor.inf_bias_edge, ) # Optionally decode diffs. @@ -447,14 +394,6 @@ def __init__( self._freeze_processor = freeze_processor self.opt = optax.adam(learning_rate) - if kind == 'pgn_mask': - inf_bias = True - inf_bias_edge = True - kind = 'pgn' - else: - inf_bias = False - inf_bias_edge = False - self.nb_dims = {} for inp in dummy_trajectory.features.inputs: self.nb_dims[inp.name] = inp.data.shape[-1] @@ -465,8 +404,8 @@ def __init__( def _use_net(*args, **kwargs): return Net(spec, hidden_dim, encode_hints, decode_hints, decode_diffs, - kind, inf_bias, inf_bias_edge, use_lstm, dropout_prob, - nb_heads, self.nb_dims)(*args, **kwargs) + kind, use_lstm, dropout_prob, nb_heads, self.nb_dims)(*args, + **kwargs) self.net_fn = hk.transform(_use_net) self.net_fn_apply = jax.jit(self.net_fn.apply, static_argnums=3) diff --git a/clrs/_src/encoders.py b/clrs/_src/encoders.py index e1a52ed5..1ae8d57a 100644 --- a/clrs/_src/encoders.py +++ b/clrs/_src/encoders.py @@ -50,11 +50,11 @@ def preprocess(dp: _DataPoint, nb_nodes: int) -> _Array: def accum_adj_mat(dp: _DataPoint, data: _Array, adj_mat: _Array) -> _Array: """Accumulates adjacency matrix.""" if dp.location == _Location.NODE and dp.type_ == _Type.POINTER: - adj_mat += ((data + jnp.transpose(data, (0, 2, 1))) > 0.0).astype('float32') + adj_mat += ((data + jnp.transpose(data, (0, 2, 1))) > 0.0) elif dp.location == _Location.EDGE and dp.type_ == _Type.MASK: - adj_mat += (data > 0.0).astype('float32') + adj_mat += (data > 0.0) - return adj_mat + return (adj_mat > 0.).astype('float32') def accum_edge_fts(encoders, dp: _DataPoint, data: _Array, diff --git a/clrs/_src/processors.py b/clrs/_src/processors.py index 65dde830..5ede62c4 100644 --- a/clrs/_src/processors.py +++ b/clrs/_src/processors.py @@ -15,6 +15,7 @@ """JAX implementation of baseline processor networks.""" +import abc from typing import Any, Callable, List, Optional import chex @@ -28,14 +29,51 @@ _Fn = Callable[..., Any] -class GAT(hk.Module): +class Processor(hk.Module): + """Processor abstract base class.""" + + @abc.abstractmethod + def __call__( + self, + node_fts: _Array, + edge_fts: _Array, + graph_fts: _Array, + adj_mat: _Array, + hidden: _Array, + **kwargs, + ) -> _Array: + """Processor inference step. + + Args: + node_fts: Node features. + edge_fts: Edge features. + graph_fts: Graph features. + adj_mat: Graph adjacency matrix. + hidden: Hidden features. + **kwargs: Extra kwargs. + + Returns: + Output of processor inference step. + """ + pass + + @property + def inf_bias(self): + return False + + @property + def inf_bias_edge(self): + return False + + +class GAT(Processor): """Graph Attention Network (Velickovic et al., ICLR 2018).""" def __init__( self, out_size: int, nb_heads: int, - activation: Optional[_Fn] = None, + activation: Optional[_Fn] = jax.nn.relu, residual: bool = True, name: str = 'gat_aggr', ): @@ -50,31 +88,25 @@ def __init__( def __call__( self, - features: _Array, - e_features: _Array, - g_features: _Array, - adj: _Array, + node_fts: _Array, + edge_fts: _Array, + graph_fts: _Array, + adj_mat: _Array, + hidden: _Array, + **unused_kwargs, ) -> _Array: - """GAT inference step. + """GAT inference step.""" - Args: - features: Node features. - e_features: Edge features. - g_features: Graph features. - adj: Graph adjacency matrix. - - Returns: - Output of GAT inference step. - """ - b, n, _ = features.shape - assert e_features.shape[:-1] == (b, n, n) - assert g_features.shape[:-1] == (b,) - assert adj.shape == (b, n, n) + b, n, _ = node_fts.shape + assert edge_fts.shape[:-1] == (b, n, n) + assert graph_fts.shape[:-1] == (b,) + assert adj_mat.shape == (b, n, n) + z = jnp.concatenate([node_fts, hidden], axis=-1) m = hk.Linear(self.out_size) skip = hk.Linear(self.out_size) - bias_mat = (adj - 1.0) * 1e9 + bias_mat = (adj_mat - 1.0) * 1e9 bias_mat = jnp.tile(bias_mat[..., None], (1, 1, 1, self.nb_heads)) # [B, N, N, H] bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N] @@ -84,16 +116,16 @@ def __call__( a_e = hk.Linear(self.nb_heads) a_g = hk.Linear(self.nb_heads) - values = m(features) # [B, N, H*F] + values = m(z) # [B, N, H*F] values = jnp.reshape( values, values.shape[:-1] + (self.nb_heads, self.head_size)) # [B, N, H, F] values = jnp.transpose(values, (0, 2, 1, 3)) # [B, H, N, F] - att_1 = jnp.expand_dims(a_1(features), axis=-1) - att_2 = jnp.expand_dims(a_2(features), axis=-1) - att_e = a_e(e_features) - att_g = jnp.expand_dims(a_g(g_features), axis=-1) + att_1 = jnp.expand_dims(a_1(z), axis=-1) + att_2 = jnp.expand_dims(a_2(z), axis=-1) + att_e = a_e(edge_fts) + att_g = jnp.expand_dims(a_g(graph_fts), axis=-1) logits = ( jnp.transpose(att_1, (0, 2, 1, 3)) + # + [B, H, N, 1] @@ -107,7 +139,7 @@ def __call__( ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) # [B, N, H*F] if self.residual: - ret += skip(features) + ret += skip(z) if self.activation is not None: ret = self.activation(ret) @@ -115,7 +147,16 @@ def __call__( return ret -class GATv2(hk.Module): +class GATFull(GAT): + """Graph Attention Network with full adjacency matrix.""" + + def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, + adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: + adj_mat = jnp.ones_like(adj_mat) + return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) + + +class GATv2(Processor): """Graph Attention Network v2 (Brody et al., ICLR 2022).""" def __init__( @@ -123,7 +164,7 @@ def __init__( out_size: int, nb_heads: int, mid_size: Optional[int] = None, - activation: Optional[_Fn] = None, + activation: Optional[_Fn] = jax.nn.relu, residual: bool = True, name: str = 'gatv2_aggr', ): @@ -145,31 +186,25 @@ def __init__( def __call__( self, - features: _Array, - e_features: _Array, - g_features: _Array, - adj: _Array, + node_fts: _Array, + edge_fts: _Array, + graph_fts: _Array, + adj_mat: _Array, + hidden: _Array, + **unused_kwargs, ) -> _Array: - """GATv2 inference step. - - Args: - features: Node features. - e_features: Edge features. - g_features: Graph features. - adj: Graph adjacency matrix. + """GATv2 inference step.""" - Returns: - Output of GATv2 inference step. - """ - b, n, _ = features.shape - assert e_features.shape[:-1] == (b, n, n) - assert g_features.shape[:-1] == (b,) - assert adj.shape == (b, n, n) + b, n, _ = node_fts.shape + assert edge_fts.shape[:-1] == (b, n, n) + assert graph_fts.shape[:-1] == (b,) + assert adj_mat.shape == (b, n, n) + z = jnp.concatenate([node_fts, hidden], axis=-1) m = hk.Linear(self.out_size) skip = hk.Linear(self.out_size) - bias_mat = (adj - 1.0) * 1e9 + bias_mat = (adj_mat - 1.0) * 1e9 bias_mat = jnp.tile(bias_mat[..., None], (1, 1, 1, self.nb_heads)) # [B, N, N, H] bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N] @@ -183,16 +218,16 @@ def __call__( for _ in range(self.nb_heads): a_heads.append(hk.Linear(1)) - values = m(features) # [B, N, H*F] + values = m(z) # [B, N, H*F] values = jnp.reshape( values, values.shape[:-1] + (self.nb_heads, self.head_size)) # [B, N, H, F] values = jnp.transpose(values, (0, 2, 1, 3)) # [B, H, N, F] - pre_att_1 = w_1(features) - pre_att_2 = w_2(features) - pre_att_e = w_e(e_features) - pre_att_g = w_g(g_features) + pre_att_1 = w_1(z) + pre_att_2 = w_2(z) + pre_att_e = w_e(edge_fts) + pre_att_g = w_g(graph_fts) pre_att = ( jnp.expand_dims(pre_att_1, axis=1) + # + [B, 1, N, H*F] @@ -226,7 +261,7 @@ def __call__( ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) # [B, N, H*F] if self.residual: - ret += skip(features) + ret += skip(z) if self.activation is not None: ret = self.activation(ret) @@ -234,15 +269,24 @@ def __call__( return ret -class MPNN(hk.Module): - """Message-Passing Neural Network (Gilmer et al., ICML 2017).""" +class GATv2Full(GATv2): + """Graph Attention Network v2 with full adjacency matrix.""" + + def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, + adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: + adj_mat = jnp.ones_like(adj_mat) + return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) + + +class PGN(Processor): + """Pointer Graph Networks (Veličković et al., NeurIPS 2020).""" def __init__( self, out_size: int, mid_size: Optional[int] = None, mid_act: Optional[_Fn] = None, - activation: Optional[_Fn] = None, + activation: Optional[_Fn] = jax.nn.relu, reduction: _Fn = jnp.max, msgs_mlp_sizes: Optional[List[int]] = None, name: str = 'mpnn_aggr', @@ -260,27 +304,21 @@ def __init__( def __call__( self, - features: _Array, - e_features: _Array, - g_features: _Array, - adj: _Array, + node_fts: _Array, + edge_fts: _Array, + graph_fts: _Array, + adj_mat: _Array, + hidden: _Array, + **unused_kwargs, ) -> _Array: - """MPNN inference step. + """MPNN inference step.""" - Args: - features: Node features. - e_features: Edge features. - g_features: Graph features. - adj: Graph adjacency matrix. - - Returns: - Output of MPNN inference step. - """ - b, n, _ = features.shape - assert e_features.shape[:-1] == (b, n, n) - assert g_features.shape[:-1] == (b,) - assert adj.shape == (b, n, n) + b, n, _ = node_fts.shape + assert edge_fts.shape[:-1] == (b, n, n) + assert graph_fts.shape[:-1] == (b,) + assert adj_mat.shape == (b, n, n) + z = jnp.concatenate([node_fts, hidden], axis=-1) m_1 = hk.Linear(self.mid_size) m_2 = hk.Linear(self.mid_size) m_e = hk.Linear(self.mid_size) @@ -289,10 +327,10 @@ def __call__( o1 = hk.Linear(self.out_size) o2 = hk.Linear(self.out_size) - msg_1 = m_1(features) - msg_2 = m_2(features) - msg_e = m_e(e_features) - msg_g = m_g(g_features) + msg_1 = m_1(z) + msg_2 = m_2(z) + msg_e = m_e(edge_fts) + msg_g = m_g(graph_fts) msgs = ( jnp.expand_dims(msg_1, axis=1) + jnp.expand_dims(msg_2, axis=2) + @@ -304,12 +342,12 @@ def __call__( msgs = self.mid_act(msgs) if self.reduction == jnp.mean: - msgs = jnp.sum(msgs * jnp.expand_dims(adj, -1), axis=-1) - msgs = msgs / jnp.sum(adj, axis=-1, keepdims=True) + msgs = jnp.sum(msgs * jnp.expand_dims(adj_mat, -1), axis=-1) + msgs = msgs / jnp.sum(adj_mat, axis=-1, keepdims=True) else: - msgs = self.reduction(msgs * jnp.expand_dims(adj, -1), axis=1) + msgs = self.reduction(msgs * jnp.expand_dims(adj_mat, -1), axis=1) - h_1 = o1(features) + h_1 = o1(z) h_2 = o2(msgs) ret = h_1 + h_2 @@ -320,19 +358,39 @@ def __call__( return ret -def _position_encoding(sentence_size: int, embedding_size: int) -> np.ndarray: - """Position Encoding described in section 4.1 [1].""" - encoding = np.ones((embedding_size, sentence_size), dtype=np.float32) - ls = sentence_size + 1 - le = embedding_size + 1 - for i in range(1, le): - for j in range(1, ls): - encoding[i - 1, j - 1] = (i - (le - 1) / 2) * (j - (ls - 1) / 2) - encoding = 1 + 4 * encoding / embedding_size / sentence_size - return np.transpose(encoding) +class DeepSets(PGN): + """Deep Sets (Zaheer et al., NeurIPS 2017).""" + + def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, + adj_mat: _Array, hidden: _Array, nb_nodes: int, + batch_size: int) -> _Array: + adj_mat = jnp.repeat( + jnp.expand_dims(jnp.eye(nb_nodes), 0), batch_size, axis=0) + return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) + + +class MPNN(PGN): + """Message-Passing Neural Network (Gilmer et al., ICML 2017).""" + + def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, + adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: + adj_mat = jnp.ones_like(adj_mat) + return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) + +class PGNMask(PGN): + """Masked Pointer Graph Networks (Veličković et al., NeurIPS 2020).""" -class MemNet(hk.Module): + @property + def inf_bias(self): + return True + + @property + def inf_bias_edge(self): + return True + + +class MemNetMasked(Processor): """Implementation of End-to-End Memory Networks. Inspired by the description in https://arxiv.org/abs/1503.08895. @@ -341,9 +399,9 @@ class MemNet(hk.Module): def __init__( self, vocab_size: int, - embedding_size: int, sentence_size: int, linear_output_size: int, + embedding_size: int = 16, memory_size: Optional[int] = 128, num_hops: int = 1, nonlin: Callable[[Any], Any] = jax.nn.relu, @@ -355,11 +413,11 @@ def __init__( Args: vocab_size: the number of words in the dictionary (each story, query and answer come contain symbols coming from this dictionary). - embedding_size: the dimensionality of the latent space to where all - memories are projected. sentence_size: the dimensionality of each memory. linear_output_size: the dimensionality of the output of the last layer of the model. + embedding_size: the dimensionality of the latent space to where all + memories are projected. memory_size: the number of memories provided. num_hops: the number of layers in the model. nonlin: non-linear transformation applied at the end of each layer. @@ -380,7 +438,30 @@ def __init__( # Encoding part: i.e. "I" of the paper. self._encodings = _position_encoding(sentence_size, embedding_size) - def __call__(self, queries: jnp.ndarray, stories: jnp.ndarray) -> jnp.ndarray: + def __call__( + self, + node_fts: _Array, + edge_fts: _Array, + graph_fts: _Array, + adj_mat: _Array, + hidden: _Array, + **unused_kwargs, + ) -> _Array: + """MemNet inference step.""" + + del hidden + node_and_graph_fts = jnp.concatenate([node_fts, graph_fts[:, None]], + axis=1) + edge_fts_padded = jnp.pad(edge_fts * adj_mat[..., None], + ((0, 0), (0, 1), (0, 1), (0, 0))) + nxt_hidden = jax.vmap(self._apply, (1), 1)(node_and_graph_fts, + edge_fts_padded) + + # Broadcast hidden state corresponding to graph features across the nodes. + nxt_hidden = nxt_hidden[:, :-1] + nxt_hidden[:, -1:] + return nxt_hidden + + def _apply(self, queries: _Array, stories: _Array) -> _Array: """Apply Memory Network to the queries and stories. Args: @@ -485,3 +566,64 @@ def __call__(self, queries: jnp.ndarray, stories: jnp.ndarray) -> jnp.ndarray: # This linear here is "W". return hk.Linear(self._vocab_size, with_bias=False)(output_layer) + + +class MemNetFull(MemNetMasked): + """Memory Networks with full adjacency matrix.""" + + def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, + adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: + adj_mat = jnp.ones_like(adj_mat) + return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) + + +def construct_processor(kind: str, hidden_dim: int, nb_heads: int) -> Processor: + """Constructs a processor.""" + if kind == 'deepsets': + processor = DeepSets( + out_size=hidden_dim, msgs_mlp_sizes=[hidden_dim, hidden_dim]) + elif kind == 'gat': + processor = GAT(out_size=hidden_dim, nb_heads=nb_heads) + elif kind == 'gat_full': + processor = GATFull(out_size=hidden_dim, nb_heads=nb_heads) + elif kind == 'gatv2': + processor = GATv2(out_size=hidden_dim, nb_heads=nb_heads) + elif kind == 'gatv2_full': + processor = GATv2Full(out_size=hidden_dim, nb_heads=nb_heads) + elif kind == 'memnet_full': + processor = MemNetFull( + vocab_size=hidden_dim, + sentence_size=hidden_dim, + linear_output_size=hidden_dim, + ) + elif kind == 'memnet_masked': + processor = MemNetMasked( + vocab_size=hidden_dim, + sentence_size=hidden_dim, + linear_output_size=hidden_dim, + ) + elif kind == 'mpnn': + processor = MPNN( + out_size=hidden_dim, msgs_mlp_sizes=[hidden_dim, hidden_dim]) + elif kind == 'pgn': + processor = PGN( + out_size=hidden_dim, msgs_mlp_sizes=[hidden_dim, hidden_dim]) + elif kind == 'pgn_mask': + processor = PGNMask( + out_size=hidden_dim, msgs_mlp_sizes=[hidden_dim, hidden_dim]) + else: + raise ValueError('Unexpected processor kind ' + kind) + + return processor + + +def _position_encoding(sentence_size: int, embedding_size: int) -> np.ndarray: + """Position Encoding described in section 4.1 [1].""" + encoding = np.ones((embedding_size, sentence_size), dtype=np.float32) + ls = sentence_size + 1 + le = embedding_size + 1 + for i in range(1, le): + for j in range(1, ls): + encoding[i - 1, j - 1] = (i - (le - 1) / 2) * (j - (ls - 1) / 2) + encoding = 1 + 4 * encoding / embedding_size / sentence_size + return np.transpose(encoding) diff --git a/clrs/_src/processors_test.py b/clrs/_src/processors_test.py index bb0f009e..273da26e 100644 --- a/clrs/_src/processors_test.py +++ b/clrs/_src/processors_test.py @@ -35,14 +35,14 @@ def test_simple_run_and_check_shapes(self): num_hops = 2 def forward_fn(queries, stories): - model = processors.MemNet( + model = processors.MemNetFull( vocab_size=vocab_size, embedding_size=embedding_size, sentence_size=sentence_size, memory_size=memory_size, linear_output_size=linear_output_size, num_hops=num_hops) - return model(queries, stories) + return model._apply(queries, stories) forward = hk.transform(forward_fn)