Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Jan 12, 2024
1 parent 63b8ae2 commit 3e1389f
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
- name: Install Library
run: pip install --upgrade pip && pip install --use-deprecated=legacy-resolver '.[dev]'
- name: Run source notebooks
run: find . -name "*.ipynb" -not -path "*/.ipynb_checkpoints/*" | xargs -I {} papermill {} {} -k python3
run: find . -name "*.ipynb" -not -path "*/tests/*" -not -path "*/.ipynb_checkpoints/*" | xargs -I {} papermill {} {} -k python3
- name: Expose 'internals' as artifact
uses: actions/upload-artifact@master
with:
Expand Down
29 changes: 19 additions & 10 deletions tests/nbs/00_typing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,16 @@
"source": [
"def func() -> sax.Model:\n",
" ...\n",
" \n",
"assert sax.is_model_factory(func) # yes, we only check the annotation for now...\n",
"\n",
"\n",
"assert sax.is_model_factory(func) # yes, we only check the annotation for now...\n",
"\n",
"\n",
"def func():\n",
" ...\n",
" \n",
"assert not sax.is_model_factory(func) # yes, we only check the annotation for now..."
"\n",
"\n",
"assert not sax.is_model_factory(func) # yes, we only check the annotation for now..."
]
},
{
Expand Down Expand Up @@ -319,7 +322,9 @@
"Si, Sj, Sx, port_map = sax.scoo(_sdense) # type: ignore\n",
"np.testing.assert_array_equal(Si, jnp.array([0, 0, 0, 1, 1, 1, 2, 2, 2]))\n",
"np.testing.assert_array_equal(Sj, jnp.array([0, 1, 2, 0, 1, 2, 0, 1, 2]))\n",
"np.testing.assert_array_almost_equal(Sx, jnp.array([0.0, 2.0, 1.0, 6.0, 8.0, 7.0, 3.0, 5.0, 4.0]))\n",
"np.testing.assert_array_almost_equal(\n",
" Sx, jnp.array([0.0, 2.0, 1.0, 6.0, 8.0, 7.0, 3.0, 5.0, 4.0])\n",
")\n",
"assert port_map == {\"in0\": 0, \"in1\": 1, \"out0\": 2}"
]
},
Expand All @@ -340,9 +345,13 @@
"source": [
"assert sax.sdense(_sdense) is _sdense\n",
"Sd, port_map = sax.sdense(_scoo) # type: ignore\n",
"Sd_ = jnp.array([[3.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],\n",
" [0.0 + 0.0j, 4.0 + 0.0j, 0.0 + 0.0j],\n",
" [1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]])\n",
"Sd_ = jnp.array(\n",
" [\n",
" [3.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],\n",
" [0.0 + 0.0j, 4.0 + 0.0j, 0.0 + 0.0j],\n",
" [1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],\n",
" ]\n",
")\n",
"\n",
"np.testing.assert_array_almost_equal(Sd, Sd_)\n",
"assert port_map == {\"in0\": 0, \"in1\": 2, \"out0\": 1}"
Expand All @@ -351,9 +360,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "sax",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "sax"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
103 changes: 82 additions & 21 deletions tests/nbs/01_utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@
"def coupler(coupling=0.5):\n",
" return {\n",
" (\"in0\", \"out0\"): coupling**0.5,\n",
" (\"in0\", \"out1\"): 1j*coupling**0.5,\n",
" (\"in1\", \"out0\"): 1j*coupling**0.5,\n",
" (\"in0\", \"out1\"): 1j * coupling**0.5,\n",
" (\"in1\", \"out0\"): 1j * coupling**0.5,\n",
" (\"in1\", \"out1\"): coupling**0.5,\n",
" }"
]
Expand Down Expand Up @@ -230,7 +230,7 @@
},
"outputs": [],
"source": [
"assert sax.get_settings(coupler) == {'coupling': 0.5}"
"assert sax.get_settings(coupler) == {\"coupling\": 0.5}"
]
},
{
Expand All @@ -244,13 +244,63 @@
"source": [
"# hide\n",
"\n",
"wls = jnp.array([2.19999, 2.20001, 2.22499, 2.22501, 2.24999, 2.25001, 2.27499, 2.27501, 2.29999, 2.30001, 2.32499, 2.32501, 2.34999, 2.35001, 2.37499, 2.37501, 2.39999, 2.40001, 2.42499, 2.42501, 2.44999, 2.45001])\n",
"phis = jnp.array([5.17317336, 5.1219654, 4.71259842, 4.66252492, 5.65699608, 5.60817922, 2.03697377, 1.98936119, 6.010146, 5.96358061, 4.96336733, 4.91777933, 5.13912198, 5.09451137, 0.22347545, 0.17979684, 2.74501894, 2.70224092, 0.10403192, 0.06214664, 4.83328794, 4.79225525])\n",
"wls = jnp.array(\n",
" [\n",
" 2.19999,\n",
" 2.20001,\n",
" 2.22499,\n",
" 2.22501,\n",
" 2.24999,\n",
" 2.25001,\n",
" 2.27499,\n",
" 2.27501,\n",
" 2.29999,\n",
" 2.30001,\n",
" 2.32499,\n",
" 2.32501,\n",
" 2.34999,\n",
" 2.35001,\n",
" 2.37499,\n",
" 2.37501,\n",
" 2.39999,\n",
" 2.40001,\n",
" 2.42499,\n",
" 2.42501,\n",
" 2.44999,\n",
" 2.45001,\n",
" ]\n",
")\n",
"phis = jnp.array(\n",
" [\n",
" 5.17317336,\n",
" 5.1219654,\n",
" 4.71259842,\n",
" 4.66252492,\n",
" 5.65699608,\n",
" 5.60817922,\n",
" 2.03697377,\n",
" 1.98936119,\n",
" 6.010146,\n",
" 5.96358061,\n",
" 4.96336733,\n",
" 4.91777933,\n",
" 5.13912198,\n",
" 5.09451137,\n",
" 0.22347545,\n",
" 0.17979684,\n",
" 2.74501894,\n",
" 2.70224092,\n",
" 0.10403192,\n",
" 0.06214664,\n",
" 4.83328794,\n",
" 4.79225525,\n",
" ]\n",
")\n",
"wl = jnp.array([2.21, 2.27, 1.31, 2.424])\n",
"phi = jnp.array(sax.grouped_interp(wl, wls, phis))\n",
"phi_ref = jnp.array([-1.4901831, 1.3595749, -1.110012 , 2.1775336])\n",
"phi_ref = jnp.array([-1.4901831, 1.3595749, -1.110012, 2.1775336])\n",
"\n",
"assert ((phi-phi_ref)**2 < 1e-5).all()"
"assert ((phi - phi_ref) ** 2 < 1e-5).all()"
]
},
{
Expand Down Expand Up @@ -289,8 +339,13 @@
},
"outputs": [],
"source": [
"assert sax.mode_combinations(modes=[\"te\", \"tm\"]) == (('te', 'te'), ('tm', 'tm'))\n",
"assert sax.mode_combinations(modes=[\"te\", \"tm\"], cross=True) == (('te', 'te'), ('te', 'tm'), ('tm', 'te'), ('tm', 'tm'))"
"assert sax.mode_combinations(modes=[\"te\", \"tm\"]) == ((\"te\", \"te\"), (\"tm\", \"tm\"))\n",
"assert sax.mode_combinations(modes=[\"te\", \"tm\"], cross=True) == (\n",
" (\"te\", \"te\"),\n",
" (\"te\", \"tm\"),\n",
" (\"tm\", \"te\"),\n",
" (\"tm\", \"tm\"),\n",
")"
]
},
{
Expand Down Expand Up @@ -318,6 +373,7 @@
"def model(x=jnp.array(3.0), y=jnp.array(4.0), z=jnp.array([3.0, 4.0])) -> sax.SDict:\n",
" return {(\"in0\", \"out0\"): jnp.array(3.0)}\n",
"\n",
"\n",
"renamings = {\"x\": \"a\", \"y\": \"z\", \"z\": \"y\"}\n",
"new_model = sax.rename_params(model, renamings)\n",
"settings = sax.get_settings(new_model)\n",
Expand All @@ -339,11 +395,17 @@
"origports = sax.get_ports(d)\n",
"renamings = {\"p0\": \"in0\", \"p1\": \"out0\", \"p2\": \"in1\"}\n",
"d_ = sax.rename_ports(d, renamings)\n",
"assert tuple(sorted(sax.get_ports(d_))) == tuple(sorted(renamings[p] for p in origports))\n",
"assert tuple(sorted(sax.get_ports(d_))) == tuple(\n",
" sorted(renamings[p] for p in origports)\n",
")\n",
"d_ = sax.rename_ports(sax.scoo(d), renamings)\n",
"assert tuple(sorted(sax.get_ports(d_))) == tuple(sorted(renamings[p] for p in origports))\n",
"assert tuple(sorted(sax.get_ports(d_))) == tuple(\n",
" sorted(renamings[p] for p in origports)\n",
")\n",
"d_ = sax.rename_ports(sax.sdense(d), renamings)\n",
"assert tuple(sorted(sax.get_ports(d_))) == tuple(sorted(renamings[p] for p in origports))"
"assert tuple(sorted(sax.get_ports(d_))) == tuple(\n",
" sorted(renamings[p] for p in origports)\n",
")"
]
},
{
Expand Down Expand Up @@ -466,8 +528,7 @@
},
"outputs": [],
"source": [
"good_sdict = sax.reciprocal({(\"p0\", \"p1\"): 0.1, \n",
" (\"p1\", \"p2\"): 0.2})\n",
"good_sdict = sax.reciprocal({(\"p0\", \"p1\"): 0.1, (\"p1\", \"p2\"): 0.2})\n",
"assert sax.validate_sdict(good_sdict) is None\n",
"\n",
"bad_sdict = {\n",
Expand All @@ -487,19 +548,19 @@
},
"outputs": [],
"source": [
"assert sax.get_inputs_outputs([\"in0\", \"out0\"]) == (('in0',), ('out0',))\n",
"assert sax.get_inputs_outputs([\"in0\", \"in1\"]) == (('in0', 'in1'), ())\n",
"assert sax.get_inputs_outputs([\"out0\", \"out1\"]) == ((), ('out0', 'out1'))\n",
"assert sax.get_inputs_outputs([\"out0\", \"dc0\"]) == (('dc0',), ('out0',))\n",
"assert sax.get_inputs_outputs([\"dc0\", \"in0\"]) == (('in0',), ('dc0',))"
"assert sax.get_inputs_outputs([\"in0\", \"out0\"]) == ((\"in0\",), (\"out0\",))\n",
"assert sax.get_inputs_outputs([\"in0\", \"in1\"]) == ((\"in0\", \"in1\"), ())\n",
"assert sax.get_inputs_outputs([\"out0\", \"out1\"]) == ((), (\"out0\", \"out1\"))\n",
"assert sax.get_inputs_outputs([\"out0\", \"dc0\"]) == ((\"dc0\",), (\"out0\",))\n",
"assert sax.get_inputs_outputs([\"dc0\", \"in0\"]) == ((\"in0\",), (\"dc0\",))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sax",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "sax"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
12 changes: 6 additions & 6 deletions tests/nbs/02_multimode.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@
"outputs": [],
"source": [
"scoo_s = sax.singlemode(scoo_s)\n",
"assert (scoo_s[0]==jnp.array([0], dtype=int)).all()\n",
"assert (scoo_s[1]==jnp.array([1], dtype=int)).all()\n",
"assert (scoo_s[2]==jnp.array([1.0], dtype=float)).all()\n",
"assert scoo_s[3] == {'in0': 0, 'out0': 1}"
"assert (scoo_s[0] == jnp.array([0], dtype=int)).all()\n",
"assert (scoo_s[1] == jnp.array([1], dtype=int)).all()\n",
"assert (scoo_s[2] == jnp.array([1.0], dtype=float)).all()\n",
"assert scoo_s[3] == {\"in0\": 0, \"out0\": 1}"
]
},
{
Expand All @@ -138,9 +138,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "sax",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "sax"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
70 changes: 50 additions & 20 deletions tests/nbs/03_backends.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,28 @@
"outputs": [],
"source": [
"instances = {\n",
" \"lft\": (\n",
" \"lft\": {\"component\": \"coupler\"},\n",
" \"top\": {\"component\": \"wg\"},\n",
" \"rgt\": {\"component\": \"mmi\"},\n",
"}\n",
"connections = {\"lft,out0\": \"rgt,in0\", \"lft,out1\": \"top,in0\", \"top,out0\": \"rgt,in1\"}\n",
"ports = {\"in0\": \"lft,in0\", \"out0\": \"rgt,out0\"}\n",
"models = {\n",
" \"wg\": lambda: {\n",
" (\"in0\", \"out0\"): -0.99477 - 0.10211j,\n",
" (\"out0\", \"in0\"): -0.99477 - 0.10211j,\n",
" },\n",
" \"mmi\": lambda: {\n",
" (\"in0\", \"out0\"): 0.7071067811865476,\n",
" (\"in0\", \"out1\"): 0.7071067811865476j,\n",
" (\"in1\", \"out0\"): 0.7071067811865476j,\n",
" (\"in1\", \"out1\"): 0.7071067811865476,\n",
" (\"out0\", \"in0\"): 0.7071067811865476,\n",
" (\"out1\", \"in0\"): 0.7071067811865476j,\n",
" (\"out0\", \"in1\"): 0.7071067811865476j,\n",
" (\"out1\", \"in1\"): 0.7071067811865476,\n",
" },\n",
" \"coupler\": lambda: (\n",
" jnp.array(\n",
" [\n",
" [\n",
Expand Down Expand Up @@ -74,20 +95,7 @@
" ),\n",
" {\"in0\": 0, \"out0\": 2, \"out1\": 4},\n",
" ),\n",
" \"top\": {(\"in0\", \"out0\"): -0.99477 - 0.10211j, (\"out0\", \"in0\"): -0.99477 - 0.10211j},\n",
" \"rgt\": {\n",
" (\"in0\", \"out0\"): 0.7071067811865476,\n",
" (\"in0\", \"out1\"): 0.7071067811865476j,\n",
" (\"in1\", \"out0\"): 0.7071067811865476j,\n",
" (\"in1\", \"out1\"): 0.7071067811865476,\n",
" (\"out0\", \"in0\"): 0.7071067811865476,\n",
" (\"out1\", \"in0\"): 0.7071067811865476j,\n",
" (\"out0\", \"in1\"): 0.7071067811865476j,\n",
" (\"out1\", \"in1\"): 0.7071067811865476,\n",
" },\n",
"}\n",
"connections = {\"lft,out0\": \"rgt,in0\", \"lft,out1\": \"top,in0\", \"top,out0\": \"rgt,in1\"}\n",
"ports = {\"in0\": \"lft,in0\", \"out0\": \"rgt,out0\"}"
"}"
]
},
{
Expand All @@ -97,7 +105,13 @@
"metadata": {},
"outputs": [],
"source": [
"sax.sdict(sax.backends.evaluate_circuit(sax.backends.analyze_circuit(connections, ports), instances))"
"analyzed_instances = sax.backends.analyze_instances(instances, models)\n",
"analyzed_circuit = sax.backends.analyze_circuit(analyzed_instances, connections, ports)\n",
"sax.sdict(\n",
" sax.backends.evaluate_circuit(\n",
" analyzed_circuit, {k: models[v[\"component\"]]() for k, v in instances.items()}\n",
" )\n",
")"
]
},
{
Expand All @@ -107,7 +121,15 @@
"metadata": {},
"outputs": [],
"source": [
"sdict_klu = sax.sdict(sax.backends.evaluate_circuit_klu(sax.backends.analyze_circuit_klu(connections, ports), instances))"
"analyzed_instances = sax.backends.analyze_instances_klu(instances, models)\n",
"analyzed_circuit = sax.backends.analyze_circuit_klu(\n",
" analyzed_instances, connections, ports\n",
")\n",
"sdict_klu = sax.sdict(\n",
" sax.backends.evaluate_circuit_klu(\n",
" analyzed_circuit, {k: models[v[\"component\"]]() for k, v in instances.items()}\n",
" )\n",
")"
]
},
{
Expand All @@ -117,7 +139,15 @@
"metadata": {},
"outputs": [],
"source": [
"sdict_fg = sax.sdict(sax.backends.evaluate_circuit_fg(sax.backends.analyze_circuit_fg(connections, ports), instances))"
"analyzed_instances = sax.backends.analyze_instances_fg(instances, models)\n",
"analyzed_circuit = sax.backends.analyze_circuit_fg(\n",
" analyzed_instances, connections, ports\n",
")\n",
"sdict_fg = sax.sdict(\n",
" sax.backends.evaluate_circuit_fg(\n",
" analyzed_circuit, {k: models[v[\"component\"]]() for k, v in instances.items()}\n",
" )\n",
")"
]
},
{
Expand All @@ -138,9 +168,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "sax",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "sax"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down

0 comments on commit 3e1389f

Please sign in to comment.