From d3b705645e1af3a4b9e69a710dbc3bc4a402d4f1 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Mon, 19 Aug 2024 11:51:20 -0700 Subject: [PATCH] Update ATL07 point cloud classifier notebook (#29) * Rename photon to point cloud ATL07 points are not actually photons, but an aggregate of ATL03 points. * Show inference results in confusion matrix and closing words Demonstrate how the trained model can be used to produce classification results in a new column of the GeoDataFrame, and compare those predictions with the groundtruth in a confusion matrix plot. Added some closing words about data-centric and model-centric ways of improving the results, and added the Petty et al. 2021 paper to the citation list as credit to Alek's help. --- book/_config.yml | 2 +- book/_toc.yml | 4 +- book/tutorials/index.md | 2 +- ...ier.ipynb => point_cloud_classifier.ipynb} | 300 +++++++++++++++--- ...lassifier.py => point_cloud_classifier.py} | 114 +++++-- 5 files changed, 348 insertions(+), 74 deletions(-) rename book/tutorials/machine-learning/{photon_classifier.ipynb => point_cloud_classifier.ipynb} (99%) rename book/tutorials/machine-learning/{photon_classifier.py => point_cloud_classifier.py} (83%) diff --git a/book/_config.yml b/book/_config.yml index 251f2d4..8a794ba 100644 --- a/book/_config.yml +++ b/book/_config.yml @@ -39,7 +39,7 @@ execute: - "**/geospatial-advanced.ipynb" - "cloud-computing/04-cloud-optimized-icesat2.ipynb" - "cloud-computing/atl08_parquet_files/atl08_parquet.ipynb" - - "machine-learning/photon_classifier.ipynb" + - "machine-learning/point_cloud_classifier.ipynb" allow_errors: false # Per-cell notebook execution limit (seconds) timeout: 300 diff --git a/book/_toc.yml b/book/_toc.yml index 6f7333b..5cfa31f 100644 --- a/book/_toc.yml +++ b/book/_toc.yml @@ -39,8 +39,8 @@ parts: - file: tutorials/cloud-computing/atl08_parquet_files/atl08_parquet options: - titlesonly: true - - file: tutorials/mental-health/index - - file: tutorials/machine-learning/photon_classifier.ipynb + - file: tutorials/mental-health/index + - file: tutorials/machine-learning/point_cloud_classifier.ipynb - caption: Projects chapters: - file: projects/index diff --git a/book/tutorials/index.md b/book/tutorials/index.md index 44e87a6..d34dac7 100644 --- a/book/tutorials/index.md +++ b/book/tutorials/index.md @@ -11,4 +11,4 @@ Below you'll find a table keeping track of all tutorials presented at this event | [ICESat-2 Mission](./mission-overview/icesat-2-mission-overview.ipynb) | ICESat-2 Mission and Products | n/a | Not recorded | | [Cloud Computing](./cloud-computing/00-goals-and-outline.ipynb) | Cloud Computing Tutorial | n/a | Not recorded | | [Notebooks to Packages](./nb-to-package/index.md) | All about Python classes to packages | n/a | Not recorded | -| [ICESat-2 photon classification](./machine-learning/photon_classifier.ipynb) | Machine Learning, PyTorch | ATL07 | Not recorded | +| [ICESat-2 point cloud classification](./machine-learning/point_cloud_classifier.ipynb) | Machine Learning, PyTorch | ATL07 | Not recorded | diff --git a/book/tutorials/machine-learning/photon_classifier.ipynb b/book/tutorials/machine-learning/point_cloud_classifier.ipynb similarity index 99% rename from book/tutorials/machine-learning/photon_classifier.ipynb rename to book/tutorials/machine-learning/point_cloud_classifier.ipynb index 849ba1f..b60412a 100644 --- a/book/tutorials/machine-learning/photon_classifier.ipynb +++ b/book/tutorials/machine-learning/point_cloud_classifier.ipynb @@ -7,7 +7,7 @@ "source": [ "# Machine Learning with ICESat-2 data\n", "\n", - "A machine learning pipeline from point clouds to photon classifications.\n", + "A machine learning pipeline for point cloud classification.\n", "\n", "Reimplementation of [Koo et al., 2023](https://doi.org/10.1016/j.rse.2023.113726),\n", "based on code available at https://github.com/YoungHyunKoo/IS2_ML." @@ -16,19 +16,17 @@ { "cell_type": "markdown", "id": "554bbfcf", - "metadata": { - "lines_to_next_cell": 2 - }, + "metadata": {}, "source": [ "```{admonition} Learning Objectives\n", "By the end of this tutorial, you should be able to:\n", "- Convert ICESat-2 point cloud data into an analysis/ML-ready format\n", "- Recognize the different levels of complexity of ML approaches and the\n", " benefits/challenges of each\n", - "- Learn the potential of using ML for ICESat-2 photon classification\n", + "- Learn the potential of using ML for ICESat-2 point cloud classification\n", "```\n", "\n", - "![ICESat-2 ATL07 sea ice photon classification ML pipeline](https://github.com/user-attachments/assets/509dab2d-d25d-417f-97ff-fc966f656ddf)" + "![ICESat-2 ATL07 sea ice point cloud classification ML pipeline](https://github.com/user-attachments/assets/d61521a4-d27a-4eb8-b886-3d92c5516c32)" ] }, { @@ -356,14 +354,6 @@ "- `item_collection` - Sentinel-2 optical satellite images" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "8b6ffe42", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "id": "369c2a5c", @@ -889,7 +879,7 @@ "source": [ "### Optical imagery to label point clouds\n", "\n", - "Let's use the Sentinel-2 satellite image we found to label each ATL07 photon. We'll\n", + "Let's use the Sentinel-2 satellite image we found to label each ATL07 point. We'll\n", "make a new column called `sea_ice_label` that can have either of these classifications:\n", "\n", "0. thick/snow-covered sea ice\n", @@ -13921,8 +13911,8 @@ "id": "c60e8b12", "metadata": {}, "source": [ - "Let's save the ATL07 photon data to a GeoParquet file so we don't have to run all the\n", - "pre-processing code above again." + "Let's save the ATL07 point cloud data to a GeoParquet file so we don't have to run all\n", + "the pre-processing code above again." ] }, { @@ -13932,7 +13922,7 @@ "metadata": {}, "outputs": [], "source": [ - "gdf.to_parquet(path=\"ATL07_photons.gpq\", compression=\"zstd\", schema_version=\"1.1.0\")" + "gdf.to_parquet(path=\"ATL07_point_cloud.gpq\", compression=\"zstd\", schema_version=\"1.1.0\")" ] }, { @@ -13957,7 +13947,7 @@ "outputs": [], "source": [ "# Load GeoParquet file back into geopandas.GeoDataFrame\n", - "gdf = gpd.read_parquet(path=\"ATL07_photons.gpq\")" + "gdf = gpd.read_parquet(path=\"ATL07_point_cloud.gpq\")" ] }, { @@ -13979,7 +13969,7 @@ "pipeline. We will create:\n", "\n", "1. A 'DataLoader', which is a fancy data container we can loop over; and\n", - "2. A neural network 'model' that will take our input ATL07 data and output photon\n", + "2. A neural network 'model' that will take our input ATL07 data and output point cloud\n", " classifications." ] }, @@ -14010,9 +14000,7 @@ "cell_type": "code", "execution_count": 25, "id": "cc9cf774", - "metadata": { - "lines_to_next_cell": 1 - }, + "metadata": {}, "outputs": [], "source": [ "# Select data variables from DataFrame that will be used for training\n", @@ -14054,11 +14042,13 @@ { "cell_type": "markdown", "id": "79141dca", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 2 + }, "source": [ "### Choosing a Machine Learning algorithm\n", "\n", - "Next is to pick a supervised learning 'model' for our photon classification task.\n", + "Next is to pick a supervised learning 'model' for our point cloud classification task.\n", "There are a variety of machine learning methods to choose with different levels of\n", "complexity:\n", "\n", @@ -14106,7 +14096,7 @@ "metadata": {}, "outputs": [], "source": [ - "class PhotonClassificationModel(torch.nn.Module):\n", + "class PointCloudClassificationModel(torch.nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.linear1 = torch.nn.Linear(in_features=6, out_features=50)\n", @@ -14131,7 +14121,7 @@ { "data": { "text/plain": [ - "PhotonClassificationModel(\n", + "PointCloudClassificationModel(\n", " (linear1): Linear(in_features=6, out_features=50, bias=True)\n", " (linear2): Linear(in_features=50, out_features=50, bias=True)\n", " (linear3): Linear(in_features=50, out_features=3, bias=True)\n", @@ -14144,7 +14134,7 @@ } ], "source": [ - "model = PhotonClassificationModel()\n", + "model = PointCloudClassificationModel()\n", "# model = model.to(device=\"cuda\") # uncomment this line if running on GPU\n", "model" ] @@ -14171,7 +14161,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "id": "3bfafafb", "metadata": {}, "outputs": [], @@ -14183,12 +14173,49 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 29, "id": "f7ff6874", "metadata": { "lines_to_next_cell": 2 }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 67%|██████▋ | 2/3 [00:00<00:00, 8.63it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss: 7949.509277 [ 8140/ 9454]\n", + "loss: 3678.343994 [ 8140/ 9454]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 3/3 [00:00<00:00, 8.64it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss: 1735.765503 [ 8140/ 9454]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# Main training loop\n", "max_epochs: int = 3\n", @@ -14219,17 +14246,15 @@ " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", - " # Report metrics\n", - " current = (i + 1) * len(x)\n", - " print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")" + " # Report metrics\n", + " current = (i + 1) * len(x)\n", + " print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")" ] }, { "cell_type": "markdown", "id": "66358b37", - "metadata": { - "lines_to_next_cell": 2 - }, + "metadata": {}, "source": [ "Did the model learn something? A good sign to check is if the loss value is\n", "decreasing, which means the error between the predicted and groundtruth value is\n", @@ -14238,25 +14263,202 @@ }, { "cell_type": "markdown", - "id": "709f7a1b", - "metadata": { - "lines_to_next_cell": 2 - }, + "id": "b82c7138", + "metadata": {}, "source": [ - "## References\n", - "- Koo, Y., Xie, H., Kurtz, N. T., Ackley, S. F., & Wang, W. (2023).\n", - " Sea ice surface type classification of ICESat-2 ATL07 data by using data-driven\n", - " machine learning model: Ross Sea, Antarctic as an example. Remote Sensing of\n", - " Environment, 296, 113726. https://doi.org/10.1016/j.rse.2023.113726" + "\n", + "### Inference results\n", + "\n", + "Besides monitoring the loss value, it is also good to calculate a metric like\n", + "Precision, Recall or F1-score. Let's first run the model in 'inference' mode to get\n", + "predictions." ] }, { "cell_type": "code", - "execution_count": null, - "id": "5014ed28", + "execution_count": 30, + "id": "fd38a5fa", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "gdf[\"predicted_surface_type\"] = None # create new column with NaN to store results\n", + "with torch.inference_mode():\n", + " for i, batch in enumerate(dataloader):\n", + " minibatch: torch.Tensor = batch[0]\n", + " x = minibatch[:, :6]\n", + " prediction = model(x=x) # one-hot encoded predictions\n", + " prediction_labels = torch.argmax(input=prediction, dim=1) # 0/1/2 labels\n", + "\n", + " start_index = i * dataloader.batch_size\n", + " stop_index = start_index + len(minibatch) - 1\n", + " gdf.loc[start_index:stop_index, \"predicted_surface_type\"] = prediction_labels" + ] + }, + { + "cell_type": "markdown", + "id": "fbfa6348", + "metadata": {}, + "source": [ + "\n", + "```{caution}\n", + "Ideally, you would want to run inference on a hold-out validation or test set, rather\n", + "than the points the model was trained on! See e.g.\n", + "[`sklearn.model_selection.train_test_split`](https://scikit-learn.org/1.5/modules/generated/sklearn.model_selection.train_test_split.html)\n", + "on how this can be done.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "a578ef64", + "metadata": {}, + "source": [ + "\n", + "Now that we have the predicted results in the `predicted_surface_type` column, we can\n", + "compare it with the 'groundtruth' labels in the 'surface_type' column by visualizing\n", + "it in a confusion matrix." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "61bd172e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Predicted | \n", + "1 | \n", + "2 | \n", + "All | \n", + "
---|---|---|---|
Actual | \n", + "\n", + " | \n", + " | \n", + " |
0 | \n", + "0 | \n", + "725 | \n", + "725 | \n", + "
1 | \n", + "1 | \n", + "577 | \n", + "578 | \n", + "
2 | \n", + "2 | \n", + "8149 | \n", + "8151 | \n", + "
All | \n", + "3 | \n", + "9451 | \n", + "9454 | \n", + "