Skip to content

Commit

Permalink
Merge branch 'main' into improve/logging
Browse files Browse the repository at this point in the history
  • Loading branch information
Joao-L-S-Almeida committed Feb 24, 2025
2 parents 71cb84d + 92bd85a commit 1026956
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 9 deletions.
8 changes: 4 additions & 4 deletions examples/confs/eurosat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ trainer:
logger:
class_path: TensorBoardLogger
init_args:
save_dir: <your_path_here>/torchgeo_eurosat
save_dir: ./torchgeo_eurosat
name: eurosat
callbacks:
- class_path: RichProgressBar
Expand All @@ -26,7 +26,7 @@ trainer:
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: <your_path_here>/torchgeo_eurosat
default_root_dir: ./torchgeo_eurosat

data:
class_path: terratorch.datamodules.TorchNonGeoDataModule
Expand All @@ -41,7 +41,7 @@ data:
batch_size: 32
num_workers: 8
dict_kwargs:
root: /dccstor/geofm-pre/EuroSat
root: ./EuroSat
download: True
bands:
- B02
Expand All @@ -57,7 +57,7 @@ model:
model_args:
decoder: IdentityDecoder
backbone_pretrained: true
backbone: prithvi_eo_v1_300
backbone: prithvi_eo_v2_300
head_dim_list:
- 384
- 128
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ dependencies = [
"torch",
"torchvision",
"rioxarray",
"albumentations==1.4.10",
"albumentations==1.4.0",
"albucore==0.0.16",
"rasterio",
"torchmetrics",
Expand Down
11 changes: 11 additions & 0 deletions terratorch/datamodules/generic_multimodal_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from terratorch.datamodules.generic_pixel_wise_data_module import Normalize
from terratorch.io.file import load_from_file_or_attribute

from .utils import check_dataset_stackability

logger = logging.getLogger("terratorch")

def collate_chunk_dicts(batch_list):
Expand Down Expand Up @@ -204,6 +206,7 @@ def __init__(
sample_replace: bool = False,
channel_position: int = -3,
concat_bands: bool = False,
check_stackability: bool = True,
**kwargs: Any,
) -> None:
"""Constructor
Expand Down Expand Up @@ -308,7 +311,9 @@ def __init__(
concat_bands (bool): Concatenate all image modalities along the band dimension into a single "image", so
that it can be processed by single-modal models. Concatenate in the order of provided modalities.
Works with image modalities only. Does not work with allow_missing_modalities. Defaults to False.
check_stackability (bool): Check if all the files in the dataset has the same size and can be stacked.
"""

if task == "segmentation":
dataset_class = GenericMultimodalSegmentationDataset
elif task == "regression":
Expand Down Expand Up @@ -364,6 +369,7 @@ def __init__(
self.reduce_zero_label = reduce_zero_label
self.channel_position = channel_position
self.concat_bands = concat_bands
self.check_stackability = check_stackability

if isinstance(train_transform, dict):
self.train_transform = {m: wrap_in_compose_is_list(train_transform[m]) if m in train_transform else None
Expand Down Expand Up @@ -526,6 +532,11 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
"""
dataset = self._valid_attribute(f"{split}_dataset", "dataset")
batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

if self.check_stackability:
print("Checking stackability.")
batch_size = check_dataset_stackability(dataset, batch_size)

if self.sample_num_modalities:
# Custom batch sampler for sampling modalities per batch
batch_sampler = MultiModalBatchSampler(
Expand Down
9 changes: 7 additions & 2 deletions terratorch/datamodules/generic_pixel_wise_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def __init__(
no_label_replace: int | None = None,
drop_last: bool = True,
pin_memory: bool = False,
check_stackability: bool = True,
**kwargs: Any,
) -> None:
"""Constructor
Expand Down Expand Up @@ -429,7 +430,7 @@ def __init__(
drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
pin_memory (bool): If ``True``, the data loader will copy Tensors
into device/CUDA pinned memory before returning them. Defaults to False.
check_stackability (bool): Check if all the files in the dataset has the same size and can be stacked.
"""
super().__init__(GenericNonGeoPixelwiseRegressionDataset, batch_size, num_workers, **kwargs)
self.img_grep = img_grep
Expand Down Expand Up @@ -475,6 +476,8 @@ def __init__(
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)

self.check_stackability = check_stackability

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
Expand Down Expand Up @@ -565,7 +568,9 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
dataset = self._valid_attribute(f"{split}_dataset", "dataset")
batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

batch_size = check_dataset_stackability(dataset, batch_size)
if self.check_stackability:
print("Checking stackability.")
batch_size = check_dataset_stackability(dataset, batch_size)

return DataLoader(
dataset=dataset,
Expand Down
10 changes: 8 additions & 2 deletions terratorch/datamodules/generic_scalar_label_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
expand_temporal_dimension: bool = False,
no_data_replace: float = 0,
drop_last: bool = True,
check_stackability: bool = True,
**kwargs: Any,
) -> None:
"""Constructor
Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(
expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
Defaults to False.
drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
check_stackability (bool): Check if all the files in the dataset has the same size and can be stacked.
"""
super().__init__(GenericNonGeoClassificationDataset, batch_size, num_workers, **kwargs)
self.num_classes = num_classes
Expand Down Expand Up @@ -169,6 +171,8 @@ def __init__(
# self.aug = Normalize(means, stds)
# self.collate_fn = collate_fn_list_dicts

self.check_stackability = check_stackability

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
Expand Down Expand Up @@ -243,8 +247,10 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
"""
dataset = self._valid_attribute(f"{split}_dataset", "dataset")
batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

batch_size = check_dataset_stackability(dataset, batch_size)

if self.check_stackability:
print("Checking stackability.")
batch_size = check_dataset_stackability(dataset, batch_size)

return DataLoader(
dataset=dataset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ data:
- 2
- 1
- 0
check_stackability: false
train_data_root: tests/resources/inputs
train_label_data_root: tests/resources/inputs
val_data_root: tests/resources/inputs
Expand Down

0 comments on commit 1026956

Please sign in to comment.