Skip to content

Commit

Permalink
Change types of DataPoint and DataPoint members.
Browse files Browse the repository at this point in the history
DataPoint is now an attrs class and its members are no longer Enums. This is to conform to the tf.data.Datasets interface.

PiperOrigin-RevId: 421839387
  • Loading branch information
adria-p authored and copybara-github committed Jan 17, 2022
1 parent dcdc075 commit d79c4ae
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 53 deletions.
4 changes: 2 additions & 2 deletions clrs/_src/algorithms/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def bridges(A: _Array) -> _Out:

low = np.zeros(A.shape[0])
is_bridge = (
np.zeros((A.shape[0], A.shape[0])) + _OutputClass.MASKED.value + adj)
np.zeros((A.shape[0], A.shape[0])) + _OutputClass.MASKED + adj)

for s in range(A.shape[0]):
if color[s] == 0:
Expand Down Expand Up @@ -1523,7 +1523,7 @@ def bipartite_matching(A: _Array, n: int, m: int, s: int, t: int) -> _Out:
't': probing.mask_one(t, A.shape[0])
})
in_matching = (
np.zeros((A.shape[0], A.shape[1])) + _OutputClass.MASKED.value + adj
np.zeros((A.shape[0], A.shape[1])) + _OutputClass.MASKED + adj
+ adj.T)
u = t
while True:
Expand Down
22 changes: 11 additions & 11 deletions clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _msg_passing_step(self,
hint_cur = jnp.expand_dims(hint_cur, -1)
hint_nxt = jnp.expand_dims(hint_nxt, -1)
gt_diffs[hint.location] += jnp.any(hint_cur != hint_nxt, axis=-1)
for loc in _Location:
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
gt_diffs[loc] = (gt_diffs[loc] > 0.0).astype(jnp.float32) * 1.0

(hiddens, output_preds_cand, hint_preds, diff_logits,
Expand All @@ -156,7 +156,7 @@ def _msg_passing_step(self,
if self.decode_hints:
if hints[0].data.shape[0] == 1 or repred:
diff_preds = {}
for loc in _Location:
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
diff_preds[loc] = (diff_logits[loc] > 0.0).astype(jnp.float32) * 1.0
else:
diff_preds = gt_diffs
Expand Down Expand Up @@ -730,7 +730,7 @@ def loss(params, feedback):
total_loss = 0.0
lengths = feedback.features.lengths
if self.decode_diffs:
for loc in _Location:
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
for i in range(len(gt_diffs)):
is_not_done = _is_not_done_broadcast(
lengths, i, diff_logits[i][loc])
Expand Down Expand Up @@ -767,7 +767,7 @@ def loss(params, feedback):
loss = jnp.mean(
jnp.maximum(pred, 0) - pred * truth.data[i + 1] +
jnp.log1p(jnp.exp(-jnp.abs(pred))) * is_not_done)
mask = (truth.data != _OutputClass.MASKED.value).astype(
mask = (truth.data != _OutputClass.MASKED).astype(
jnp.float32)
total_loss += jnp.sum(loss*mask)/jnp.sum(mask)
elif truth.type_ == _Type.MASK_ONE:
Expand All @@ -783,9 +783,9 @@ def loss(params, feedback):
pred) * is_not_done, axis=-1))
elif truth.type_ == _Type.CATEGORICAL:
unmasked_data = truth.data[
truth.data == _OutputClass.POSITIVE.value]
truth.data == _OutputClass.POSITIVE]
masked_truth = truth.data * (
truth.data != _OutputClass.MASKED.value).astype(jnp.float32)
truth.data != _OutputClass.MASKED).astype(jnp.float32)
if self.decode_diffs:
total_loss += jnp.sum(
-jnp.sum(
Expand Down Expand Up @@ -820,12 +820,12 @@ def loss(params, feedback):
elif truth.type_ == _Type.MASK:
loss = (jnp.maximum(pred, 0) - pred * truth.data +
jnp.log1p(jnp.exp(-jnp.abs(pred))))
mask = (truth.data != _OutputClass.MASKED.value).astype(jnp.float32)
mask = (truth.data != _OutputClass.MASKED).astype(jnp.float32)
total_loss += jnp.sum(loss*mask)/jnp.sum(mask)
elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]:
unmasked_data = truth.data[truth.data == _OutputClass.POSITIVE.value]
unmasked_data = truth.data[truth.data == _OutputClass.POSITIVE]
masked_truth = truth.data * (
truth.data != _OutputClass.MASKED.value).astype(jnp.float32)
truth.data != _OutputClass.MASKED).astype(jnp.float32)
total_loss += (
-jnp.sum(masked_truth * jax.nn.log_softmax(pred))
/ jnp.sum(unmasked_data))
Expand Down Expand Up @@ -864,14 +864,14 @@ def verbose_loss(self, feedback: _Feedback, extra_info) -> Dict[str, _Array]:

