Skip to content

Commit

Permalink
baselines.py refactoring (2/N)
Browse files Browse the repository at this point in the history
- Move encoder logic to encoders.py
- Remove repeated code over inputs/hints
- Maintain a single `self.encoders` (input/hint names are unique per alg)

PiperOrigin-RevId: 425651010
  • Loading branch information
dbudden authored and copybara-github committed Feb 1, 2022
1 parent 3e9d59c commit a689c82
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 81 deletions.
100 changes: 24 additions & 76 deletions clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,19 +261,16 @@ def invert(d):

def _construct_encoders_decoders(self):
"""Constructs encoders and decoders."""
self.enc_inp = {}
self.encoders = {}
self.dec_out = {}

if self.encode_hints:
self.enc_hint = {}

if self.decode_hints:
self.dec_hint = {}

for name, (stage, loc, t) in self.spec.items():
if stage == _Stage.INPUT:
# Build input encoders.
self.enc_inp[name] = encoders.construct_encoder(
self.encoders[name] = encoders.construct_encoder(
loc, t, hidden_dim=self.hidden_dim)

elif stage == _Stage.OUTPUT:
Expand All @@ -284,7 +281,7 @@ def _construct_encoders_decoders(self):
elif stage == _Stage.HINT:
# Optionally build hint encoder/decoders.
if self.encode_hints:
self.enc_hint[name] = encoders.construct_encoder(
self.encoders[name] = encoders.construct_encoder(
loc, t, hidden_dim=self.hidden_dim)

if self.decode_hints:
Expand All @@ -299,6 +296,8 @@ def _construct_encoders_decoders(self):

def _construct_processor(self):
"""Constructs processor."""

# TODO(budden): Move this logic to `processors.py`.
if self.kind in ['deepsets', 'mpnn', 'pgn']:
self.mpnn = processors.MPNN(
out_size=self.hidden_dim,
Expand Down Expand Up @@ -338,84 +337,33 @@ def _one_step_pred(
nb_nodes: int,
lstm_state: Optional[hk.LSTMState],
):
"""Generates one step predictions."""
"""Generates one-step predictions."""

# Initialise empty node/edge/graph features and adjacency matrix.
node_fts = jnp.zeros((self.batch_size, nb_nodes, self.hidden_dim))
edge_fts = jnp.zeros((self.batch_size, nb_nodes, nb_nodes, self.hidden_dim))
graph_fts = jnp.zeros((self.batch_size, self.hidden_dim))
adj_mat = jnp.repeat(
jnp.expand_dims(jnp.eye(nb_nodes), 0), self.batch_size, axis=0)

for inp in inputs:
# Extract shared logic with hints and loss
encoder = self.enc_inp[inp.name][0]
if inp.type_ == _Type.POINTER:
in_data = hk.one_hot(inp.data, nb_nodes)
else:
in_data = inp.data.astype(jnp.float32)
if inp.type_ == _Type.CATEGORICAL:
encoding = encoder(in_data)
else:
encoding = encoder(jnp.expand_dims(in_data, -1))
if inp.location == _Location.NODE:
if inp.type_ == _Type.POINTER:
edge_fts += encoding
adj_mat += ((in_data + jnp.transpose(in_data, (0, 2, 1))) >
0.0).astype('float32')
else:
node_fts += encoding
elif inp.location == _Location.EDGE:
if inp.type_ == _Type.POINTER:
# Aggregate pointer contributions across sender and receiver nodes
encoding_2 = self.enc_inp[inp.name][1](jnp.expand_dims(in_data, -1))
edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(encoding_2, axis=2)
else:
edge_fts += encoding
if inp.type_ == _Type.MASK:
adj_mat += (in_data > 0.0).astype('float32')
elif inp.location == _Location.GRAPH:
if inp.type_ == _Type.POINTER:
node_fts += encoding
else:
graph_fts += encoding

# Encode node/edge/graph features from inputs and (optionally) hints.
trajectories = [inputs]
if self.encode_hints:
for hint in hints:
encoder = self.enc_hint[hint.name][0]
if hint.type_ == _Type.POINTER:
in_data = hk.one_hot(hint.data, nb_nodes)
else:
in_data = hint.data.astype(jnp.float32)
if hint.type_ == _Type.CATEGORICAL:
encoding = encoder(in_data)
else:
encoding = encoder(jnp.expand_dims(in_data, -1))
if hint.location == _Location.NODE:
if hint.type_ == _Type.POINTER:
edge_fts += encoding
adj_mat += ((in_data + jnp.transpose(in_data, (0, 2, 1))) >
0.0).astype('float32')
else:
node_fts += encoding
elif hint.location == _Location.EDGE:
if hint.type_ == _Type.POINTER:
# Aggregate pointer contributions across sender and receiver nodes
encoding_2 = self.enc_hint[hint.name][1](
jnp.expand_dims(in_data, -1))
edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(
encoding_2, axis=2)
else:
edge_fts += encoding
if hint.type_ == _Type.MASK:
adj_mat += (in_data > 0.0).astype('float32')
elif hint.location == _Location.GRAPH:
if hint.type_ == _Type.POINTER:
node_fts += encoding
else:
graph_fts += encoding
else:
raise ValueError('Invalid hint location')

trajectories.append(hints)

for trajectory in trajectories:
for dp in trajectory:
try:
data = encoders.preprocess(dp, nb_nodes)
adj_mat = encoders.accum_adj_mat(dp, data, adj_mat)
encoder = self.encoders[dp.name]
edge_fts = encoders.accum_edge_fts(encoder, dp, data, edge_fts)
node_fts = encoders.accum_node_fts(encoder, dp, data, node_fts)
graph_fts = encoders.accum_graph_fts(encoder, dp, data, graph_fts)
except Exception as e:
raise Exception(f'Failed to process {dp}') from e

# TODO(budden): Move this logic to `processors.py`.
if self.kind == 'deepsets':
adj_mat = jnp.repeat(
jnp.expand_dims(jnp.eye(nb_nodes), 0), self.batch_size, axis=0)
Expand Down
83 changes: 78 additions & 5 deletions clrs/_src/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,96 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Encoder utilities."""

import chex
from clrs._src import probing
from clrs._src import specs
import haiku as hk
import jax.numpy as jnp


_Array = chex.Array
_DataPoint = probing.DataPoint
_Location = specs.Location
_Spec = specs.Spec
_Type = specs.Type


def construct_encoder(loc: str, t: str, hidden_dim: int):
"""Constructs an encoder."""
encoder = [hk.Linear(hidden_dim)]
encoders = [hk.Linear(hidden_dim)]
if loc == _Location.EDGE and t == _Type.POINTER:
# Edge pointers need two-way encoders.
encoder.append(hk.Linear(hidden_dim))
encoders.append(hk.Linear(hidden_dim))

return encoders


def preprocess(dp: _DataPoint, nb_nodes: int) -> _Array:
"""Pre-process data point."""
if dp.type_ == _Type.POINTER:
data = hk.one_hot(dp.data, nb_nodes)
else:
data = dp.data.astype(jnp.float32)

return data


def accum_adj_mat(dp: _DataPoint, data: _Array, adj_mat: _Array) -> _Array:
"""Accumulates adjacency matrix."""
if dp.location == _Location.NODE and dp.type_ == _Type.POINTER:
adj_mat += ((data + jnp.transpose(data, (0, 2, 1))) > 0.0).astype('float32')
elif dp.location == _Location.EDGE and dp.type_ == _Type.MASK:
adj_mat += (data > 0.0).astype('float32')

return adj_mat


def accum_edge_fts(encoders, dp: _DataPoint, data: _Array,
edge_fts: _Array) -> _Array:
"""Encodes and accumulates edge features."""
encoding = _encode_inputs(encoders, dp, data)

if dp.location == _Location.NODE and dp.type_ == _Type.POINTER:
edge_fts += encoding

elif dp.location == _Location.EDGE:
if dp.type_ == _Type.POINTER:
# Aggregate pointer contributions across sender and receiver nodes.
encoding_2 = encoders[1](jnp.expand_dims(data, -1))
edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(encoding_2, axis=2)
else:
edge_fts += encoding

return edge_fts


def accum_node_fts(encoders, dp: _DataPoint, data: _Array,
node_fts: _Array) -> _Array:
"""Encodes and accumulates node features."""
encoding = _encode_inputs(encoders, dp, data)

if ((dp.location == _Location.NODE and dp.type_ != _Type.POINTER) or
(dp.location == _Location.GRAPH and dp.type_ == _Type.POINTER)):
node_fts += encoding

return node_fts


def accum_graph_fts(encoders, dp: _DataPoint, data: _Array,
graph_fts: _Array) -> _Array:
"""Encodes and accumulates graph features."""
encoding = _encode_inputs(encoders, dp, data)

if dp.location == _Location.GRAPH and dp.type_ != _Type.POINTER:
graph_fts += encoding

return graph_fts


return encoder
def _encode_inputs(encoders, dp: _DataPoint, data: _Array) -> _Array:
if dp.type_ == _Type.CATEGORICAL:
encoding = encoders[0](data)
else:
encoding = encoders[0](jnp.expand_dims(data, -1))
return encoding

0 comments on commit a689c82

Please sign in to comment.