diff --git a/clrs/_src/algorithms/graphs.py b/clrs/_src/algorithms/graphs.py index 4d0b3652..09303e6e 100644 --- a/clrs/_src/algorithms/graphs.py +++ b/clrs/_src/algorithms/graphs.py @@ -542,7 +542,7 @@ def bridges(A: _Array) -> _Out: low = np.zeros(A.shape[0]) is_bridge = ( - np.zeros((A.shape[0], A.shape[0])) + _OutputClass.MASKED.value + adj) + np.zeros((A.shape[0], A.shape[0])) + _OutputClass.MASKED + adj) for s in range(A.shape[0]): if color[s] == 0: @@ -1523,7 +1523,7 @@ def bipartite_matching(A: _Array, n: int, m: int, s: int, t: int) -> _Out: 't': probing.mask_one(t, A.shape[0]) }) in_matching = ( - np.zeros((A.shape[0], A.shape[1])) + _OutputClass.MASKED.value + adj + np.zeros((A.shape[0], A.shape[1])) + _OutputClass.MASKED + adj + adj.T) u = t while True: diff --git a/clrs/_src/baselines.py b/clrs/_src/baselines.py index 219d387c..a4f91361 100644 --- a/clrs/_src/baselines.py +++ b/clrs/_src/baselines.py @@ -135,7 +135,7 @@ def _msg_passing_step(self, hint_cur = jnp.expand_dims(hint_cur, -1) hint_nxt = jnp.expand_dims(hint_nxt, -1) gt_diffs[hint.location] += jnp.any(hint_cur != hint_nxt, axis=-1) - for loc in _Location: + for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: gt_diffs[loc] = (gt_diffs[loc] > 0.0).astype(jnp.float32) * 1.0 (hiddens, output_preds_cand, hint_preds, diff_logits, @@ -156,7 +156,7 @@ def _msg_passing_step(self, if self.decode_hints: if hints[0].data.shape[0] == 1 or repred: diff_preds = {} - for loc in _Location: + for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: diff_preds[loc] = (diff_logits[loc] > 0.0).astype(jnp.float32) * 1.0 else: diff_preds = gt_diffs @@ -730,7 +730,7 @@ def loss(params, feedback): total_loss = 0.0 lengths = feedback.features.lengths if self.decode_diffs: - for loc in _Location: + for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: for i in range(len(gt_diffs)): is_not_done = _is_not_done_broadcast( lengths, i, diff_logits[i][loc]) @@ -767,7 +767,7 @@ def loss(params, feedback): loss = jnp.mean( jnp.maximum(pred, 0) - pred * truth.data[i + 1] + jnp.log1p(jnp.exp(-jnp.abs(pred))) * is_not_done) - mask = (truth.data != _OutputClass.MASKED.value).astype( + mask = (truth.data != _OutputClass.MASKED).astype( jnp.float32) total_loss += jnp.sum(loss*mask)/jnp.sum(mask) elif truth.type_ == _Type.MASK_ONE: @@ -783,9 +783,9 @@ def loss(params, feedback): pred) * is_not_done, axis=-1)) elif truth.type_ == _Type.CATEGORICAL: unmasked_data = truth.data[ - truth.data == _OutputClass.POSITIVE.value] + truth.data == _OutputClass.POSITIVE] masked_truth = truth.data * ( - truth.data != _OutputClass.MASKED.value).astype(jnp.float32) + truth.data != _OutputClass.MASKED).astype(jnp.float32) if self.decode_diffs: total_loss += jnp.sum( -jnp.sum( @@ -820,12 +820,12 @@ def loss(params, feedback): elif truth.type_ == _Type.MASK: loss = (jnp.maximum(pred, 0) - pred * truth.data + jnp.log1p(jnp.exp(-jnp.abs(pred)))) - mask = (truth.data != _OutputClass.MASKED.value).astype(jnp.float32) + mask = (truth.data != _OutputClass.MASKED).astype(jnp.float32) total_loss += jnp.sum(loss*mask)/jnp.sum(mask) elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]: - unmasked_data = truth.data[truth.data == _OutputClass.POSITIVE.value] + unmasked_data = truth.data[truth.data == _OutputClass.POSITIVE] masked_truth = truth.data * ( - truth.data != _OutputClass.MASKED.value).astype(jnp.float32) + truth.data != _OutputClass.MASKED).astype(jnp.float32) total_loss += ( -jnp.sum(masked_truth * jax.nn.log_softmax(pred)) / jnp.sum(unmasked_data)) @@ -864,14 +864,14 @@ def verbose_loss(self, feedback: _Feedback, extra_info) -> Dict[str, _Array]: losses = {} if self.decode_diffs: - for loc in _Location: + for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: for i in range(len(gt_diffs)): is_not_done = _is_not_done_broadcast(lengths, i, gt_diffs[i][loc]) diff_loss = ( jnp.maximum(diff_logits[i][loc], 0) - diff_logits[i][loc] * gt_diffs[i][loc] + jnp.log1p(jnp.exp(-jnp.abs(diff_logits[i][loc]))) * is_not_done) - losses[loc.name + '_diff_%d' % i] = jnp.mean(diff_loss) + losses[loc + '_diff_%d' % i] = jnp.mean(diff_loss) if self.decode_hints: for truth in feedback.features.hints: diff --git a/clrs/_src/model.py b/clrs/_src/model.py index 22c91487..f4bceaaf 100644 --- a/clrs/_src/model.py +++ b/clrs/_src/model.py @@ -73,14 +73,14 @@ def evaluate( def _eval_one(pred, truth): - mask = np.all(truth != specs.OutputClass.MASKED.value, axis=-1) + mask = np.all(truth != specs.OutputClass.MASKED, axis=-1) return np.sum( (np.argmax(pred, -1) == np.argmax(truth, -1)) * mask) / np.sum(mask) def _mask_fn(pred, truth): """Evaluate outputs of type MASK, and account for any class imbalance.""" - mask = (truth != specs.OutputClass.MASKED.value).astype(np.float32) + mask = (truth != specs.OutputClass.MASKED).astype(np.float32) # Use F1 score for the masked outputs to address any imbalance tp = np.sum((((pred > 0.5) * (truth > 0.5)) * 1.0) * mask) diff --git a/clrs/_src/probing.py b/clrs/_src/probing.py index b18083fe..06f09033 100644 --- a/clrs/_src/probing.py +++ b/clrs/_src/probing.py @@ -25,9 +25,11 @@ from typing import Dict, List, Tuple, Union +import attr from clrs._src import specs import jax import numpy as np +import tensorflow as tf _Location = specs.Location @@ -37,34 +39,48 @@ _Array = np.ndarray _Data = Union[_Array, List[_Array]] -_DataOrType = Union[_Data, _Type] +_DataOrType = Union[_Data, str] -ProbesDict = Dict[_Stage, Dict[_Location, Dict[str, Dict[str, _DataOrType]]]] +ProbesDict = Dict[ + str, Dict[str, Dict[str, Dict[str, _DataOrType]]]] -class ProbeError(Exception): - pass +def _convert_to_str(element): + if isinstance(element, tf.Tensor): + return element.numpy().decode('utf-8') + elif isinstance(element, (np.ndarray, bytes)): + return element.decode('utf-8') + else: + return element +# First anotation makes this object jax.jit/pmap friendly, second one makes this +# tf.data.Datasets friendly. @jax.tree_util.register_pytree_node_class +@attr.define class DataPoint: """Describes a data point.""" - def __init__( - self, - name: str, - location: _Location, - type_: _Type, - data: _Array, - ): - self.name = name - self.location = location - self.type_ = type_ - self.data = data + _name: str + _location: str + _type_: str + data: _Array + + @property + def name(self): + return _convert_to_str(self._name) + + @property + def location(self): + return _convert_to_str(self._location) + + @property + def type_(self): + return _convert_to_str(self._type_) def __repr__(self): - s = f'DataPoint(name="{self.name}",\tlocation={self.location.name},\t' - return s + f'type={self.type_.name},\tdata=Array{self.data.shape})' + s = f'DataPoint(name="{self.name}",\tlocation={self.location},\t' + return s + f'type={self.type_},\tdata=Array{self.data.shape})' def tree_flatten(self): data = (self.data,) @@ -78,6 +94,10 @@ def tree_unflatten(cls, meta, data): return DataPoint(name, location, type_, subdata) +class ProbeError(Exception): + pass + + def initialize(spec: specs.Spec) -> ProbesDict: """Initializes an empty `ProbesDict` corresponding with the provided spec.""" probes = dict() @@ -89,13 +109,14 @@ def initialize(spec: specs.Spec) -> ProbesDict: for name in spec: stage, loc, t = spec[name] probes[stage][loc][name] = {} - probes[stage][loc][name]['type_'] = t probes[stage][loc][name]['data'] = [] - - return probes + probes[stage][loc][name]['type_'] = t + # Pytype thinks initialize() returns a ProbesDict with a str for all final + # values instead of _DataOrType. + return probes # pytype: disable=bad-return-type -def push(probes: ProbesDict, stage: _Stage, next_probe): +def push(probes: ProbesDict, stage: str, next_probe): """Pushes a probe into an existing `ProbesDict`.""" for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: for name in probes[stage][loc]: @@ -103,7 +124,9 @@ def push(probes: ProbesDict, stage: _Stage, next_probe): raise ProbeError(f'Missing probe for {name}.') if isinstance(probes[stage][loc][name]['data'], _Array): raise ProbeError('Attemping to push to finalized `ProbesDict`.') - probes[stage][loc][name]['data'].append(next_probe[name]) + # Pytype thinks initialize() returns a ProbesDict with a str for all final + # values instead of _DataOrType. + probes[stage][loc][name]['data'].append(next_probe[name]) # pytype: disable=attribute-error def finalize(probes: ProbesDict): @@ -249,15 +272,15 @@ def strings_pair_cat(pair_probe: np.ndarray, nb_classes: int) -> np.ndarray: probe_ret = np.zeros((n + m, n + m, nb_classes + 1)) for i in range(0, n): for j in range(0, m): - probe_ret[i, j + n, int(pair_probe[i, j])] = _OutputClass.POSITIVE.value + probe_ret[i, j + n, int(pair_probe[i, j])] = _OutputClass.POSITIVE # Fill the blank cells. for i_1 in range(0, n): for i_2 in range(0, n): - probe_ret[i_1, i_2, nb_classes] = _OutputClass.MASKED.value + probe_ret[i_1, i_2, nb_classes] = _OutputClass.MASKED for j_1 in range(0, m): for x in range(0, n + m): - probe_ret[j_1 + n, x, nb_classes] = _OutputClass.MASKED.value + probe_ret[j_1 + n, x, nb_classes] = _OutputClass.MASKED return probe_ret diff --git a/clrs/_src/specs.py b/clrs/_src/specs.py index 0cead2fd..d5d8eaa4 100644 --- a/clrs/_src/specs.py +++ b/clrs/_src/specs.py @@ -30,31 +30,23 @@ for representing sequential data where appropriate """ -import enum import types from typing import Dict, Tuple -class _OrderedEnum(enum.Enum): - - def __lt__(self, other): - assert self.__class__ is other.__class__ - return self.value < other.value # pylint: disable=comparison-with-callable - - -class Stage(_OrderedEnum): +class Stage: INPUT = 'input' OUTPUT = 'output' HINT = 'hint' -class Location(_OrderedEnum): +class Location: NODE = 'node' EDGE = 'edge' GRAPH = 'graph' -class Type(_OrderedEnum): +class Type: SCALAR = 'scalar' CATEGORICAL = 'categorical' MASK = 'mask' @@ -62,12 +54,12 @@ class Type(_OrderedEnum): POINTER = 'pointer' -class OutputClass(_OrderedEnum): +class OutputClass: POSITIVE = 1 NEGATIVE = 0 MASKED = -1 -Spec = Dict[str, Tuple[Stage, Location, Type]] +Spec = Dict[str, Tuple[str, str, str]] CLRS_21_ALGS = [ 'bellman_ford',