Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change types of DataPoint and DataPoint members. #22

Merged
merged 1 commit into from
Jan 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clrs/_src/algorithms/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def bridges(A: _Array) -> _Out:

low = np.zeros(A.shape[0])
is_bridge = (
np.zeros((A.shape[0], A.shape[0])) + _OutputClass.MASKED.value + adj)
np.zeros((A.shape[0], A.shape[0])) + _OutputClass.MASKED + adj)

for s in range(A.shape[0]):
if color[s] == 0:
Expand Down Expand Up @@ -1523,7 +1523,7 @@ def bipartite_matching(A: _Array, n: int, m: int, s: int, t: int) -> _Out:
't': probing.mask_one(t, A.shape[0])
})
in_matching = (
np.zeros((A.shape[0], A.shape[1])) + _OutputClass.MASKED.value + adj
np.zeros((A.shape[0], A.shape[1])) + _OutputClass.MASKED + adj
+ adj.T)
u = t
while True:
Expand Down
22 changes: 11 additions & 11 deletions clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _msg_passing_step(self,
hint_cur = jnp.expand_dims(hint_cur, -1)
hint_nxt = jnp.expand_dims(hint_nxt, -1)
gt_diffs[hint.location] += jnp.any(hint_cur != hint_nxt, axis=-1)
for loc in _Location:
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
gt_diffs[loc] = (gt_diffs[loc] > 0.0).astype(jnp.float32) * 1.0

(hiddens, output_preds_cand, hint_preds, diff_logits,
Expand All @@ -156,7 +156,7 @@ def _msg_passing_step(self,
if self.decode_hints:
if hints[0].data.shape[0] == 1 or repred:
diff_preds = {}
for loc in _Location:
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
diff_preds[loc] = (diff_logits[loc] > 0.0).astype(jnp.float32) * 1.0
else:
diff_preds = gt_diffs
Expand Down Expand Up @@ -730,7 +730,7 @@ def loss(params, feedback):
total_loss = 0.0
lengths = feedback.features.lengths
if self.decode_diffs:
for loc in _Location:
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
for i in range(len(gt_diffs)):
is_not_done = _is_not_done_broadcast(
lengths, i, diff_logits[i][loc])
Expand Down Expand Up @@ -767,7 +767,7 @@ def loss(params, feedback):
loss = jnp.mean(
jnp.maximum(pred, 0) - pred * truth.data[i + 1] +
jnp.log1p(jnp.exp(-jnp.abs(pred))) * is_not_done)
mask = (truth.data != _OutputClass.MASKED.value).astype(
mask = (truth.data != _OutputClass.MASKED).astype(
jnp.float32)
total_loss += jnp.sum(loss*mask)/jnp.sum(mask)
elif truth.type_ == _Type.MASK_ONE:
Expand All @@ -783,9 +783,9 @@ def loss(params, feedback):
pred) * is_not_done, axis=-1))
elif truth.type_ == _Type.CATEGORICAL:
unmasked_data = truth.data[
truth.data == _OutputClass.POSITIVE.value]
truth.data == _OutputClass.POSITIVE]
masked_truth = truth.data * (
truth.data != _OutputClass.MASKED.value).astype(jnp.float32)
truth.data != _OutputClass.MASKED).astype(jnp.float32)
if self.decode_diffs:
total_loss += jnp.sum(
-jnp.sum(
Expand Down Expand Up @@ -820,12 +820,12 @@ def loss(params, feedback):
elif truth.type_ == _Type.MASK:
loss = (jnp.maximum(pred, 0) - pred * truth.data +
jnp.log1p(jnp.exp(-jnp.abs(pred))))
mask = (truth.data != _OutputClass.MASKED.value).astype(jnp.float32)
mask = (truth.data != _OutputClass.MASKED).astype(jnp.float32)
total_loss += jnp.sum(loss*mask)/jnp.sum(mask)
elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]:
unmasked_data = truth.data[truth.data == _OutputClass.POSITIVE.value]
unmasked_data = truth.data[truth.data == _OutputClass.POSITIVE]
masked_truth = truth.data * (
truth.data != _OutputClass.MASKED.value).astype(jnp.float32)
truth.data != _OutputClass.MASKED).astype(jnp.float32)
total_loss += (
-jnp.sum(masked_truth * jax.nn.log_softmax(pred))
/ jnp.sum(unmasked_data))
Expand Down Expand Up @@ -864,14 +864,14 @@ def verbose_loss(self, feedback: _Feedback, extra_info) -> Dict[str, _Array]:

