Skip to content

Commit

Permalink
Implement scalar division for modules etc
Browse files Browse the repository at this point in the history
Thi commit addresses #249.
  • Loading branch information
jgosmann committed Jun 22, 2020
1 parent d36b68e commit 3063864
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 0 deletions.
7 changes: 7 additions & 0 deletions nengo_spa/ast/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def __rmul__(self, other):
else:
return self._mul_with_dynamic(other, swap_inputs=True)

@binary_node_op
def __truediv__(self, other):
if isinstance(other, FixedScalar):
return self._mul_with_fixed(FixedScalar(1.0 / other.value))
else:
return NotImplemented

@binary_node_op
def dot(self, other):
type_ = infer_types(self, other)
Expand Down
5 changes: 5 additions & 0 deletions nengo_spa/ast/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ def __rmul__(self, other):
type_ = infer_types(self, other)
return PointerSymbol(other.expr + "*" + self.expr, type_)

@symbolic_op
def __truediv__(self, other):
type_ = infer_types(self, other)
return PointerSymbol(self.expr + "/" + other.expr, type_)

def dot(self, other):
other = as_symbolic_node(other)
if not isinstance(other, PointerSymbol):
Expand Down
18 changes: 18 additions & 0 deletions nengo_spa/ast/tests/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,24 @@ def test_binary_operation_on_modules(Simulator, algebra, op, suffix, seed, rng):
)


@pytest.mark.parametrize("suffix", ["", ".output"])
def test_division_with_fixed(Simulator, suffix, seed, rng):
vocab = spa.Vocabulary(16, pointer_gen=rng)
vocab.populate("A")

with spa.Network(seed=seed) as model:
a = spa.Transcode("A", output_vocab=vocab) # noqa: F841
x = eval("a" + suffix + "/ 2")
p = nengo.Probe(x.construct(), synapse=0.03)

with Simulator(model) as sim:
sim.run(0.3)

assert_sp_close(
sim.trange(), sim.data[p], vocab.parse("0.5 * A"), skip=0.2, atol=0.3
)


@pytest.mark.parametrize("op", ["+", "-", "*"])
@pytest.mark.parametrize("order", ["AB", "BA"])
def test_binary_operation_on_modules_with_pointer_symbol(
Expand Down
11 changes: 11 additions & 0 deletions nengo_spa/ast/tests/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ def test_multiply_fixed_scalar_and_pointer_symbol(scalar, rng):
assert_equal(node.output, vocab.parse("2 * A").v)


@pytest.mark.parametrize("scalar", [2, np.float64(2)])
def test_divide_pointer_symbol_by_fixed_scalar(scalar, rng):
vocab = spa.Vocabulary(16, pointer_gen=rng)
vocab.populate("A")

with spa.Network():
x = PointerSymbol("A", TVocabulary(vocab)) / scalar
node = x.construct()
assert_equal(node.output, vocab.parse("0.5 * A").v)


def test_fixed_dot(rng):
vocab = spa.Vocabulary(16, pointer_gen=rng)
vocab.populate("A; B")
Expand Down
1 change: 1 addition & 0 deletions nengo_spa/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def op_impl(self, other):
__rsub__ = __define_binary_op.__func__("__rsub__")
__mul__ = __define_binary_op.__func__("__mul__")
__rmul__ = __define_binary_op.__func__("__rmul__")
__truediv__ = __define_binary_op.__func__("__truediv__")
__matmul__ = __define_binary_op.__func__("__matmul__")
__rmatmul__ = __define_binary_op.__func__("__rmatmul__")

Expand Down

0 comments on commit 3063864

Please sign in to comment.