diff --git a/book/_config.yml b/book/_config.yml
index 251f2d4..8a794ba 100644
--- a/book/_config.yml
+++ b/book/_config.yml
@@ -39,7 +39,7 @@ execute:
- "**/geospatial-advanced.ipynb"
- "cloud-computing/04-cloud-optimized-icesat2.ipynb"
- "cloud-computing/atl08_parquet_files/atl08_parquet.ipynb"
- - "machine-learning/photon_classifier.ipynb"
+ - "machine-learning/point_cloud_classifier.ipynb"
allow_errors: false
# Per-cell notebook execution limit (seconds)
timeout: 300
diff --git a/book/_toc.yml b/book/_toc.yml
index 6f7333b..5cfa31f 100644
--- a/book/_toc.yml
+++ b/book/_toc.yml
@@ -39,8 +39,8 @@ parts:
- file: tutorials/cloud-computing/atl08_parquet_files/atl08_parquet
options:
- titlesonly: true
- - file: tutorials/mental-health/index
- - file: tutorials/machine-learning/photon_classifier.ipynb
+ - file: tutorials/mental-health/index
+ - file: tutorials/machine-learning/point_cloud_classifier.ipynb
- caption: Projects
chapters:
- file: projects/index
diff --git a/book/tutorials/index.md b/book/tutorials/index.md
index 44e87a6..d34dac7 100644
--- a/book/tutorials/index.md
+++ b/book/tutorials/index.md
@@ -11,4 +11,4 @@ Below you'll find a table keeping track of all tutorials presented at this event
| [ICESat-2 Mission](./mission-overview/icesat-2-mission-overview.ipynb) | ICESat-2 Mission and Products | n/a | Not recorded |
| [Cloud Computing](./cloud-computing/00-goals-and-outline.ipynb) | Cloud Computing Tutorial | n/a | Not recorded |
| [Notebooks to Packages](./nb-to-package/index.md) | All about Python classes to packages | n/a | Not recorded |
-| [ICESat-2 photon classification](./machine-learning/photon_classifier.ipynb) | Machine Learning, PyTorch | ATL07 | Not recorded |
+| [ICESat-2 point cloud classification](./machine-learning/point_cloud_classifier.ipynb) | Machine Learning, PyTorch | ATL07 | Not recorded |
diff --git a/book/tutorials/machine-learning/photon_classifier.ipynb b/book/tutorials/machine-learning/point_cloud_classifier.ipynb
similarity index 99%
rename from book/tutorials/machine-learning/photon_classifier.ipynb
rename to book/tutorials/machine-learning/point_cloud_classifier.ipynb
index 849ba1f..b60412a 100644
--- a/book/tutorials/machine-learning/photon_classifier.ipynb
+++ b/book/tutorials/machine-learning/point_cloud_classifier.ipynb
@@ -7,7 +7,7 @@
"source": [
"# Machine Learning with ICESat-2 data\n",
"\n",
- "A machine learning pipeline from point clouds to photon classifications.\n",
+ "A machine learning pipeline for point cloud classification.\n",
"\n",
"Reimplementation of [Koo et al., 2023](https://doi.org/10.1016/j.rse.2023.113726),\n",
"based on code available at https://github.com/YoungHyunKoo/IS2_ML."
@@ -16,19 +16,17 @@
{
"cell_type": "markdown",
"id": "554bbfcf",
- "metadata": {
- "lines_to_next_cell": 2
- },
+ "metadata": {},
"source": [
"```{admonition} Learning Objectives\n",
"By the end of this tutorial, you should be able to:\n",
"- Convert ICESat-2 point cloud data into an analysis/ML-ready format\n",
"- Recognize the different levels of complexity of ML approaches and the\n",
" benefits/challenges of each\n",
- "- Learn the potential of using ML for ICESat-2 photon classification\n",
+ "- Learn the potential of using ML for ICESat-2 point cloud classification\n",
"```\n",
"\n",
- "![ICESat-2 ATL07 sea ice photon classification ML pipeline](https://github.com/user-attachments/assets/509dab2d-d25d-417f-97ff-fc966f656ddf)"
+ "![ICESat-2 ATL07 sea ice point cloud classification ML pipeline](https://github.com/user-attachments/assets/d61521a4-d27a-4eb8-b886-3d92c5516c32)"
]
},
{
@@ -356,14 +354,6 @@
"- `item_collection` - Sentinel-2 optical satellite images"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8b6ffe42",
- "metadata": {},
- "outputs": [],
- "source": []
- },
{
"cell_type": "markdown",
"id": "369c2a5c",
@@ -889,7 +879,7 @@
"source": [
"### Optical imagery to label point clouds\n",
"\n",
- "Let's use the Sentinel-2 satellite image we found to label each ATL07 photon. We'll\n",
+ "Let's use the Sentinel-2 satellite image we found to label each ATL07 point. We'll\n",
"make a new column called `sea_ice_label` that can have either of these classifications:\n",
"\n",
"0. thick/snow-covered sea ice\n",
@@ -13921,8 +13911,8 @@
"id": "c60e8b12",
"metadata": {},
"source": [
- "Let's save the ATL07 photon data to a GeoParquet file so we don't have to run all the\n",
- "pre-processing code above again."
+ "Let's save the ATL07 point cloud data to a GeoParquet file so we don't have to run all\n",
+ "the pre-processing code above again."
]
},
{
@@ -13932,7 +13922,7 @@
"metadata": {},
"outputs": [],
"source": [
- "gdf.to_parquet(path=\"ATL07_photons.gpq\", compression=\"zstd\", schema_version=\"1.1.0\")"
+ "gdf.to_parquet(path=\"ATL07_point_cloud.gpq\", compression=\"zstd\", schema_version=\"1.1.0\")"
]
},
{
@@ -13957,7 +13947,7 @@
"outputs": [],
"source": [
"# Load GeoParquet file back into geopandas.GeoDataFrame\n",
- "gdf = gpd.read_parquet(path=\"ATL07_photons.gpq\")"
+ "gdf = gpd.read_parquet(path=\"ATL07_point_cloud.gpq\")"
]
},
{
@@ -13979,7 +13969,7 @@
"pipeline. We will create:\n",
"\n",
"1. A 'DataLoader', which is a fancy data container we can loop over; and\n",
- "2. A neural network 'model' that will take our input ATL07 data and output photon\n",
+ "2. A neural network 'model' that will take our input ATL07 data and output point cloud\n",
" classifications."
]
},
@@ -14010,9 +14000,7 @@
"cell_type": "code",
"execution_count": 25,
"id": "cc9cf774",
- "metadata": {
- "lines_to_next_cell": 1
- },
+ "metadata": {},
"outputs": [],
"source": [
"# Select data variables from DataFrame that will be used for training\n",
@@ -14054,11 +14042,13 @@
{
"cell_type": "markdown",
"id": "79141dca",
- "metadata": {},
+ "metadata": {
+ "lines_to_next_cell": 2
+ },
"source": [
"### Choosing a Machine Learning algorithm\n",
"\n",
- "Next is to pick a supervised learning 'model' for our photon classification task.\n",
+ "Next is to pick a supervised learning 'model' for our point cloud classification task.\n",
"There are a variety of machine learning methods to choose with different levels of\n",
"complexity:\n",
"\n",
@@ -14106,7 +14096,7 @@
"metadata": {},
"outputs": [],
"source": [
- "class PhotonClassificationModel(torch.nn.Module):\n",
+ "class PointCloudClassificationModel(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.linear1 = torch.nn.Linear(in_features=6, out_features=50)\n",
@@ -14131,7 +14121,7 @@
{
"data": {
"text/plain": [
- "PhotonClassificationModel(\n",
+ "PointCloudClassificationModel(\n",
" (linear1): Linear(in_features=6, out_features=50, bias=True)\n",
" (linear2): Linear(in_features=50, out_features=50, bias=True)\n",
" (linear3): Linear(in_features=50, out_features=3, bias=True)\n",
@@ -14144,7 +14134,7 @@
}
],
"source": [
- "model = PhotonClassificationModel()\n",
+ "model = PointCloudClassificationModel()\n",
"# model = model.to(device=\"cuda\") # uncomment this line if running on GPU\n",
"model"
]
@@ -14171,7 +14161,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 28,
"id": "3bfafafb",
"metadata": {},
"outputs": [],
@@ -14183,12 +14173,49 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 29,
"id": "f7ff6874",
"metadata": {
"lines_to_next_cell": 2
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 67%|██████▋ | 2/3 [00:00<00:00, 8.63it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "loss: 7949.509277 [ 8140/ 9454]\n",
+ "loss: 3678.343994 [ 8140/ 9454]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 3/3 [00:00<00:00, 8.64it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "loss: 1735.765503 [ 8140/ 9454]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
"source": [
"# Main training loop\n",
"max_epochs: int = 3\n",
@@ -14219,17 +14246,15 @@
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
- " # Report metrics\n",
- " current = (i + 1) * len(x)\n",
- " print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")"
+ " # Report metrics\n",
+ " current = (i + 1) * len(x)\n",
+ " print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")"
]
},
{
"cell_type": "markdown",
"id": "66358b37",
- "metadata": {
- "lines_to_next_cell": 2
- },
+ "metadata": {},
"source": [
"Did the model learn something? A good sign to check is if the loss value is\n",
"decreasing, which means the error between the predicted and groundtruth value is\n",
@@ -14238,25 +14263,202 @@
},
{
"cell_type": "markdown",
- "id": "709f7a1b",
- "metadata": {
- "lines_to_next_cell": 2
- },
+ "id": "b82c7138",
+ "metadata": {},
"source": [
- "## References\n",
- "- Koo, Y., Xie, H., Kurtz, N. T., Ackley, S. F., & Wang, W. (2023).\n",
- " Sea ice surface type classification of ICESat-2 ATL07 data by using data-driven\n",
- " machine learning model: Ross Sea, Antarctic as an example. Remote Sensing of\n",
- " Environment, 296, 113726. https://doi.org/10.1016/j.rse.2023.113726"
+ "\n",
+ "### Inference results\n",
+ "\n",
+ "Besides monitoring the loss value, it is also good to calculate a metric like\n",
+ "Precision, Recall or F1-score. Let's first run the model in 'inference' mode to get\n",
+ "predictions."
]
},
{
"cell_type": "code",
- "execution_count": null,
- "id": "5014ed28",
+ "execution_count": 30,
+ "id": "fd38a5fa",
"metadata": {},
"outputs": [],
- "source": []
+ "source": [
+ "gdf[\"predicted_surface_type\"] = None # create new column with NaN to store results\n",
+ "with torch.inference_mode():\n",
+ " for i, batch in enumerate(dataloader):\n",
+ " minibatch: torch.Tensor = batch[0]\n",
+ " x = minibatch[:, :6]\n",
+ " prediction = model(x=x) # one-hot encoded predictions\n",
+ " prediction_labels = torch.argmax(input=prediction, dim=1) # 0/1/2 labels\n",
+ "\n",
+ " start_index = i * dataloader.batch_size\n",
+ " stop_index = start_index + len(minibatch) - 1\n",
+ " gdf.loc[start_index:stop_index, \"predicted_surface_type\"] = prediction_labels"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fbfa6348",
+ "metadata": {},
+ "source": [
+ "\n",
+ "```{caution}\n",
+ "Ideally, you would want to run inference on a hold-out validation or test set, rather\n",
+ "than the points the model was trained on! See e.g.\n",
+ "[`sklearn.model_selection.train_test_split`](https://scikit-learn.org/1.5/modules/generated/sklearn.model_selection.train_test_split.html)\n",
+ "on how this can be done.\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a578ef64",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Now that we have the predicted results in the `predicted_surface_type` column, we can\n",
+ "compare it with the 'groundtruth' labels in the 'surface_type' column by visualizing\n",
+ "it in a confusion matrix."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "61bd172e",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " Predicted | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " All | \n",
+ "
\n",
+ " \n",
+ " Actual | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 725 | \n",
+ " 725 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 577 | \n",
+ " 578 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 8149 | \n",
+ " 8151 | \n",
+ "
\n",
+ " \n",
+ " All | \n",
+ " 3 | \n",
+ " 9451 | \n",
+ " 9454 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ "Predicted 1 2 All\n",
+ "Actual \n",
+ "0 0 725 725\n",
+ "1 1 577 578\n",
+ "2 2 8149 8151\n",
+ "All 3 9451 9454"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.crosstab(\n",
+ " index=gdf.surface_type,\n",
+ " columns=gdf.predicted_surface_type,\n",
+ " rownames=[\"Actual\"],\n",
+ " colnames=[\"Predicted\"],\n",
+ " margins=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "737847e3",
+ "metadata": {},
+ "source": [
+ "\n",
+ "```{attention}\n",
+ "Oo, it looks like our model isn't producing very good results! It's practically only\n",
+ "predicting thick sea ice (class: 2). There could be many reasons for this, and it's up\n",
+ "to you to figure out a solution, either by changing the data, or adjusting the model.\n",
+ "\n",
+ "Data-centric approaches:\n",
+ "- Add more data! Maybe <10000 points isn't enough, try getting more!\n",
+ "- Check the labels! Are there wrongly labelled points? Is the Sentinel-2\n",
+ " dark/gray/bright bins above too simplistic? Investigate!\n",
+ "- Normalize the data value range. The original paper by Koo et al., 2023 applied\n",
+ " min-max normalization on the 6 input columns, try and apply that too!\n",
+ "\n",
+ "Model-centric appraoches:\n",
+ "- Manage class imbalance. There are a lot more thick sea ice points than thin sea ice\n",
+ " or water points, could we modify the loss function to weigh rare classes higher?\n",
+ "- Adjust the model hyperparameters, try adjusting the learning rate, train the model\n",
+ " for more epochs, etc.\n",
+ "- Tweak the model architecture. The original paper by Koo et al., 2023 used a\n",
+ " [`tanh`](https://pytorch.org/docs/2.4/generated/torch.nn.Tanh.html) activation\n",
+ " function in the neural network layers. Will adding that help?\n",
+ "\n",
+ "The list above isn't exhaustive, and different machine learning practicioners may have\n",
+ "other suggestions on what to try next. That said, you now have a Machine Learning\n",
+ "ready GeoParquet dataset to iterate on ideas more quickly. Good luck!\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "709f7a1b",
+ "metadata": {},
+ "source": [
+ "## References\n",
+ "- Koo, Y., Xie, H., Kurtz, N. T., Ackley, S. F., & Wang, W. (2023).\n",
+ " Sea ice surface type classification of ICESat-2 ATL07 data by using data-driven\n",
+ " machine learning model: Ross Sea, Antarctic as an example. Remote Sensing of\n",
+ " Environment, 296, 113726. https://doi.org/10.1016/j.rse.2023.113726\n",
+ "- Petty, A. A., Bagnardi, M., Kurtz, N. T., Tilling, R., Fons, S., Armitage, T.,\n",
+ " Horvat, C., & Kwok, R. (2021). Assessment of ICESat‐2 Sea Ice Surface\n",
+ " Classification with Sentinel‐2 Imagery: Implications for Freeboard and New Estimates\n",
+ " of Lead and Floe Geometry. Earth and Space Science, 8(3), e2020EA001491.\n",
+ " https://doi.org/10.1029/2020EA001491"
+ ]
}
],
"metadata": {
diff --git a/book/tutorials/machine-learning/photon_classifier.py b/book/tutorials/machine-learning/point_cloud_classifier.py
similarity index 83%
rename from book/tutorials/machine-learning/photon_classifier.py
rename to book/tutorials/machine-learning/point_cloud_classifier.py
index 05f07bd..3d51c7d 100644
--- a/book/tutorials/machine-learning/photon_classifier.py
+++ b/book/tutorials/machine-learning/point_cloud_classifier.py
@@ -16,7 +16,7 @@
# %% [markdown]
# # Machine Learning with ICESat-2 data
#
-# A machine learning pipeline from point clouds to photon classifications.
+# A machine learning pipeline for point cloud classification.
#
# Reimplementation of [Koo et al., 2023](https://doi.org/10.1016/j.rse.2023.113726),
# based on code available at https://github.com/YoungHyunKoo/IS2_ML.
@@ -27,11 +27,10 @@
# - Convert ICESat-2 point cloud data into an analysis/ML-ready format
# - Recognize the different levels of complexity of ML approaches and the
# benefits/challenges of each
-# - Learn the potential of using ML for ICESat-2 photon classification
+# - Learn the potential of using ML for ICESat-2 point cloud classification
# ```
#
-# ![ICESat-2 ATL07 sea ice photon classification ML pipeline](https://github.com/user-attachments/assets/509dab2d-d25d-417f-97ff-fc966f656ddf)
-
+# ![ICESat-2 ATL07 sea ice point cloud classification ML pipeline](https://github.com/user-attachments/assets/d61521a4-d27a-4eb8-b886-3d92c5516c32)
# %% [markdown]
# ## Part 0: Setup
@@ -155,8 +154,6 @@
# - `granule` - ICESat-2 ATL07 sea ice point cloud data
# - `item_collection` - Sentinel-2 optical satellite images
-# %%
-
# %% [markdown]
# ### Filter to strong beams and required data variables
#
@@ -253,7 +250,7 @@
# %% [markdown]
# ### Optical imagery to label point clouds
#
-# Let's use the Sentinel-2 satellite image we found to label each ATL07 photon. We'll
+# Let's use the Sentinel-2 satellite image we found to label each ATL07 point. We'll
# make a new column called `sea_ice_label` that can have either of these classifications:
#
# 0. thick/snow-covered sea ice
@@ -396,11 +393,11 @@
# ### Save to GeoParquet
# %% [markdown]
-# Let's save the ATL07 photon data to a GeoParquet file so we don't have to run all the
-# pre-processing code above again.
+# Let's save the ATL07 point cloud data to a GeoParquet file so we don't have to run all
+# the pre-processing code above again.
# %%
-gdf.to_parquet(path="ATL07_photons.gpq", compression="zstd", schema_version="1.1.0")
+gdf.to_parquet(path="ATL07_point_cloud.gpq", compression="zstd", schema_version="1.1.0")
# %% [markdown]
# ```{note} To compress or not?
@@ -413,7 +410,7 @@
# %%
# Load GeoParquet file back into geopandas.GeoDataFrame
-gdf = gpd.read_parquet(path="ATL07_photons.gpq")
+gdf = gpd.read_parquet(path="ATL07_point_cloud.gpq")
# %%
@@ -424,7 +421,7 @@
# pipeline. We will create:
#
# 1. A 'DataLoader', which is a fancy data container we can loop over; and
-# 2. A neural network 'model' that will take our input ATL07 data and output photon
+# 2. A neural network 'model' that will take our input ATL07 data and output point cloud
# classifications.
# %% [markdown]
@@ -479,7 +476,7 @@
# %% [markdown]
# ### Choosing a Machine Learning algorithm
#
-# Next is to pick a supervised learning 'model' for our photon classification task.
+# Next is to pick a supervised learning 'model' for our point cloud classification task.
# There are a variety of machine learning methods to choose with different levels of
# complexity:
#
@@ -519,8 +516,9 @@
# - Output layer with 3 nodes, for 3 surface types (open water, thin ice,
# thick/snow-covered ice)
+
# %%
-class PhotonClassificationModel(torch.nn.Module):
+class PointCloudClassificationModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(in_features=6, out_features=50)
@@ -535,7 +533,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# %%
-model = PhotonClassificationModel()
+model = PointCloudClassificationModel()
# model = model.to(device="cuda") # uncomment this line if running on GPU
model
@@ -590,9 +588,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
optimizer.step()
optimizer.zero_grad()
- # Report metrics
- current = (i + 1) * len(x)
- print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
+ # Report metrics
+ current = (i + 1) * len(x)
+ print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
# %% [markdown]
@@ -600,6 +598,78 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# decreasing, which means the error between the predicted and groundtruth value is
# getting smaller.
+# %% [markdown]
+#
+# ### Inference results
+#
+# Besides monitoring the loss value, it is also good to calculate a metric like
+# Precision, Recall or F1-score. Let's first run the model in 'inference' mode to get
+# predictions.
+
+# %%
+gdf["predicted_surface_type"] = None # create new column with NaN to store results
+with torch.inference_mode():
+ for i, batch in enumerate(dataloader):
+ minibatch: torch.Tensor = batch[0]
+ x = minibatch[:, :6]
+ prediction = model(x=x) # one-hot encoded predictions
+ prediction_labels = torch.argmax(input=prediction, dim=1) # 0/1/2 labels
+
+ start_index = i * dataloader.batch_size
+ stop_index = start_index + len(minibatch) - 1
+ gdf.loc[start_index:stop_index, "predicted_surface_type"] = prediction_labels
+
+# %% [markdown]
+#
+# ```{caution}
+# Ideally, you would want to run inference on a hold-out validation or test set, rather
+# than the points the model was trained on! See e.g.
+# [`sklearn.model_selection.train_test_split`](https://scikit-learn.org/1.5/modules/generated/sklearn.model_selection.train_test_split.html)
+# on how this can be done.
+# ```
+
+# %% [markdown]
+#
+# Now that we have the predicted results in the `predicted_surface_type` column, we can
+# compare it with the 'groundtruth' labels in the 'surface_type' column by visualizing
+# it in a confusion matrix.
+
+# %%
+pd.crosstab(
+ index=gdf.surface_type,
+ columns=gdf.predicted_surface_type,
+ rownames=["Actual"],
+ colnames=["Predicted"],
+ margins=True,
+)
+
+# %% [markdown]
+#
+# ```{attention}
+# Oo, it looks like our model isn't producing very good results! It's practically only
+# predicting thick sea ice (class: 2). There could be many reasons for this, and it's up
+# to you to figure out a solution, either by changing the data, or adjusting the model.
+#
+# Data-centric approaches:
+# - Add more data! Maybe <10000 points isn't enough, try getting more!
+# - Check the labels! Are there wrongly labelled points? Is the Sentinel-2
+# dark/gray/bright bins above too simplistic? Investigate!
+# - Normalize the data value range. The original paper by Koo et al., 2023 applied
+# min-max normalization on the 6 input columns, try and apply that too!
+#
+# Model-centric appraoches:
+# - Manage class imbalance. There are a lot more thick sea ice points than thin sea ice
+# or water points, could we modify the loss function to weigh rare classes higher?
+# - Adjust the model hyperparameters, try adjusting the learning rate, train the model
+# for more epochs, etc.
+# - Tweak the model architecture. The original paper by Koo et al., 2023 used a
+# [`tanh`](https://pytorch.org/docs/2.4/generated/torch.nn.Tanh.html) activation
+# function in the neural network layers. Will adding that help?
+#
+# The list above isn't exhaustive, and different machine learning practicioners may have
+# other suggestions on what to try next. That said, you now have a Machine Learning
+# ready GeoParquet dataset to iterate on ideas more quickly. Good luck!
+# ```
# %% [markdown]
# ## References
@@ -607,6 +677,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Sea ice surface type classification of ICESat-2 ATL07 data by using data-driven
# machine learning model: Ross Sea, Antarctic as an example. Remote Sensing of
# Environment, 296, 113726. https://doi.org/10.1016/j.rse.2023.113726
-
-
-# %%
+# - Petty, A. A., Bagnardi, M., Kurtz, N. T., Tilling, R., Fons, S., Armitage, T.,
+# Horvat, C., & Kwok, R. (2021). Assessment of ICESat‐2 Sea Ice Surface
+# Classification with Sentinel‐2 Imagery: Implications for Freeboard and New Estimates
+# of Lead and Floe Geometry. Earth and Space Science, 8(3), e2020EA001491.
+# https://doi.org/10.1029/2020EA001491