losses = {}
if self.decode_diffs:
for loc in _Location:
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
for i in range(len(gt_diffs)):
is_not_done = _is_not_done_broadcast(lengths, i, gt_diffs[i][loc])
diff_loss = (
jnp.maximum(diff_logits[i][loc], 0) -
diff_logits[i][loc] * gt_diffs[i][loc] +
jnp.log1p(jnp.exp(-jnp.abs(diff_logits[i][loc]))) * is_not_done)
losses[loc.name + '_diff_%d' % i] = jnp.mean(diff_loss)
losses[loc + '_diff_%d' % i] = jnp.mean(diff_loss)

if self.decode_hints:
for truth in feedback.features.hints:
Expand Down
4 changes: 2 additions & 2 deletions clrs/_src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def evaluate(


def _eval_one(pred, truth):
mask = np.all(truth != specs.OutputClass.MASKED.value, axis=-1)
mask = np.all(truth != specs.OutputClass.MASKED, axis=-1)
return np.sum(
(np.argmax(pred, -1) == np.argmax(truth, -1)) * mask) / np.sum(mask)


def _mask_fn(pred, truth):
"""Evaluate outputs of type MASK, and account for any class imbalance."""
mask = (truth != specs.OutputClass.MASKED.value).astype(np.float32)
mask = (truth != specs.OutputClass.MASKED).astype(np.float32)

# Use F1 score for the masked outputs to address any imbalance
tp = np.sum((((pred > 0.5) * (truth > 0.5)) * 1.0) * mask)
Expand Down
73 changes: 48 additions & 25 deletions clrs/_src/probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@

from typing import Dict, List, Tuple, Union

import attr
from clrs._src import specs
import jax
import numpy as np
import tensorflow as tf


_Location = specs.Location
Expand All @@ -37,34 +39,48 @@

_Array = np.ndarray
_Data = Union[_Array, List[_Array]]
_DataOrType = Union[_Data, _Type]
_DataOrType = Union[_Data, str]

ProbesDict = Dict[_Stage, Dict[_Location, Dict[str, Dict[str, _DataOrType]]]]
ProbesDict = Dict[
str, Dict[str, Dict[str, Dict[str, _DataOrType]]]]


class ProbeError(Exception):
pass
def _convert_to_str(element):
if isinstance(element, tf.Tensor):
return element.numpy().decode('utf-8')
elif isinstance(element, (np.ndarray, bytes)):
return element.decode('utf-8')
else:
return element


# First anotation makes this object jax.jit/pmap friendly, second one makes this
# tf.data.Datasets friendly.
@jax.tree_util.register_pytree_node_class
@attr.define
class DataPoint:
"""Describes a data point."""

def __init__(
self,
name: str,
location: _Location,
type_: _Type,
data: _Array,
):
self.name = name
self.location = location
self.type_ = type_
self.data = data
_name: str
_location: str
_type_: str
data: _Array

@property
def name(self):
return _convert_to_str(self._name)

@property
def location(self):
return _convert_to_str(self._location)

@property
def type_(self):
return _convert_to_str(self._type_)

def __repr__(self):
s = f'DataPoint(name="{self.name}",\tlocation={self.location.name},\t'
return s + f'type={self.type_.name},\tdata=Array{self.data.shape})'
s = f'DataPoint(name="{self.name}",\tlocation={self.location},\t'
return s + f'type={self.type_},\tdata=Array{self.data.shape})'

def tree_flatten(self):
data = (self.data,)
Expand All @@ -78,6 +94,10 @@ def tree_unflatten(cls, meta, data):
return DataPoint(name, location, type_, subdata)


class ProbeError(Exception):
pass


def initialize(spec: specs.Spec) -> ProbesDict:
"""Initializes an empty `ProbesDict` corresponding with the provided spec."""
probes = dict()
Expand All @@ -89,21 +109,24 @@ def initialize(spec: specs.Spec) -> ProbesDict:
for name in spec:
stage, loc, t = spec[name]
probes[stage][loc][name] = {}
probes[stage][loc][name]['type_'] = t
probes[stage][loc][name]['data'] = []

return probes
probes[stage][loc][name]['type_'] = t
# Pytype thinks initialize() returns a ProbesDict with a str for all final
# values instead of _DataOrType.
return probes # pytype: disable=bad-return-type


def push(probes: ProbesDict, stage: _Stage, next_probe):
def push(probes: ProbesDict, stage: str, next_probe):
"""Pushes a probe into an existing `ProbesDict`."""
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
for name in probes[stage][loc]:
if name not in next_probe:
raise ProbeError(f'Missing probe for {name}.')
if isinstance(probes[stage][loc][name]['data'], _Array):
raise ProbeError('Attemping to push to finalized `ProbesDict`.')
probes[stage][loc][name]['data'].append(next_probe[name])
# Pytype thinks initialize() returns a ProbesDict with a str for all final
# values instead of _DataOrType.
probes[stage][loc][name]['data'].append(next_probe[name]) # pytype: disable=attribute-error


def finalize(probes: ProbesDict):
Expand Down Expand Up @@ -249,15 +272,15 @@ def strings_pair_cat(pair_probe: np.ndarray, nb_classes: int) -> np.ndarray:
probe_ret = np.zeros((n + m, n + m, nb_classes + 1))
for i in range(0, n):
for j in range(0, m):
probe_ret[i, j + n, int(pair_probe[i, j])] = _OutputClass.POSITIVE.value
probe_ret[i, j + n, int(pair_probe[i, j])] = _OutputClass.POSITIVE

# Fill the blank cells.
for i_1 in range(0, n):
for i_2 in range(0, n):
probe_ret[i_1, i_2, nb_classes] = _OutputClass.MASKED.value
probe_ret[i_1, i_2, nb_classes] = _OutputClass.MASKED
for j_1 in range(0, m):
for x in range(0, n + m):
probe_ret[j_1 + n, x, nb_classes] = _OutputClass.MASKED.value
probe_ret[j_1 + n, x, nb_classes] = _OutputClass.MASKED
return probe_ret


Expand Down
18 changes: 5 additions & 13 deletions clrs/_src/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,44 +30,36 @@
for representing sequential data where appropriate
"""

import enum
import types
from typing import Dict, Tuple


class _OrderedEnum(enum.Enum):

def __lt__(self, other):
assert self.__class__ is other.__class__
return self.value < other.value # pylint: disable=comparison-with-callable


class Stage(_OrderedEnum):
class Stage:
INPUT = 'input'
OUTPUT = 'output'
HINT = 'hint'


class Location(_OrderedEnum):
class Location:
NODE = 'node'
EDGE = 'edge'
GRAPH = 'graph'


class Type(_OrderedEnum):
class Type:
SCALAR = 'scalar'
CATEGORICAL = 'categorical'
MASK = 'mask'
MASK_ONE = 'mask_one'
POINTER = 'pointer'


class OutputClass(_OrderedEnum):
class OutputClass:
POSITIVE = 1
NEGATIVE = 0
MASKED = -1

Spec = Dict[str, Tuple[Stage, Location, Type]]
Spec = Dict[str, Tuple[str, str, str]]

CLRS_21_ALGS = [
'bellman_ford',
Expand Down