Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

baselines.py refactoring (6/N) #52

Merged
merged 1 commit into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 27 additions & 88 deletions clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def __init__(
decode_hints: bool,
decode_diffs: bool,
kind: str,
inf_bias: bool,
inf_bias_edge: bool,
use_lstm: bool,
dropout_prob: float,
nb_heads: int,
Expand All @@ -85,8 +83,6 @@ def __init__(

self._dropout_prob = dropout_prob
self.spec = spec
self.inf_bias = inf_bias
self.inf_bias_edge = inf_bias_edge
self.hidden_dim = hidden_dim
self.encode_hints = encode_hints
self.decode_hints = decode_hints
Expand Down Expand Up @@ -191,12 +187,16 @@ def __call__(self, features: _Features, repred: bool):
nb_nodes = inp.data.shape[1]
break

self._construct_encoders_decoders()
self._construct_processor()
# Construct encoders and decoders.
(self.encoders, self.decoders,
self.diff_decoders) = self._construct_encoders_decoders()
self.processor = processors.construct_processor(
kind=self.kind, hidden_dim=self.hidden_dim, nb_heads=self.nb_heads)

nb_mp_steps = max(1, hints[0].data.shape[0] - 1)
hiddens = jnp.zeros((self.batch_size, nb_nodes, self.hidden_dim))

# Optionally construct LSTM.
if self.use_lstm:
self.lstm = hk.LSTM(
hidden_size=self.hidden_dim,
Expand Down Expand Up @@ -258,58 +258,25 @@ def invert(d):

def _construct_encoders_decoders(self):
"""Constructs encoders and decoders."""
self.encoders = {}
self.decoders = {}
encoders_ = {}
decoders_ = {}

for name, (stage, loc, t) in self.spec.items():
if stage == _Stage.INPUT or (stage == _Stage.HINT and self.encode_hints):
# Build input encoders.
self.encoders[name] = encoders.construct_encoders(
encoders_[name] = encoders.construct_encoders(
loc, t, hidden_dim=self.hidden_dim)

if stage == _Stage.OUTPUT or (stage == _Stage.HINT and self.decode_hints):
# Build output decoders.
self.decoders[name] = decoders.construct_decoders(
decoders_[name] = decoders.construct_decoders(
loc, t, hidden_dim=self.hidden_dim, nb_dims=self.nb_dims[name])

if self.decode_diffs:
# Optionally build diff decoders.
self.diff_decoders = decoders.construct_diff_decoders()

def _construct_processor(self):
"""Constructs processor."""

if self.kind in ['deepsets', 'mpnn', 'pgn']:
self.mpnn = processors.MPNN(
out_size=self.hidden_dim,
mid_act=jax.nn.relu,
activation=jax.nn.relu,
reduction=jnp.max,
msgs_mlp_sizes=[
self.hidden_dim,
self.hidden_dim,
])
elif self.kind in ['gat', 'gat_full']:
self.mpnn = processors.GAT(
out_size=self.hidden_dim,
nb_heads=self.nb_heads,
activation=jax.nn.relu,
residual=True)
elif self.kind in ['gatv2', 'gatv2_full']:
self.mpnn = processors.GATv2(
out_size=self.hidden_dim,
nb_heads=self.nb_heads,
activation=jax.nn.relu,
residual=True)
elif self.kind == 'memnet_full' or self.kind == 'memnet_masked':
self.memnet = processors.MemNet(
vocab_size=self.hidden_dim,
embedding_size=16,
sentence_size=self.hidden_dim,
linear_output_size=self.hidden_dim,
memory_size=128,
num_hops=1,
apply_embeddings=True)
diff_decoders = decoders.construct_diff_decoders()

return encoders_, decoders_, diff_decoders

def _one_step_pred(
self,
Expand All @@ -329,7 +296,6 @@ def _one_step_pred(
jnp.expand_dims(jnp.eye(nb_nodes), 0), self.batch_size, axis=0)

# ENCODE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# Encode node/edge/graph features from inputs and (optionally) hints.
trajectories = [inputs]
if self.encode_hints:
Expand All @@ -348,33 +314,15 @@ def _one_step_pred(
raise Exception(f'Failed to process {dp}') from e

# PROCESS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

if self.kind == 'deepsets':
adj_mat = jnp.repeat(
jnp.expand_dims(jnp.eye(nb_nodes), 0), self.batch_size, axis=0)
elif (self.kind == 'mpnn' or self.kind == 'gat_full' or
self.kind == 'gatv2_full' or self.kind == 'memnet_full'):
adj_mat = jnp.ones_like(adj_mat)
elif (self.kind == 'pgn' or self.kind == 'gat' or self.kind == 'gatv2' or
self.kind == 'memnet_masked'):
adj_mat = (adj_mat > 0.0) * 1.0
else:
raise ValueError('Unsupported kind of model')

z = jnp.concatenate([node_fts, hidden], axis=-1)
if self.kind == 'memnet_full' or self.kind == 'memnet_masked':
node_and_graph_fts = jnp.concatenate(
[node_fts, graph_fts[:, None]], axis=1)
edge_fts_padded = jnp.pad(edge_fts * adj_mat[..., None],
((0, 0), (0, 1), (0, 1), (0, 0)))
nxt_hidden = jax.vmap(self.memnet, (1), 1)(node_and_graph_fts,
edge_fts_padded)
# Broadcast hidden state corresponding to graph features across the nodes.
nxt_hidden = nxt_hidden[:, :-1] + nxt_hidden[:, -1:]
else:
nxt_hidden = self.mpnn(z, edge_fts, graph_fts,
(adj_mat > 0.0).astype('float32'))

nxt_hidden = self.processor(
node_fts,
edge_fts,
graph_fts,
adj_mat,
hidden,
batch_size=self.batch_size,
nb_nodes=nb_nodes,
)
nxt_hidden = hk.dropout(hk.next_rng_key(), self._dropout_prob, nxt_hidden)

if self.use_lstm:
Expand All @@ -384,10 +332,9 @@ def _one_step_pred(
else:
nxt_lstm_state = None

h_t = jnp.concatenate([z, nxt_hidden], axis=-1)
h_t = jnp.concatenate([node_fts, hidden, nxt_hidden], axis=-1)

# DECODE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# Decode features and (optionally) hints.
hint_preds, output_preds = decoders.decode_fts(
decoders=self.decoders,
Expand All @@ -396,8 +343,8 @@ def _one_step_pred(
adj_mat=adj_mat,
edge_fts=edge_fts,
graph_fts=graph_fts,
inf_bias=self.inf_bias,
inf_bias_edge=self.inf_bias_edge,
inf_bias=self.processor.inf_bias,
inf_bias_edge=self.processor.inf_bias_edge,
)

# Optionally decode diffs.
Expand Down Expand Up @@ -447,14 +394,6 @@ def __init__(
self._freeze_processor = freeze_processor
self.opt = optax.adam(learning_rate)

if kind == 'pgn_mask':
inf_bias = True
inf_bias_edge = True
kind = 'pgn'
else:
inf_bias = False
inf_bias_edge = False

self.nb_dims = {}
for inp in dummy_trajectory.features.inputs:
self.nb_dims[inp.name] = inp.data.shape[-1]
Expand All @@ -465,8 +404,8 @@ def __init__(

def _use_net(*args, **kwargs):
return Net(spec, hidden_dim, encode_hints, decode_hints, decode_diffs,
kind, inf_bias, inf_bias_edge, use_lstm, dropout_prob,
nb_heads, self.nb_dims)(*args, **kwargs)
kind, use_lstm, dropout_prob, nb_heads, self.nb_dims)(*args,
**kwargs)

self.net_fn = hk.transform(_use_net)
self.net_fn_apply = jax.jit(self.net_fn.apply, static_argnums=3)
Expand Down
6 changes: 3 additions & 3 deletions clrs/_src/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def preprocess(dp: _DataPoint, nb_nodes: int) -> _Array:
def accum_adj_mat(dp: _DataPoint, data: _Array, adj_mat: _Array) -> _Array:
"""Accumulates adjacency matrix."""
if dp.location == _Location.NODE and dp.type_ == _Type.POINTER:
adj_mat += ((data + jnp.transpose(data, (0, 2, 1))) > 0.0).astype('float32')
adj_mat += ((data + jnp.transpose(data, (0, 2, 1))) > 0.0)
elif dp.location == _Location.EDGE and dp.type_ == _Type.MASK:
adj_mat += (data > 0.0).astype('float32')
adj_mat += (data > 0.0)

return adj_mat
return (adj_mat > 0.).astype('float32')


def accum_edge_fts(encoders, dp: _DataPoint, data: _Array,
Expand Down
Loading