losses = {}
if self.decode_diffs:
for loc in _Location:
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
for i in range(len(gt_diffs)):
is_not_done = _is_not_done_broadcast(lengths, i, gt_diffs[i][loc])
diff_loss = (
jnp.maximum(diff_logits[i][loc], 0) -
diff_logits[i][loc] * gt_diffs[i][loc] +
jnp.log1p(jnp.exp(-jnp.abs(diff_logits[i][loc]))) * is_not_done)
losses[loc.name + '_diff_%d' % i] = jnp.mean(diff_loss)
losses[loc + '_diff_%d' % i] = jnp.mean(diff_loss)

if self.decode_hints:
for truth in feedback.features.hints:
Expand Down
4 changes: 2 additions & 2 deletions clrs/_src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def evaluate(


def _eval_one(pred, truth):
mask = np.all(truth != specs.OutputClass.MASKED.value, axis=-1)
mask = np.all(truth != specs.OutputClass.MASKED, axis=-1)
return np.sum(
(np.argmax(pred, -1) == np.argmax(truth, -1)) * mask) / np.sum(mask)


def _mask_fn(pred, truth):
"""Evaluate outputs of type MASK, and account for any class imbalance."""
mask = (truth != specs.OutputClass.MASKED.value).astype(np.float32)
mask = (truth != specs.OutputClass.MASKED).astype(np.float32)

# Use F1 score for the masked outputs to address any imbalance
tp = np.sum((((pred > 0.5) * (truth > 0.5)) * 1.0) * mask)
Expand Down
73 changes: 48 additions & 25 deletions clrs/_src/probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@

from typing import Dict, List, Tuple, Union

import attr
from clrs._src import specs
import jax
import numpy as np
import tensorflow as tf


_Location = specs.Location
Expand All @@ -37,34 +39,48 @@

_Array = np.ndarray
_Data = Union[_Array, List[_Array]]
_DataOrType = Union[_Data, _Type]
_DataOrType = Union[_Data, str]

ProbesDict = Dict[_Stage, Dict[_Location, Dict[str, Dict[str, _DataOrType]]]]
ProbesDict = Dict[
str, Dict[str, Dict[str, Dict[str, _DataOrType]]]]


class ProbeError(Exception):
pass
def _convert_to_str(element):
if isinstance(element, tf.Tensor):
return element.numpy().decode('utf-8')
elif isinstance(element, (np.ndarray, bytes)):
return element.decode('utf-8')
else:
return element


# First anotation makes this object jax.jit/pmap friendly, second one makes this
# tf.data.Datasets friendly.
@jax.tree_util.register_pytree_node_class
@attr.define
class DataPoint:
"""Describes a data point."""

def __init__(
self,
name: str,
location: _Location,
type_: _Type,
data: _Array,
):
self.name = name
self.location = location
self.type_ = type_
self.data = data
_name: str
_location: str
_type_: str
data: _Array

@property
def name(self):
return _convert_to_str(self._name)

@property
def location(self):
return _convert_to_str(self._location)

@property
def type_(self):
return _convert_to_str(self._type_)

def __repr__(self):
s = f'DataPoint(name="{self.name}",\tlocation={self.location.name},\t'
return s + f'type={self.type_.name},\tdata=Array{self.data.shape})'
s = f'DataPoint(name="{self.name}",\tlocation={self.location},\t'
return s + f'type={self.type_},\tdata=Array{self.data.shape})'

def tree_flatten(self):
data = (self.data,)
Expand All @@ -78,6 +94,10 @@ def tree_unflatten(cls, meta, data):
return DataPoint(name, location, type_, subdata)


class ProbeError(Exception):
pass


def initialize(spec: specs.Spec) -> ProbesDict:
"""Initializes an empty `ProbesDict` corresponding with the provided spec."""
probes = dict()
Expand All @@ -89,21 +109,24 @@ def initialize(spec: specs.Spec) -> ProbesDict:
for name in spec:
stage, loc, t = spec[name]
probes[stage][loc][name] = {}
probes[stage][loc][name]['type_'] = t
probes[stage][loc][name]['data'] = []

return probes
probes[stage][loc][name]['type_'] = t
# Pytype thinks initialize() returns a ProbesDict with a str for all final
# values instead of _DataOrType.
return probes # pytype: disable=bad-return-type


def push(probes: ProbesDict, stage: _Stage, next_probe):
def push(probes: ProbesDict, stage: str, next_probe):
"""Pushes a probe into an existing `ProbesDict`."""
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
for name in probes[stage][loc]:
if name not in next_probe:
raise ProbeError(f'Missing probe for {name}.')
if isinstance(probes[stage][loc][name]['data'], _Array):
raise ProbeError('Attemping to push to finalized `ProbesDict`.')
probes[stage][loc][name]['data'].append(next_probe[name])
# Pytype thinks initialize() returns a ProbesDict with a str for all final
# values instead of _DataOrType.
probes[stage][loc][name]['data'].append(next_probe[name]) # pytype: disable=attribute-error


def finalize(probes: ProbesDict):
Expand Down Expand Up @@ -249,15 +272,15 @@ def strings_pair_cat(pair_probe: np.ndarray, nb_classes: int) -> np.ndarray:
probe_ret = np.zeros((n + m, n + m, nb_classes + 1))
for i in range(0, n):
for j in range(0, m):
probe_ret[i, j + n, int(pair_probe[i, j])] = _OutputClass.POSITIVE.value
probe_ret[i, j + n, int(pair_probe[i, j])] = _OutputClass.POSITIVE

# Fill the blank cells.
for i_1 in range(0, n):
for i_2 in range(0, n):
probe_ret[i_1, i_2, nb_classes] = _OutputClass.MASKED.value
probe_ret[i_1, i_2, nb_classes] = _OutputClass.MASKED
for j_1 in range(0, m):
for x in range(0, n + m):
probe_ret[j_1 + n, x, nb_classes] = _OutputClass.MASKED.value
probe_ret[j_1 + n, x, nb_classes] = _OutputClass.MASKED
return probe_ret


Expand Down
18 changes: 5 additions & 13 deletions clrs/_src/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,44 +30,36 @@
for representing sequential data where appropriate
"""

import enum
import types
from typing import Dict, Tuple


class _OrderedEnum(enum.Enum):

def __lt__(self, other):
assert self.__class__ is other.__class__
return self.value < other.value # pylint: disable=comparison-with-callable


class Stage(_OrderedEnum):
class Stage:
INPUT = 'input'
OUTPUT = 'output'
HINT = 'hint'


class Location(_OrderedEnum):
class Location:
NODE = 'node'
EDGE = 'edge'
GRAPH = 'graph'


class Type(_OrderedEnum):
class Type:
SCALAR = 'scalar'
CATEGORICAL = 'categorical'
MASK = 'mask'
MASK_ONE = 'mask_one'
POINTER = 'pointer'


class OutputClass(_OrderedEnum):
class OutputClass:
POSITIVE = 1
NEGATIVE = 0
MASKED = -1

Spec = Dict[str, Tuple[Stage, Location, Type]]
Spec = Dict[str, Tuple[str, str, str]]

CLRS_21_ALGS = [
'bellman_ford',
Expand Down

0 comments on commit d79c4ae

Please sign in to comment.