Skip to content

Commit

Permalink
Fix dataset loading in tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
ffl096 committed Jan 3, 2025
1 parent 1b79b9f commit f17c334
Show file tree
Hide file tree
Showing 8 changed files with 4 additions and 54 deletions.
1 change: 1 addition & 0 deletions topomodelx/nn/simplicial/scone.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
# Lookup table used to speed up vectorizing of trajectories
self.edge_lookup_table = {}
for i, edge in enumerate(self.sc.skeleton(1)):
edge = tuple(edge)
self.edge_lookup_table[edge] = (1, i)
self.edge_lookup_table[edge[::-1]] = (-1, i)

Expand Down
8 changes: 0 additions & 8 deletions tutorials/cell/ccxn_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@
"source": [
"shrec, _ = tnx.datasets.shrec_16(size=\"small\")\n",
"\n",
"shrec = {key: np.array(value) for key, value in shrec.items()}\n",
"x_0s = shrec[\"node_feat\"]\n",
"x_1s = shrec[\"edge_feat\"]\n",
"x_2s = shrec[\"face_feat\"]\n",
Expand Down Expand Up @@ -634,13 +633,6 @@
" flush=True,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
8 changes: 0 additions & 8 deletions tutorials/cell/cwn_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@
"source": [
"shrec, _ = tnx.datasets.shrec_16(size=\"small\")\n",
"\n",
"shrec = {key: np.array(value) for key, value in shrec.items()}\n",
"x_0s = shrec[\"node_feat\"]\n",
"x_1s = shrec[\"edge_feat\"]\n",
"x_2s = shrec[\"face_feat\"]\n",
Expand Down Expand Up @@ -578,13 +577,6 @@
" flush=True,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
21 changes: 0 additions & 21 deletions tutorials/hypergraph/dhgcn_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@
"source": [
"shrec, _ = tnx.datasets.mesh.shrec_16(size=\"small\")\n",
"\n",
"shrec = {key: np.array(value) for key, value in shrec.items()}\n",
"x_0s = shrec[\"node_feat\"]\n",
"x_1s = shrec[\"edge_feat\"]\n",
"x_2s = shrec[\"face_feat\"]\n",
Expand All @@ -109,26 +108,6 @@
"simplexes = shrec[\"complexes\"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((100,), (100, 750, 10), (100, 500, 7), (100,), (100,))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_0s.shape, x_1s.shape, x_2s.shape, ys.shape, simplexes.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
Expand Down
1 change: 0 additions & 1 deletion tutorials/hypergraph/hypergat_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@
"source": [
"shrec, _ = tnx.datasets.shrec_16(size=\"small\")\n",
"\n",
"shrec = {key: np.array(value) for key, value in shrec.items()}\n",
"x_0s = shrec[\"node_feat\"]\n",
"x_1s = shrec[\"edge_feat\"]\n",
"x_2s = shrec[\"face_feat\"]\n",
Expand Down
9 changes: 1 addition & 8 deletions tutorials/simplicial/sccnn_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
],
"source": [
"shrec, _ = tnx.datasets.shrec_16(size=\"small\")\n",
"shrec = {key: np.array(value) for key, value in shrec.items()}\n",
"\n",
"x_0s = shrec[\"node_feat\"]\n",
"x_1s = shrec[\"edge_feat\"]\n",
"x_2s = shrec[\"face_feat\"]\n",
Expand Down Expand Up @@ -977,13 +977,6 @@
" )\n",
" print(f\"Test_acc: {test_accuracy:.4f}\", flush=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
8 changes: 0 additions & 8 deletions tutorials/simplicial/scn2_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@
"source": [
"shrec, _ = tnx.datasets.shrec_16(size=\"small\")\n",
"\n",
"shrec = {key: np.array(value) for key, value in shrec.items()}\n",
"x_0s = shrec[\"node_feat\"]\n",
"x_1s = shrec[\"edge_feat\"]\n",
"x_2s = shrec[\"face_feat\"]\n",
Expand Down Expand Up @@ -461,13 +460,6 @@
" test_loss = loss_fn(y_hat, y)\n",
" print(f\"Test_loss: {test_loss:.4f}\", flush=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
2 changes: 2 additions & 0 deletions tutorials/simplicial/scone_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@
" \"\"\"\n",
" # Plot triangles\n",
" for idx in sc.skeleton(2):\n",
" idx = tuple(idx)\n",
" pts = np.array([coords[idx[0]], coords[idx[1]], coords[idx[2]]])\n",
" poly = plt.Polygon(pts, color=\"green\", alpha=0.25)\n",
" plt.gca().add_patch(poly)\n",
Expand Down Expand Up @@ -397,6 +398,7 @@
" # Lookup table used to speed up vectorizing of trajectories\n",
" self.edge_lookup_table = {}\n",
" for i, edge in enumerate(self.sc.skeleton(1)):\n",
" edge = tuple(edge)\n",
" self.edge_lookup_table[edge] = (1, i)\n",
" self.edge_lookup_table[edge[::-1]] = (-1, i)\n",
"\n",
Expand Down

0 comments on commit f17c334

Please sign in to comment.