Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

baselines.py refactoring (2/N) #36

Merged
merged 1 commit into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 21 additions & 75 deletions clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,19 +261,16 @@ def invert(d):

def _construct_encoders_decoders(self):
"""Constructs encoders and decoders."""
self.enc_inp = {}
self.encoders = {}
self.dec_out = {}

if self.encode_hints:
self.enc_hint = {}

if self.decode_hints:
self.dec_hint = {}

for name, (stage, loc, t) in self.spec.items():
if stage == _Stage.INPUT:
# Build input encoders.
self.enc_inp[name] = encoders.construct_encoder(
self.encoders[name] = encoders.construct_encoders(
loc, t, hidden_dim=self.hidden_dim)

elif stage == _Stage.OUTPUT:
Expand All @@ -284,7 +281,7 @@ def _construct_encoders_decoders(self):
elif stage == _Stage.HINT:
# Optionally build hint encoder/decoders.
if self.encode_hints:
self.enc_hint[name] = encoders.construct_encoder(
self.encoders[name] = encoders.construct_encoders(
loc, t, hidden_dim=self.hidden_dim)

if self.decode_hints:
Expand All @@ -299,6 +296,7 @@ def _construct_encoders_decoders(self):

def _construct_processor(self):
"""Constructs processor."""

if self.kind in ['deepsets', 'mpnn', 'pgn']:
self.mpnn = processors.MPNN(
out_size=self.hidden_dim,
Expand Down Expand Up @@ -338,83 +336,31 @@ def _one_step_pred(
nb_nodes: int,
lstm_state: Optional[hk.LSTMState],
):
"""Generates one step predictions."""
"""Generates one-step predictions."""

# Initialise empty node/edge/graph features and adjacency matrix.
node_fts = jnp.zeros((self.batch_size, nb_nodes, self.hidden_dim))
edge_fts = jnp.zeros((self.batch_size, nb_nodes, nb_nodes, self.hidden_dim))
graph_fts = jnp.zeros((self.batch_size, self.hidden_dim))
adj_mat = jnp.repeat(
jnp.expand_dims(jnp.eye(nb_nodes), 0), self.batch_size, axis=0)

for inp in inputs:
# Extract shared logic with hints and loss
encoder = self.enc_inp[inp.name][0]
if inp.type_ == _Type.POINTER:
in_data = hk.one_hot(inp.data, nb_nodes)
else:
in_data = inp.data.astype(jnp.float32)
if inp.type_ == _Type.CATEGORICAL:
encoding = encoder(in_data)
else:
encoding = encoder(jnp.expand_dims(in_data, -1))
if inp.location == _Location.NODE:
if inp.type_ == _Type.POINTER:
edge_fts += encoding
adj_mat += ((in_data + jnp.transpose(in_data, (0, 2, 1))) >
0.0).astype('float32')
else:
node_fts += encoding
elif inp.location == _Location.EDGE:
if inp.type_ == _Type.POINTER:
# Aggregate pointer contributions across sender and receiver nodes
encoding_2 = self.enc_inp[inp.name][1](jnp.expand_dims(in_data, -1))
edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(encoding_2, axis=2)
else:
edge_fts += encoding
if inp.type_ == _Type.MASK:
adj_mat += (in_data > 0.0).astype('float32')
elif inp.location == _Location.GRAPH:
if inp.type_ == _Type.POINTER:
node_fts += encoding
else:
graph_fts += encoding

# Encode node/edge/graph features from inputs and (optionally) hints.
trajectories = [inputs]
if self.encode_hints:
for hint in hints:
encoder = self.enc_hint[hint.name][0]
if hint.type_ == _Type.POINTER:
in_data = hk.one_hot(hint.data, nb_nodes)
else:
in_data = hint.data.astype(jnp.float32)
if hint.type_ == _Type.CATEGORICAL:
encoding = encoder(in_data)
else:
encoding = encoder(jnp.expand_dims(in_data, -1))
if hint.location == _Location.NODE:
if hint.type_ == _Type.POINTER:
edge_fts += encoding
adj_mat += ((in_data + jnp.transpose(in_data, (0, 2, 1))) >
0.0).astype('float32')
else:
node_fts += encoding
elif hint.location == _Location.EDGE:
if hint.type_ == _Type.POINTER:
# Aggregate pointer contributions across sender and receiver nodes
encoding_2 = self.enc_hint[hint.name][1](
jnp.expand_dims(in_data, -1))
edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(
encoding_2, axis=2)
else:
edge_fts += encoding
if hint.type_ == _Type.MASK:
adj_mat += (in_data > 0.0).astype('float32')
elif hint.location == _Location.GRAPH:
if hint.type_ == _Type.POINTER:
node_fts += encoding
else:
graph_fts += encoding
else:
raise ValueError('Invalid hint location')
trajectories.append(hints)

for trajectory in trajectories:
for dp in trajectory:
try:
data = encoders.preprocess(dp, nb_nodes)
adj_mat = encoders.accum_adj_mat(dp, data, adj_mat)
encoder = self.encoders[dp.name]
edge_fts = encoders.accum_edge_fts(encoder, dp, data, edge_fts)
node_fts = encoders.accum_node_fts(encoder, dp, data, node_fts)
graph_fts = encoders.accum_graph_fts(encoder, dp, data, graph_fts)
except Exception as e:
raise Exception(f'Failed to process {dp}') from e

if self.kind == 'deepsets':
adj_mat = jnp.repeat(
Expand Down
85 changes: 79 additions & 6 deletions clrs/_src/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,96 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Encoder utilities."""

import chex
from clrs._src import probing
from clrs._src import specs
import haiku as hk
import jax.numpy as jnp


_Array = chex.Array
_DataPoint = probing.DataPoint
_Location = specs.Location
_Spec = specs.Spec
_Type = specs.Type


def construct_encoder(loc: str, t: str, hidden_dim: int):
def construct_encoders(loc: str, t: str, hidden_dim: int):
"""Constructs an encoder."""
encoder = [hk.Linear(hidden_dim)]
encoders = [hk.Linear(hidden_dim)]
if loc == _Location.EDGE and t == _Type.POINTER:
# Edge pointers need two-way encoders.
encoder.append(hk.Linear(hidden_dim))
encoders.append(hk.Linear(hidden_dim))

return encoders


def preprocess(dp: _DataPoint, nb_nodes: int) -> _Array:
"""Pre-process data point."""
if dp.type_ == _Type.POINTER:
data = hk.one_hot(dp.data, nb_nodes)
else:
data = dp.data.astype(jnp.float32)

return data


def accum_adj_mat(dp: _DataPoint, data: _Array, adj_mat: _Array) -> _Array:
"""Accumulates adjacency matrix."""
if dp.location == _Location.NODE and dp.type_ == _Type.POINTER:
adj_mat += ((data + jnp.transpose(data, (0, 2, 1))) > 0.0).astype('float32')
elif dp.location == _Location.EDGE and dp.type_ == _Type.MASK:
adj_mat += (data > 0.0).astype('float32')

return adj_mat


def accum_edge_fts(encoders, dp: _DataPoint, data: _Array,
edge_fts: _Array) -> _Array:
"""Encodes and accumulates edge features."""
encoding = _encode_inputs(encoders, dp, data)

if dp.location == _Location.NODE and dp.type_ == _Type.POINTER:
edge_fts += encoding

elif dp.location == _Location.EDGE:
if dp.type_ == _Type.POINTER:
# Aggregate pointer contributions across sender and receiver nodes.
encoding_2 = encoders[1](jnp.expand_dims(data, -1))
edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(encoding_2, axis=2)
else:
edge_fts += encoding

return edge_fts


def accum_node_fts(encoders, dp: _DataPoint, data: _Array,
node_fts: _Array) -> _Array:
"""Encodes and accumulates node features."""
encoding = _encode_inputs(encoders, dp, data)

if ((dp.location == _Location.NODE and dp.type_ != _Type.POINTER) or
(dp.location == _Location.GRAPH and dp.type_ == _Type.POINTER)):
node_fts += encoding

return node_fts


def accum_graph_fts(encoders, dp: _DataPoint, data: _Array,
graph_fts: _Array) -> _Array:
"""Encodes and accumulates graph features."""
encoding = _encode_inputs(encoders, dp, data)

if dp.location == _Location.GRAPH and dp.type_ != _Type.POINTER:
graph_fts += encoding

return graph_fts


return encoder
def _encode_inputs(encoders, dp: _DataPoint, data: _Array) -> _Array:
if dp.type_ == _Type.CATEGORICAL:
encoding = encoders[0](data)
else:
encoding = encoders[0](jnp.expand_dims(data, -1))
return encoding