-
Notifications
You must be signed in to change notification settings - Fork 616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allowing more flexible cost functions for optimizers #959
Changes from 5 commits
fe79cc6
9362b82
5609f5d
2c14788
24d30e1
21e9e01
a820b55
ad0906a
ed4ac91
b30a552
cb2b328
5c03e65
df0a7a8
4cf358c
b9a447d
7624f24
5e0c40b
12598c8
4f35668
193cb53
095de9e
9f189f9
b4b0d71
83dcf94
a3c2534
b4ed411
b3c3857
29416d2
dfc0092
6cc787c
703942d
6ca34cb
4854c6b
16c6bde
7ddbcbc
c53adb7
d9d03a9
afd0cfa
751a030
90257ba
a35d782
033af1b
3c58644
f39a839
0eb4133
bfe0a4b
c3d1e49
761bfed
f7e9d67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -19,6 +19,7 @@ | |||||
|
||||||
import numpy as onp | ||||||
import pytest | ||||||
from pytest_mock import mocker | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to import
Suggested change
|
||||||
|
||||||
import pennylane as qml | ||||||
from pennylane import numpy as np | ||||||
|
@@ -720,6 +721,11 @@ def test_update_stepsize(self): | |||||
assert opt._stepsize == eta2 | ||||||
|
||||||
|
||||||
def reset(opt): | ||||||
if getattr(opt, "reset", None): | ||||||
opt.reset() | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize( | ||||||
"opt, opt_name", | ||||||
[ | ||||||
|
@@ -735,56 +741,76 @@ def test_update_stepsize(self): | |||||
class TestOverOpts: | ||||||
"""Tests keywords, multiple arguements, and non-training arguments in relevent optimizers""" | ||||||
|
||||||
def test_kwargs(self, opt, opt_name, tol): | ||||||
def test_kwargs(self, mocker, opt, opt_name, tol): | ||||||
"""Test that the keywords get passed and alter the function""" | ||||||
|
||||||
def func(x, c=1.0): | ||||||
return (x - c) ** 2 | ||||||
class func_wrapper: | ||||||
@staticmethod | ||||||
def func(x, c=1.0): | ||||||
return (x - c) ** 2 | ||||||
|
||||||
x = 1.0 | ||||||
|
||||||
x_new_one = opt.step(func, x, c=1.0) | ||||||
x_new_two = opt.step(func, x, c=2.0) | ||||||
wrapper = func_wrapper() | ||||||
spy = mocker.spy(wrapper, "func") | ||||||
|
||||||
x_new_one_wc, cost_one = opt.step_and_cost(func, x, c=1.0) | ||||||
x_new_two_wc, cost_two = opt.step_and_cost(func, x, c=2.0) | ||||||
x_new_two = opt.step(wrapper.func, x, c=2.0) | ||||||
reset(opt) | ||||||
|
||||||
if getattr(opt, "reset", None): | ||||||
opt.reset() | ||||||
args2, kwargs2 = spy.call_args_list[-1] | ||||||
|
||||||
assert x_new_one != pytest.approx(x_new_two, abs=tol) | ||||||
assert x_new_one_wc != pytest.approx(x_new_two_wc, abs=tol) | ||||||
x_new_three_wc, cost_three = opt.step_and_cost(wrapper.func, x, c=3.0) | ||||||
reset(opt) | ||||||
|
||||||
if opt_name != "nest": | ||||||
assert cost_one == pytest.approx(func(x, c=1.0), abs=tol) | ||||||
assert cost_two == pytest.approx(func(x, c=2.0), abs=tol) | ||||||
args3, kwargs3 = spy.call_args_list[-1] | ||||||
|
||||||
@pytest.mark.parametrize( | ||||||
"func, args", | ||||||
[ | ||||||
(lambda x, y: x * y, (1.0, 1.0)), | ||||||
(lambda x, y: x[0] * y[0], (np.array([1.0]), np.array([1.0]))), | ||||||
], | ||||||
) | ||||||
def test_multi_args(self, opt, opt_name, func, args, tol): | ||||||
"""Test multiple arguments to function""" | ||||||
x_new, y_new = opt.step(func, *args) | ||||||
x_new2, y_new2 = opt.step(func, x_new, y_new) | ||||||
if opt_name != "roto": | ||||||
assert args2 == (x,) | ||||||
assert args3 == (x,) | ||||||
else: | ||||||
assert x_new_two != pytest.approx(x, abs=tol) | ||||||
assert x_new_three_wc != pytest.approx(x, abs=tol) | ||||||
|
||||||
assert kwargs2 == {"c": 2.0} | ||||||
assert kwargs3 == {"c": 3.0} | ||||||
|
||||||
assert cost_three == pytest.approx(wrapper.func(x, c=3.0), abs=tol) | ||||||
Comment on lines
+773
to
+776
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💯 |
||||||
|
||||||
def test_multi_args(self, mocker, opt, opt_name, tol): | ||||||
"""Test passing multiple arguments to function""" | ||||||
|
||||||
class func_wrapper: | ||||||
@staticmethod | ||||||
def func(x, y, z): | ||||||
return x[0] * y[0] + z[0] | ||||||
|
||||||
wrapper = func_wrapper() | ||||||
spy = mocker.spy(wrapper, "func") | ||||||
|
||||||
x = np.array([1.0]) | ||||||
y = np.array([2.0]) | ||||||
z = np.array([3.0]) | ||||||
|
||||||
(x_new_wc, y_new_wc), cost = opt.step_and_cost(func, *args) | ||||||
(x_new2_wx, y_new2_wc), cost2 = opt.step_and_cost(func, x_new_wc, y_new_wc) | ||||||
(x_new, y_new, z_new), cost = opt.step_and_cost(wrapper.func, x, y, z) | ||||||
reset(opt) | ||||||
args_called1, kwargs1 = spy.call_args_list[-1] # just take last call | ||||||
|
||||||
if getattr(opt, "reset", None): | ||||||
opt.reset() | ||||||
x_new2, y_new2, z_new2 = opt.step(wrapper.func, x_new, y_new, z_new) | ||||||
reset(opt) | ||||||
args_called2, kwargs2 = spy.call_args_list[-1] # just take last call | ||||||
|
||||||
assert x_new != pytest.approx(args[0], abs=tol) | ||||||
assert y_new != pytest.approx(args[1], abs=tol) | ||||||
if opt_name != "roto": | ||||||
assert args_called1 == (x, y, z) | ||||||
assert args_called2 == (x_new, y_new, z_new) | ||||||
else: | ||||||
assert x_new != pytest.approx(x, abs=tol) | ||||||
assert y_new != pytest.approx(y, abs=tol) | ||||||
assert z_new != pytest.approx(z, abs=tol) | ||||||
|
||||||
assert x_new_wc != pytest.approx(args[0], abs=tol) | ||||||
assert y_new_wc != pytest.approx(args[1], abs=tol) | ||||||
assert kwargs1 == {} | ||||||
assert kwargs2 == {} | ||||||
|
||||||
if opt_name != "nest": | ||||||
assert cost == pytest.approx(func(*args), abs=tol) | ||||||
assert cost == pytest.approx(wrapper.func(x, y, z), abs=tol) | ||||||
|
||||||
def test_nontrainable_data(self, opt, opt_name, tol): | ||||||
"""Check non-trainable argument does not get updated""" | ||||||
|
@@ -796,71 +822,37 @@ def func(x, data): | |||||
data = np.array([1.0], requires_grad=False) | ||||||
|
||||||
args_new = opt.step(func, x, data) | ||||||
reset(opt) | ||||||
args_new_wc, cost = opt.step_and_cost(func, *args_new) | ||||||
|
||||||
if getattr(opt, "reset", None): | ||||||
opt.reset() | ||||||
reset(opt) | ||||||
|
||||||
assert len(args_new) == pytest.approx(2, abs=tol) | ||||||
assert args_new[0] != pytest.approx(x, abs=tol) | ||||||
assert args_new[1] == pytest.approx(data, abs=tol) | ||||||
|
||||||
if opt_name != "nest": | ||||||
assert cost == pytest.approx(func(args_new[0], data), abs=tol) | ||||||
|
||||||
def test_multiargs_data_kwargs(self, opt, opt_name, tol): | ||||||
""" Check all multiargs, non-trainable data, and keywords at the same time.""" | ||||||
assert cost == pytest.approx(func(*args_new), abs=tol) | ||||||
|
||||||
def func(x, data, y, c=1.0): | ||||||
return c * (x[0] + y[0] - data[0]) ** 2 | ||||||
|
||||||
x = np.array([1.0], requires_grad=True) | ||||||
y = np.array([1.0]) | ||||||
data = np.array([1.0], requires_grad=False) | ||||||
|
||||||
args_new, cost = opt.step_and_cost(func, x, data, y, c=0.5) | ||||||
args_new2 = opt.step(func, *args_new, c=0.5) | ||||||
|
||||||
if getattr(opt, "reset", None): | ||||||
opt.reset() | ||||||
def test_steps_the_same(self, opt, opt_name, tol): | ||||||
"""Tests whether separating the args into different inputs affects their | ||||||
optimization step. Assumes single argument optimization is correct, as tested elsewhere.""" | ||||||
|
||||||
assert args_new[0] != pytest.approx(x, abs=tol) | ||||||
assert args_new[1] == pytest.approx(data, abs=tol) | ||||||
assert args_new[2] != pytest.approx(y, abs=tol) | ||||||
def func1(x, y, z): | ||||||
return x[0] * y[0] * z[0] | ||||||
|
||||||
if opt_name != "nest": | ||||||
assert cost == pytest.approx(func(x, data, y, c=0.5), abs=tol) | ||||||
def func2(args): | ||||||
return args[0][0] * args[1][0] * args[2][0] | ||||||
|
||||||
def test_steps_the_same(self, opt, opt_name, tol): | ||||||
"""Tests optimizing single parameter same as with several at a time""" | ||||||
x = np.array([1.0]) | ||||||
y = np.array([2.0]) | ||||||
z = np.array([3.0]) | ||||||
args = (x, y, z) | ||||||
|
||||||
def func(x, y, z): | ||||||
return x[0] * y[0] * z[0] | ||||||
x_seperate, y_seperate, z_seperate = opt.step(func1, x, y, z) | ||||||
reset(opt) | ||||||
|
||||||
args_new = opt.step(func2, args) | ||||||
reset(opt) | ||||||
|
||||||
fx = lambda xp: func(xp, y, z) | ||||||
fy = lambda yp: func(x, yp, z) | ||||||
fz = lambda zp: func(x, y, zp) | ||||||
|
||||||
if getattr(opt, "reset", None): | ||||||
opt.reset() | ||||||
|
||||||
x_full, y_full, z_full = opt.step(func, x, y, z) | ||||||
if getattr(opt, "reset", None): | ||||||
opt.reset() | ||||||
|
||||||
x_part = opt.step(fx, x) | ||||||
if getattr(opt, "reset", None): | ||||||
opt.reset() | ||||||
y_part = opt.step(fy, y) | ||||||
if getattr(opt, "reset", None): | ||||||
opt.reset() | ||||||
z_part = opt.step(fz, z) | ||||||
if getattr(opt, "reset", None): | ||||||
opt.reset() | ||||||
|
||||||
assert x_full == pytest.approx(x_part, abs=tol) | ||||||
assert y_full == pytest.approx(y_part, abs=tol) | ||||||
assert z_full == pytest.approx(z_part, abs=tol) | ||||||
assert x_seperate == pytest.approx(args_new[0], abs=tol) | ||||||
assert y_seperate == pytest.approx(args_new[1], abs=tol) | ||||||
assert z_seperate == pytest.approx(args_new[2], abs=tol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget to add your name to contributors! (unless you have done that already, and I missed it)