From af7b7ab082bb87e1920866ea3ae3735a2d801678 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 28 May 2024 08:06:17 +0000 Subject: [PATCH 01/10] Update VHR10 dataset visualization method --- torchgeo/datasets/vhr10.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 30ea07c908c..5327be2e279 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -22,6 +22,7 @@ download_and_extract_archive, download_url, lazy_import, + percentile_normalization ) @@ -386,10 +387,13 @@ def plot( .. versionadded:: 0.4 """ assert show_feats in {'boxes', 'masks', 'both'} + image = percentile_normalization(sample['image'].permute(1, 2, 0).numpy()) + boxes = sample['boxes'].cpu().numpy() + labels = sample['labels'].cpu().numpy() if self.split == 'negative': fig, axs = plt.subplots(squeeze=False) - axs[0, 0].imshow(sample['image'].permute(1, 2, 0)) + axs[0, 0].imshow(image) axs[0, 0].axis('off') if suptitle is not None: @@ -399,9 +403,6 @@ def plot( if show_feats != 'boxes': skimage = lazy_import('skimage') - image = sample['image'].permute(1, 2, 0).numpy() - boxes = sample['boxes'].cpu().numpy() - labels = sample['labels'].cpu().numpy() if 'masks' in sample: masks = [mask.squeeze().cpu().numpy() for mask in sample['masks']] From dee5199ee3272eec24175c8030de9b6b8093adfb Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 28 May 2024 08:08:44 +0000 Subject: [PATCH 02/10] Update VHR10 info in readme --- README.md | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 38883a212d7..6e213577b86 100644 --- a/README.md +++ b/README.md @@ -120,22 +120,15 @@ TorchGeo includes a number of [_benchmark datasets_](https://torchgeo.readthedoc If you've used [torchvision](https://pytorch.org/vision) before, these datasets should seem very familiar. In this example, we'll create a dataset for the Northwestern Polytechnical University (NWPU) very-high-resolution ten-class ([VHR-10](https://github.com/chaozhong2010/VHR-10_dataset_coco)) geospatial object detection dataset. This dataset can be automatically downloaded, checksummed, and extracted, just like with torchvision. ```python -from torch.utils.data import DataLoader - -from torchgeo.datamodules.utils import collate_fn_detection from torchgeo.datasets import VHR10 +from torchgeo.datamodules import VHR10DataModule # Initialize the dataset dataset = VHR10(root="...", download=True, checksum=True) # Initialize the dataloader with the custom collate function -dataloader = DataLoader( - dataset, - batch_size=128, - shuffle=True, - num_workers=4, - collate_fn=collate_fn_detection, -) +datamodule = VHR10DataModule(root="data", batch_size=32, num_workers=16) +datamodule.setup("fit") # Training loop for batch in dataloader: From 087223c66957ba072ddd637b10b2c2c6f351acd8 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 28 May 2024 08:24:51 +0000 Subject: [PATCH 03/10] isort --- torchgeo/datasets/vhr10.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 5327be2e279..fe674fcaf9d 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -17,13 +17,8 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import ( - check_integrity, - download_and_extract_archive, - download_url, - lazy_import, - percentile_normalization -) +from .utils import (check_integrity, download_and_extract_archive, + download_url, lazy_import, percentile_normalization) def convert_coco_poly_to_mask( From 5a9776b84acfb17dd90469e1b725abbc7c445ec3 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 28 May 2024 13:59:16 +0000 Subject: [PATCH 04/10] Missing annotations handling in plotting --- torchgeo/datasets/vhr10.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index fe674fcaf9d..fc9b4845bb6 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -247,7 +247,12 @@ def __getitem__(self, index: int) -> dict[str, Any]: sample['labels'] = sample['label']['labels'] sample['boxes'] = sample['label']['boxes'] sample['masks'] = sample['label']['masks'] - del sample['label'] + else: + # Ensure the keys are always present even if there are no annotations + sample['labels'] = torch.empty((0,), dtype=torch.int64) + sample['boxes'] = torch.empty((0, 4), dtype=torch.float32) + + del sample['label'] if self.transforms is not None: sample = self.transforms(sample) From cab99c6b9e06e5f73420a48ba8a32bcaaa8f68d2 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 28 May 2024 15:38:22 +0000 Subject: [PATCH 05/10] Update VHR10 dataset handling and visualization --- torchgeo/datasets/vhr10.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index fc9b4845bb6..7b3f57cfc5d 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -246,13 +246,8 @@ def __getitem__(self, index: int) -> dict[str, Any]: sample = self.coco_convert(sample) sample['labels'] = sample['label']['labels'] sample['boxes'] = sample['label']['boxes'] - sample['masks'] = sample['label']['masks'] - else: - # Ensure the keys are always present even if there are no annotations - sample['labels'] = torch.empty((0,), dtype=torch.int64) - sample['boxes'] = torch.empty((0, 4), dtype=torch.float32) - - del sample['label'] + sample['masks'] = sample['label']['masks'] + del sample['label'] if self.transforms is not None: sample = self.transforms(sample) @@ -388,8 +383,6 @@ def plot( """ assert show_feats in {'boxes', 'masks', 'both'} image = percentile_normalization(sample['image'].permute(1, 2, 0).numpy()) - boxes = sample['boxes'].cpu().numpy() - labels = sample['labels'].cpu().numpy() if self.split == 'negative': fig, axs = plt.subplots(squeeze=False) @@ -403,6 +396,9 @@ def plot( if show_feats != 'boxes': skimage = lazy_import('skimage') + boxes = sample['boxes'].cpu().numpy() + labels = sample['labels'].cpu().numpy() + if 'masks' in sample: masks = [mask.squeeze().cpu().numpy() for mask in sample['masks']] From a0e17f0e3ce05d397c2dd9ed0f5db85cbca890ca Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 28 May 2024 15:39:58 +0000 Subject: [PATCH 06/10] remove whitespace --- torchgeo/datasets/vhr10.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 7b3f57cfc5d..bf3ba262a77 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -246,7 +246,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: sample = self.coco_convert(sample) sample['labels'] = sample['label']['labels'] sample['boxes'] = sample['label']['boxes'] - sample['masks'] = sample['label']['masks'] + sample['masks'] = sample['label']['masks'] del sample['label'] if self.transforms is not None: From 1baac805aa8a3540292fcaa2191fb6a6d33b5ac1 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 28 May 2024 15:40:22 +0000 Subject: [PATCH 07/10] Fix whitespace issue in VHR10 dataset handling --- torchgeo/datasets/vhr10.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index bf3ba262a77..41f85f851f1 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -246,7 +246,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: sample = self.coco_convert(sample) sample['labels'] = sample['label']['labels'] sample['boxes'] = sample['label']['boxes'] - sample['masks'] = sample['label']['masks'] + sample['masks'] = sample['label']['masks'] del sample['label'] if self.transforms is not None: From 44cb52ac31f8945cc0506a255656d0238151bcce Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 28 May 2024 16:18:39 +0000 Subject: [PATCH 08/10] black format --- torchgeo/datasets/vhr10.py | 209 +++++++++++++++++++------------------ 1 file changed, 107 insertions(+), 102 deletions(-) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 41f85f851f1..8c090621edf 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -17,8 +17,13 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import (check_integrity, download_and_extract_archive, - download_url, lazy_import, percentile_normalization) +from .utils import ( + check_integrity, + download_and_extract_archive, + download_url, + lazy_import, + percentile_normalization, +) def convert_coco_poly_to_mask( @@ -37,7 +42,7 @@ def convert_coco_poly_to_mask( Raises: DependencyNotFoundError: If pycocotools is not installed. """ - pycocotools = lazy_import('pycocotools') + pycocotools = lazy_import("pycocotools") masks = [] for polygons in segmentations: rles = pycocotools.mask.frPyObjects(polygons, height, width) @@ -65,28 +70,28 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: Returns: Processed sample """ - image = sample['image'] + image = sample["image"] _, h, w = image.size() - target = sample['label'] + target = sample["label"] - image_id = target['image_id'] + image_id = target["image_id"] image_id = torch.tensor([image_id]) - anno = target['annotations'] + anno = target["annotations"] - anno = [obj for obj in anno if obj['iscrowd'] == 0] + anno = [obj for obj in anno if obj["iscrowd"] == 0] - bboxes = [obj['bbox'] for obj in anno] + bboxes = [obj["bbox"] for obj in anno] # guard against no boxes via resizing boxes = torch.as_tensor(bboxes, dtype=torch.float32).reshape(-1, 4) boxes[:, 2:] += boxes[:, :2] boxes[:, 0::2].clamp_(min=0, max=w) boxes[:, 1::2].clamp_(min=0, max=h) - categories = [obj['category_id'] for obj in anno] + categories = [obj["category_id"] for obj in anno] classes = torch.tensor(categories, dtype=torch.int64) - segmentations = [obj['segmentation'] for obj in anno] + segmentations = [obj["segmentation"] for obj in anno] masks = convert_coco_poly_to_mask(segmentations, h, w) @@ -94,17 +99,17 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: boxes = boxes[keep] classes = classes[keep] - target = {'boxes': boxes, 'labels': classes, 'image_id': image_id} + target = {"boxes": boxes, "labels": classes, "image_id": image_id} if masks.nelement() > 0: masks = masks[keep] - target['masks'] = masks + target["masks"] = masks # for conversion to coco api - area = torch.tensor([obj['area'] for obj in anno]) - iscrowd = torch.tensor([obj['iscrowd'] for obj in anno]) - target['area'] = area - target['iscrowd'] = iscrowd - return {'image': image, 'label': target} + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) + target["area"] = area + target["iscrowd"] = iscrowd + return {"image": image, "label": target} class VHR10(NonGeoDataset): @@ -155,34 +160,34 @@ class VHR10(NonGeoDataset): """ image_meta = { - 'url': 'https://drive.google.com/file/d/1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE', - 'filename': 'NWPU VHR-10 dataset.rar', - 'md5': 'd30a7ff99d92123ebb0b3a14d9102081', + "url": "https://drive.google.com/file/d/1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE", + "filename": "NWPU VHR-10 dataset.rar", + "md5": "d30a7ff99d92123ebb0b3a14d9102081", } target_meta = { - 'url': 'https://raw.githubusercontent.com/chaozhong2010/VHR-10_dataset_coco/ce0ba0f5f6a0737031f1cbe05e785ddd5ef05bd7/NWPU%20VHR-10_dataset_coco/annotations.json', # noqa: E501 - 'filename': 'annotations.json', - 'md5': '7c76ec50c17a61bb0514050d20f22c08', + "url": "https://raw.githubusercontent.com/chaozhong2010/VHR-10_dataset_coco/ce0ba0f5f6a0737031f1cbe05e785ddd5ef05bd7/NWPU%20VHR-10_dataset_coco/annotations.json", # noqa: E501 + "filename": "annotations.json", + "md5": "7c76ec50c17a61bb0514050d20f22c08", } categories = [ - 'background', - 'airplane', - 'ships', - 'storage tank', - 'baseball diamond', - 'tennis court', - 'basketball court', - 'ground track field', - 'harbor', - 'bridge', - 'vehicle', + "background", + "airplane", + "ships", + "storage tank", + "baseball diamond", + "tennis court", + "basketball court", + "ground track field", + "harbor", + "bridge", + "vehicle", ] def __init__( self, - root: str = 'data', - split: str = 'positive', + root: str = "data", + split: str = "positive", transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, checksum: bool = False, @@ -203,7 +208,7 @@ def __init__( DependencyNotFoundError: if ``split="positive"`` and pycocotools is not installed. """ - assert split in ['positive', 'negative'] + assert split in ["positive", "negative"] self.root = root self.split = split @@ -216,11 +221,11 @@ def __init__( if not self._check_integrity(): raise DatasetNotFoundError(self) - if split == 'positive': - pc = lazy_import('pycocotools.coco') + if split == "positive": + pc = lazy_import("pycocotools.coco") self.coco = pc.COCO( os.path.join( - self.root, 'NWPU VHR-10 dataset', self.target_meta['filename'] + self.root, "NWPU VHR-10 dataset", self.target_meta["filename"] ) ) self.coco_convert = ConvertCocoAnnotations() @@ -238,16 +243,16 @@ def __getitem__(self, index: int) -> dict[str, Any]: id_ = index % len(self) + 1 sample: dict[str, Any] = { - 'image': self._load_image(id_), - 'label': self._load_target(id_), + "image": self._load_image(id_), + "label": self._load_target(id_), } - if sample['label']['annotations']: + if sample["label"]["annotations"]: sample = self.coco_convert(sample) - sample['labels'] = sample['label']['labels'] - sample['boxes'] = sample['label']['boxes'] - sample['masks'] = sample['label']['masks'] - del sample['label'] + sample["labels"] = sample["label"]["labels"] + sample["boxes"] = sample["label"]["boxes"] + sample["masks"] = sample["label"]["masks"] + del sample["label"] if self.transforms is not None: sample = self.transforms(sample) @@ -260,7 +265,7 @@ def __len__(self) -> int: Returns: length of the dataset """ - if self.split == 'positive': + if self.split == "positive": return len(self.ids) else: return 150 @@ -276,9 +281,9 @@ def _load_image(self, id_: int) -> Tensor: """ filename = os.path.join( self.root, - 'NWPU VHR-10 dataset', - self.split + ' image set', - f'{id_:03d}.jpg', + "NWPU VHR-10 dataset", + self.split + " image set", + f"{id_:03d}.jpg", ) with Image.open(filename) as img: array: np.typing.NDArray[np.int_] = np.array(img) @@ -299,7 +304,7 @@ def _load_target(self, id_: int) -> dict[str, Any]: """ # Images in the "negative" image set have no annotations annot = [] - if self.split == 'positive': + if self.split == "positive": annot = self.coco.loadAnns(self.coco.getAnnIds(id_ - 1)) target = dict(image_id=id_, annotations=annot) @@ -313,18 +318,18 @@ def _check_integrity(self) -> bool: True if dataset files are found and/or MD5s match, else False """ image: bool = check_integrity( - os.path.join(self.root, self.image_meta['filename']), - self.image_meta['md5'] if self.checksum else None, + os.path.join(self.root, self.image_meta["filename"]), + self.image_meta["md5"] if self.checksum else None, ) # Annotations only needed for "positive" image set target = True - if self.split == 'positive': + if self.split == "positive": target = check_integrity( os.path.join( - self.root, 'NWPU VHR-10 dataset', self.target_meta['filename'] + self.root, "NWPU VHR-10 dataset", self.target_meta["filename"] ), - self.target_meta['md5'] if self.checksum else None, + self.target_meta["md5"] if self.checksum else None, ) return image and target @@ -332,25 +337,25 @@ def _check_integrity(self) -> bool: def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return # Download images download_and_extract_archive( - self.image_meta['url'], + self.image_meta["url"], self.root, - filename=self.image_meta['filename'], - md5=self.image_meta['md5'] if self.checksum else None, + filename=self.image_meta["filename"], + md5=self.image_meta["md5"] if self.checksum else None, ) # Annotations only needed for "positive" image set - if self.split == 'positive': + if self.split == "positive": # Download annotations download_url( - self.target_meta['url'], - os.path.join(self.root, 'NWPU VHR-10 dataset'), - self.target_meta['filename'], - self.target_meta['md5'] if self.checksum else None, + self.target_meta["url"], + os.path.join(self.root, "NWPU VHR-10 dataset"), + self.target_meta["filename"], + self.target_meta["md5"] if self.checksum else None, ) def plot( @@ -358,7 +363,7 @@ def plot( sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None, - show_feats: str | None = 'both', + show_feats: str | None = "both", box_alpha: float = 0.7, mask_alpha: float = 0.7, ) -> Figure: @@ -381,42 +386,42 @@ def plot( .. versionadded:: 0.4 """ - assert show_feats in {'boxes', 'masks', 'both'} - image = percentile_normalization(sample['image'].permute(1, 2, 0).numpy()) + assert show_feats in {"boxes", "masks", "both"} + image = percentile_normalization(sample["image"].permute(1, 2, 0).numpy()) - if self.split == 'negative': + if self.split == "negative": fig, axs = plt.subplots(squeeze=False) axs[0, 0].imshow(image) - axs[0, 0].axis('off') + axs[0, 0].axis("off") if suptitle is not None: plt.suptitle(suptitle) return fig - if show_feats != 'boxes': - skimage = lazy_import('skimage') + if show_feats != "boxes": + skimage = lazy_import("skimage") - boxes = sample['boxes'].cpu().numpy() - labels = sample['labels'].cpu().numpy() + boxes = sample["boxes"].cpu().numpy() + labels = sample["labels"].cpu().numpy() - if 'masks' in sample: - masks = [mask.squeeze().cpu().numpy() for mask in sample['masks']] + if "masks" in sample: + masks = [mask.squeeze().cpu().numpy() for mask in sample["masks"]] n_gt = len(boxes) ncols = 1 - show_predictions = 'prediction_labels' in sample + show_predictions = "prediction_labels" in sample if show_predictions: show_pred_boxes = False show_pred_masks = False - prediction_labels = sample['prediction_labels'].numpy() - prediction_scores = sample['prediction_scores'].numpy() - if 'prediction_boxes' in sample: - prediction_boxes = sample['prediction_boxes'].numpy() + prediction_labels = sample["prediction_labels"].numpy() + prediction_scores = sample["prediction_scores"].numpy() + if "prediction_boxes" in sample: + prediction_boxes = sample["prediction_boxes"].numpy() show_pred_boxes = True - if 'prediction_masks' in sample: - prediction_masks = sample['prediction_masks'].numpy() + if "prediction_masks" in sample: + prediction_masks = sample["prediction_masks"].numpy() show_pred_masks = True n_pred = len(prediction_labels) @@ -425,25 +430,25 @@ def plot( # Display image fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 13)) axs[0, 0].imshow(image) - axs[0, 0].axis('off') + axs[0, 0].axis("off") - cm = plt.get_cmap('gist_rainbow') + cm = plt.get_cmap("gist_rainbow") for i in range(n_gt): class_num = labels[i] color = cm(class_num / len(self.categories)) # Add bounding boxes x1, y1, x2, y2 = boxes[i] - if show_feats in {'boxes', 'both'}: + if show_feats in {"boxes", "both"}: r = patches.Rectangle( (x1, y1), x2 - x1, y2 - y1, linewidth=2, alpha=box_alpha, - linestyle='dashed', + linestyle="dashed", edgecolor=color, - facecolor='none', + facecolor="none", ) axs[0, 0].add_patch(r) @@ -451,26 +456,26 @@ def plot( label = self.categories[class_num] caption = label axs[0, 0].text( - x1, y1 - 8, caption, color='white', size=11, backgroundcolor='none' + x1, y1 - 8, caption, color="white", size=11, backgroundcolor="none" ) # Add masks - if show_feats in {'masks', 'both'} and 'masks' in sample: + if show_feats in {"masks", "both"} and "masks" in sample: mask = masks[i] contours = skimage.measure.find_contours(mask, 0.5) for verts in contours: verts = np.fliplr(verts) p = patches.Polygon( - verts, facecolor=color, alpha=mask_alpha, edgecolor='white' + verts, facecolor=color, alpha=mask_alpha, edgecolor="white" ) axs[0, 0].add_patch(p) if show_titles: - axs[0, 0].set_title('Ground Truth') + axs[0, 0].set_title("Ground Truth") if show_predictions: axs[0, 1].imshow(image) - axs[0, 1].axis('off') + axs[0, 1].axis("off") for i in range(n_pred): score = prediction_scores[i] if score < 0.5: @@ -488,22 +493,22 @@ def plot( y2 - y1, linewidth=2, alpha=box_alpha, - linestyle='dashed', + linestyle="dashed", edgecolor=color, - facecolor='none', + facecolor="none", ) axs[0, 1].add_patch(r) # Add labels label = self.categories[class_num] - caption = f'{label} {score:.3f}' + caption = f"{label} {score:.3f}" axs[0, 1].text( x1, y1 - 8, caption, - color='white', + color="white", size=11, - backgroundcolor='none', + backgroundcolor="none", ) # Add masks @@ -513,12 +518,12 @@ def plot( for verts in contours: verts = np.fliplr(verts) p = patches.Polygon( - verts, facecolor=color, alpha=mask_alpha, edgecolor='white' + verts, facecolor=color, alpha=mask_alpha, edgecolor="white" ) axs[0, 1].add_patch(p) if show_titles: - axs[0, 1].set_title('Prediction') + axs[0, 1].set_title("Prediction") if suptitle is not None: plt.suptitle(suptitle) From 8cff67879f3e30fa694a55cf39dab0ff91ff0379 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 28 May 2024 16:57:02 +0000 Subject: [PATCH 09/10] ruff format --- torchgeo/datasets/vhr10.py | 200 ++++++++++++++++++------------------- 1 file changed, 100 insertions(+), 100 deletions(-) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 8c090621edf..b1aae5d2a30 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -42,7 +42,7 @@ def convert_coco_poly_to_mask( Raises: DependencyNotFoundError: If pycocotools is not installed. """ - pycocotools = lazy_import("pycocotools") + pycocotools = lazy_import('pycocotools') masks = [] for polygons in segmentations: rles = pycocotools.mask.frPyObjects(polygons, height, width) @@ -70,28 +70,28 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: Returns: Processed sample """ - image = sample["image"] + image = sample['image'] _, h, w = image.size() - target = sample["label"] + target = sample['label'] - image_id = target["image_id"] + image_id = target['image_id'] image_id = torch.tensor([image_id]) - anno = target["annotations"] + anno = target['annotations'] - anno = [obj for obj in anno if obj["iscrowd"] == 0] + anno = [obj for obj in anno if obj['iscrowd'] == 0] - bboxes = [obj["bbox"] for obj in anno] + bboxes = [obj['bbox'] for obj in anno] # guard against no boxes via resizing boxes = torch.as_tensor(bboxes, dtype=torch.float32).reshape(-1, 4) boxes[:, 2:] += boxes[:, :2] boxes[:, 0::2].clamp_(min=0, max=w) boxes[:, 1::2].clamp_(min=0, max=h) - categories = [obj["category_id"] for obj in anno] + categories = [obj['category_id'] for obj in anno] classes = torch.tensor(categories, dtype=torch.int64) - segmentations = [obj["segmentation"] for obj in anno] + segmentations = [obj['segmentation'] for obj in anno] masks = convert_coco_poly_to_mask(segmentations, h, w) @@ -99,17 +99,17 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: boxes = boxes[keep] classes = classes[keep] - target = {"boxes": boxes, "labels": classes, "image_id": image_id} + target = {'boxes': boxes, 'labels': classes, 'image_id': image_id} if masks.nelement() > 0: masks = masks[keep] - target["masks"] = masks + target['masks'] = masks # for conversion to coco api - area = torch.tensor([obj["area"] for obj in anno]) - iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) - target["area"] = area - target["iscrowd"] = iscrowd - return {"image": image, "label": target} + area = torch.tensor([obj['area'] for obj in anno]) + iscrowd = torch.tensor([obj['iscrowd'] for obj in anno]) + target['area'] = area + target['iscrowd'] = iscrowd + return {'image': image, 'label': target} class VHR10(NonGeoDataset): @@ -160,34 +160,34 @@ class VHR10(NonGeoDataset): """ image_meta = { - "url": "https://drive.google.com/file/d/1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE", - "filename": "NWPU VHR-10 dataset.rar", - "md5": "d30a7ff99d92123ebb0b3a14d9102081", + 'url': 'https://drive.google.com/file/d/1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE', + 'filename': 'NWPU VHR-10 dataset.rar', + 'md5': 'd30a7ff99d92123ebb0b3a14d9102081', } target_meta = { - "url": "https://raw.githubusercontent.com/chaozhong2010/VHR-10_dataset_coco/ce0ba0f5f6a0737031f1cbe05e785ddd5ef05bd7/NWPU%20VHR-10_dataset_coco/annotations.json", # noqa: E501 - "filename": "annotations.json", - "md5": "7c76ec50c17a61bb0514050d20f22c08", + 'url': 'https://raw.githubusercontent.com/chaozhong2010/VHR-10_dataset_coco/ce0ba0f5f6a0737031f1cbe05e785ddd5ef05bd7/NWPU%20VHR-10_dataset_coco/annotations.json', # noqa: E501 + 'filename': 'annotations.json', + 'md5': '7c76ec50c17a61bb0514050d20f22c08', } categories = [ - "background", - "airplane", - "ships", - "storage tank", - "baseball diamond", - "tennis court", - "basketball court", - "ground track field", - "harbor", - "bridge", - "vehicle", + 'background', + 'airplane', + 'ships', + 'storage tank', + 'baseball diamond', + 'tennis court', + 'basketball court', + 'ground track field', + 'harbor', + 'bridge', + 'vehicle', ] def __init__( self, - root: str = "data", - split: str = "positive", + root: str = 'data', + split: str = 'positive', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, checksum: bool = False, @@ -208,7 +208,7 @@ def __init__( DependencyNotFoundError: if ``split="positive"`` and pycocotools is not installed. """ - assert split in ["positive", "negative"] + assert split in ['positive', 'negative'] self.root = root self.split = split @@ -221,11 +221,11 @@ def __init__( if not self._check_integrity(): raise DatasetNotFoundError(self) - if split == "positive": - pc = lazy_import("pycocotools.coco") + if split == 'positive': + pc = lazy_import('pycocotools.coco') self.coco = pc.COCO( os.path.join( - self.root, "NWPU VHR-10 dataset", self.target_meta["filename"] + self.root, 'NWPU VHR-10 dataset', self.target_meta['filename'] ) ) self.coco_convert = ConvertCocoAnnotations() @@ -243,16 +243,16 @@ def __getitem__(self, index: int) -> dict[str, Any]: id_ = index % len(self) + 1 sample: dict[str, Any] = { - "image": self._load_image(id_), - "label": self._load_target(id_), + 'image': self._load_image(id_), + 'label': self._load_target(id_), } - if sample["label"]["annotations"]: + if sample['label']['annotations']: sample = self.coco_convert(sample) - sample["labels"] = sample["label"]["labels"] - sample["boxes"] = sample["label"]["boxes"] - sample["masks"] = sample["label"]["masks"] - del sample["label"] + sample['labels'] = sample['label']['labels'] + sample['boxes'] = sample['label']['boxes'] + sample['masks'] = sample['label']['masks'] + del sample['label'] if self.transforms is not None: sample = self.transforms(sample) @@ -265,7 +265,7 @@ def __len__(self) -> int: Returns: length of the dataset """ - if self.split == "positive": + if self.split == 'positive': return len(self.ids) else: return 150 @@ -281,9 +281,9 @@ def _load_image(self, id_: int) -> Tensor: """ filename = os.path.join( self.root, - "NWPU VHR-10 dataset", - self.split + " image set", - f"{id_:03d}.jpg", + 'NWPU VHR-10 dataset', + self.split + ' image set', + f'{id_:03d}.jpg', ) with Image.open(filename) as img: array: np.typing.NDArray[np.int_] = np.array(img) @@ -304,7 +304,7 @@ def _load_target(self, id_: int) -> dict[str, Any]: """ # Images in the "negative" image set have no annotations annot = [] - if self.split == "positive": + if self.split == 'positive': annot = self.coco.loadAnns(self.coco.getAnnIds(id_ - 1)) target = dict(image_id=id_, annotations=annot) @@ -318,18 +318,18 @@ def _check_integrity(self) -> bool: True if dataset files are found and/or MD5s match, else False """ image: bool = check_integrity( - os.path.join(self.root, self.image_meta["filename"]), - self.image_meta["md5"] if self.checksum else None, + os.path.join(self.root, self.image_meta['filename']), + self.image_meta['md5'] if self.checksum else None, ) # Annotations only needed for "positive" image set target = True - if self.split == "positive": + if self.split == 'positive': target = check_integrity( os.path.join( - self.root, "NWPU VHR-10 dataset", self.target_meta["filename"] + self.root, 'NWPU VHR-10 dataset', self.target_meta['filename'] ), - self.target_meta["md5"] if self.checksum else None, + self.target_meta['md5'] if self.checksum else None, ) return image and target @@ -337,25 +337,25 @@ def _check_integrity(self) -> bool: def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return # Download images download_and_extract_archive( - self.image_meta["url"], + self.image_meta['url'], self.root, - filename=self.image_meta["filename"], - md5=self.image_meta["md5"] if self.checksum else None, + filename=self.image_meta['filename'], + md5=self.image_meta['md5'] if self.checksum else None, ) # Annotations only needed for "positive" image set - if self.split == "positive": + if self.split == 'positive': # Download annotations download_url( - self.target_meta["url"], - os.path.join(self.root, "NWPU VHR-10 dataset"), - self.target_meta["filename"], - self.target_meta["md5"] if self.checksum else None, + self.target_meta['url'], + os.path.join(self.root, 'NWPU VHR-10 dataset'), + self.target_meta['filename'], + self.target_meta['md5'] if self.checksum else None, ) def plot( @@ -363,7 +363,7 @@ def plot( sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None, - show_feats: str | None = "both", + show_feats: str | None = 'both', box_alpha: float = 0.7, mask_alpha: float = 0.7, ) -> Figure: @@ -386,42 +386,42 @@ def plot( .. versionadded:: 0.4 """ - assert show_feats in {"boxes", "masks", "both"} - image = percentile_normalization(sample["image"].permute(1, 2, 0).numpy()) + assert show_feats in {'boxes', 'masks', 'both'} + image = percentile_normalization(sample['image'].permute(1, 2, 0).numpy()) - if self.split == "negative": + if self.split == 'negative': fig, axs = plt.subplots(squeeze=False) axs[0, 0].imshow(image) - axs[0, 0].axis("off") + axs[0, 0].axis('off') if suptitle is not None: plt.suptitle(suptitle) return fig - if show_feats != "boxes": - skimage = lazy_import("skimage") + if show_feats != 'boxes': + skimage = lazy_import('skimage') - boxes = sample["boxes"].cpu().numpy() - labels = sample["labels"].cpu().numpy() + boxes = sample['boxes'].cpu().numpy() + labels = sample['labels'].cpu().numpy() - if "masks" in sample: - masks = [mask.squeeze().cpu().numpy() for mask in sample["masks"]] + if 'masks' in sample: + masks = [mask.squeeze().cpu().numpy() for mask in sample['masks']] n_gt = len(boxes) ncols = 1 - show_predictions = "prediction_labels" in sample + show_predictions = 'prediction_labels' in sample if show_predictions: show_pred_boxes = False show_pred_masks = False - prediction_labels = sample["prediction_labels"].numpy() - prediction_scores = sample["prediction_scores"].numpy() - if "prediction_boxes" in sample: - prediction_boxes = sample["prediction_boxes"].numpy() + prediction_labels = sample['prediction_labels'].numpy() + prediction_scores = sample['prediction_scores'].numpy() + if 'prediction_boxes' in sample: + prediction_boxes = sample['prediction_boxes'].numpy() show_pred_boxes = True - if "prediction_masks" in sample: - prediction_masks = sample["prediction_masks"].numpy() + if 'prediction_masks' in sample: + prediction_masks = sample['prediction_masks'].numpy() show_pred_masks = True n_pred = len(prediction_labels) @@ -430,25 +430,25 @@ def plot( # Display image fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 13)) axs[0, 0].imshow(image) - axs[0, 0].axis("off") + axs[0, 0].axis('off') - cm = plt.get_cmap("gist_rainbow") + cm = plt.get_cmap('gist_rainbow') for i in range(n_gt): class_num = labels[i] color = cm(class_num / len(self.categories)) # Add bounding boxes x1, y1, x2, y2 = boxes[i] - if show_feats in {"boxes", "both"}: + if show_feats in {'boxes', 'both'}: r = patches.Rectangle( (x1, y1), x2 - x1, y2 - y1, linewidth=2, alpha=box_alpha, - linestyle="dashed", + linestyle='dashed', edgecolor=color, - facecolor="none", + facecolor='none', ) axs[0, 0].add_patch(r) @@ -456,26 +456,26 @@ def plot( label = self.categories[class_num] caption = label axs[0, 0].text( - x1, y1 - 8, caption, color="white", size=11, backgroundcolor="none" + x1, y1 - 8, caption, color='white', size=11, backgroundcolor='none' ) # Add masks - if show_feats in {"masks", "both"} and "masks" in sample: + if show_feats in {'masks', 'both'} and 'masks' in sample: mask = masks[i] contours = skimage.measure.find_contours(mask, 0.5) for verts in contours: verts = np.fliplr(verts) p = patches.Polygon( - verts, facecolor=color, alpha=mask_alpha, edgecolor="white" + verts, facecolor=color, alpha=mask_alpha, edgecolor='white' ) axs[0, 0].add_patch(p) if show_titles: - axs[0, 0].set_title("Ground Truth") + axs[0, 0].set_title('Ground Truth') if show_predictions: axs[0, 1].imshow(image) - axs[0, 1].axis("off") + axs[0, 1].axis('off') for i in range(n_pred): score = prediction_scores[i] if score < 0.5: @@ -493,22 +493,22 @@ def plot( y2 - y1, linewidth=2, alpha=box_alpha, - linestyle="dashed", + linestyle='dashed', edgecolor=color, - facecolor="none", + facecolor='none', ) axs[0, 1].add_patch(r) # Add labels label = self.categories[class_num] - caption = f"{label} {score:.3f}" + caption = f'{label} {score:.3f}' axs[0, 1].text( x1, y1 - 8, caption, - color="white", + color='white', size=11, - backgroundcolor="none", + backgroundcolor='none', ) # Add masks @@ -518,12 +518,12 @@ def plot( for verts in contours: verts = np.fliplr(verts) p = patches.Polygon( - verts, facecolor=color, alpha=mask_alpha, edgecolor="white" + verts, facecolor=color, alpha=mask_alpha, edgecolor='white' ) axs[0, 1].add_patch(p) if show_titles: - axs[0, 1].set_title("Prediction") + axs[0, 1].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) From 36f796e820e1e7831c98d01922ba8806f91c6852 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 29 May 2024 11:59:57 +0000 Subject: [PATCH 10/10] revert readme changes --- README.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 6e213577b86..38883a212d7 100644 --- a/README.md +++ b/README.md @@ -120,15 +120,22 @@ TorchGeo includes a number of [_benchmark datasets_](https://torchgeo.readthedoc If you've used [torchvision](https://pytorch.org/vision) before, these datasets should seem very familiar. In this example, we'll create a dataset for the Northwestern Polytechnical University (NWPU) very-high-resolution ten-class ([VHR-10](https://github.com/chaozhong2010/VHR-10_dataset_coco)) geospatial object detection dataset. This dataset can be automatically downloaded, checksummed, and extracted, just like with torchvision. ```python +from torch.utils.data import DataLoader + +from torchgeo.datamodules.utils import collate_fn_detection from torchgeo.datasets import VHR10 -from torchgeo.datamodules import VHR10DataModule # Initialize the dataset dataset = VHR10(root="...", download=True, checksum=True) # Initialize the dataloader with the custom collate function -datamodule = VHR10DataModule(root="data", batch_size=32, num_workers=16) -datamodule.setup("fit") +dataloader = DataLoader( + dataset, + batch_size=128, + shuffle=True, + num_workers=4, + collate_fn=collate_fn_detection, +) # Training loop for batch in dataloader: