From c857a7204b6c28fbbcdd037d52f4a382ab533b1e Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Sat, 27 Jul 2024 12:35:49 -0400 Subject: [PATCH] Update torch name --- geoai/segmentation.py | 19 +++++++++++++++++++ requirements.txt | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/geoai/segmentation.py b/geoai/segmentation.py index 44c64e1..a781a9b 100644 --- a/geoai/segmentation.py +++ b/geoai/segmentation.py @@ -328,3 +328,22 @@ def visualize_predictions( plt.tight_layout() plt.show() + + +# Example usage +if __name__ == "__main__": + images_dir = "../datasets/Water-Bodies-Dataset/Images" + masks_dir = "../datasets/Water-Bodies-Dataset/Masks" + transform = get_transform() + train_dataset, val_dataset = prepare_datasets(images_dir, masks_dir, transform) + + model_save_path = "./fine_tuned_model" + train_model(train_dataset, val_dataset, model_save_path) + + image_path = "../datasets/Water-Bodies-Dataset/Images/water_body_44.jpg" + reference_image_path = image_path.replace("Images", "Masks") + segmented_mask = segment_image(image_path, model_save_path) + + visualize_predictions( + image_path, segmented_mask, reference_image_path=reference_image_path + ) diff --git a/requirements.txt b/requirements.txt index e6935e9..9828fc8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ albumentations -pytorch scikit-learn segment-geospatial +torch transformers