diff --git a/examples/confs/burn_scars.yaml b/examples/confs/burn_scars.yaml index 94144bd0..7720f679 100644 --- a/examples/confs/burn_scars.yaml +++ b/examples/confs/burn_scars.yaml @@ -40,7 +40,7 @@ data: - NIR_NARROW - SWIR_1 - SWIR_2 - output_bands: + input_bands: - BLUE - GREEN - RED diff --git a/examples/confs/multi_temporal_crop.yaml b/examples/confs/multi_temporal_crop.yaml index e8390f8b..baf145fe 100644 --- a/examples/confs/multi_temporal_crop.yaml +++ b/examples/confs/multi_temporal_crop.yaml @@ -44,7 +44,7 @@ data: - NIR_NARROW - SWIR_1 - SWIR_2 - output_bands: + input_bands: - BLUE - GREEN - RED diff --git a/examples/confs/sen1floods11_vit.yaml b/examples/confs/sen1floods11_vit.yaml index 9e80b10f..6f7651ac 100644 --- a/examples/confs/sen1floods11_vit.yaml +++ b/examples/confs/sen1floods11_vit.yaml @@ -31,7 +31,7 @@ data: - NIR_NARROW - SWIR_1 - SWIR_2 - output_bands: + input_bands: - BLUE - GREEN - RED diff --git a/examples/confs/sen1floods11_vit_local_ckpt.yaml b/examples/confs/sen1floods11_vit_local_ckpt.yaml index 62180b56..b52aedf7 100644 --- a/examples/confs/sen1floods11_vit_local_ckpt.yaml +++ b/examples/confs/sen1floods11_vit_local_ckpt.yaml @@ -30,7 +30,7 @@ data: - NIR_NARROW - SWIR_1 - SWIR_2 - output_bands: + input_bands: - BLUE - GREEN - RED diff --git a/examples/notebooks/Tutorial.ipynb b/examples/notebooks/Tutorial.ipynb index f689345b..475cb1e8 100644 --- a/examples/notebooks/Tutorial.ipynb +++ b/examples/notebooks/Tutorial.ipynb @@ -468,7 +468,7 @@ " HLSBands.SWIR_1,\n", " HLSBands.SWIR_2,\n", " ],\n", - " output_bands=[\n", + " input_bands=[\n", " HLSBands.BLUE,\n", " HLSBands.GREEN,\n", " HLSBands.RED,\n", @@ -871,7 +871,7 @@ ], "metadata": { "kernelspec": { - "display_name": "terratorch_os", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -885,7 +885,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.10.8" } }, "nbformat": 4, diff --git a/src/terratorch/datamodules/generic_pixel_wise_data_module.py b/src/terratorch/datamodules/generic_pixel_wise_data_module.py index eff10407..89db280e 100644 --- a/src/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/src/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -91,7 +91,7 @@ def __init__( allow_substring_split_file: bool = True, dataset_bands: list[HLSBands | int] | None = None, predict_dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + input_bands: list[HLSBands | int] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, @@ -133,7 +133,7 @@ def __init__( matches (e.g. eurosat). Defaults to True. dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. - output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. + input_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. constant_scale (float, optional): _description_. Defaults to 1. rgb_indices (list[int] | None, optional): _description_. Defaults to None. train_transform (Albumentations.Compose | None): Albumentations transform @@ -183,7 +183,7 @@ def __init__( self.dataset_bands = dataset_bands self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands - self.output_bands = output_bands + self.input_bands = input_bands self.rgb_indices = rgb_indices self.expand_temporal_dimension = expand_temporal_dimension self.reduce_zero_label = reduce_zero_label @@ -213,7 +213,7 @@ def setup(self, stage: str) -> None: ignore_split_file_extensions=self.ignore_split_file_extensions, allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, - output_bands=self.output_bands, + input_bands=self.input_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.train_transform, @@ -233,7 +233,7 @@ def setup(self, stage: str) -> None: ignore_split_file_extensions=self.ignore_split_file_extensions, allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, - output_bands=self.output_bands, + input_bands=self.input_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.val_transform, @@ -253,7 +253,7 @@ def setup(self, stage: str) -> None: ignore_split_file_extensions=self.ignore_split_file_extensions, allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, - output_bands=self.output_bands, + input_bands=self.input_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.test_transform, @@ -267,7 +267,7 @@ def setup(self, stage: str) -> None: self.predict_root, self.num_classes, dataset_bands=self.predict_dataset_bands, - output_bands=self.output_bands, + input_bands=self.input_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.test_transform, @@ -330,7 +330,7 @@ def __init__( allow_substring_split_file: bool = True, dataset_bands: list[HLSBands | int] | None = None, predict_dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + input_bands: list[HLSBands | int] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, @@ -371,7 +371,7 @@ def __init__( matches (e.g. eurosat). Defaults to True. dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. - output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. + input_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. constant_scale (float, optional): _description_. Defaults to 1. rgb_indices (list[int] | None, optional): _description_. Defaults to None. train_transform (Albumentations.Compose | None): Albumentations transform @@ -421,7 +421,7 @@ def __init__( self.dataset_bands = dataset_bands self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands - self.output_bands = output_bands + self.input_bands = input_bands self.rgb_indices = rgb_indices # self.aug = AugmentationSequential( @@ -447,7 +447,7 @@ def setup(self, stage: str) -> None: ignore_split_file_extensions=self.ignore_split_file_extensions, allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, - output_bands=self.output_bands, + input_bands=self.input_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.train_transform, @@ -466,7 +466,7 @@ def setup(self, stage: str) -> None: ignore_split_file_extensions=self.ignore_split_file_extensions, allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, - output_bands=self.output_bands, + input_bands=self.input_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.val_transform, @@ -485,7 +485,7 @@ def setup(self, stage: str) -> None: ignore_split_file_extensions=self.ignore_split_file_extensions, allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, - output_bands=self.output_bands, + input_bands=self.input_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.test_transform, @@ -499,7 +499,7 @@ def setup(self, stage: str) -> None: self.predict_dataset = self.dataset_class( self.predict_root, dataset_bands=self.predict_dataset_bands, - output_bands=self.output_bands, + input_bands=self.input_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.test_transform, diff --git a/src/terratorch/datamodules/generic_scalar_label_data_module.py b/src/terratorch/datamodules/generic_scalar_label_data_module.py index 230db10a..6c2e14ac 100644 --- a/src/terratorch/datamodules/generic_scalar_label_data_module.py +++ b/src/terratorch/datamodules/generic_scalar_label_data_module.py @@ -77,7 +77,7 @@ def __init__( allow_substring_split_file: bool = True, dataset_bands: list[HLSBands | int] | None = None, predict_dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + input_bands: list[HLSBands | int] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, @@ -112,7 +112,7 @@ def __init__( matches (e.g. eurosat). Defaults to True. dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. - output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. + input_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. constant_scale (float, optional): _description_. Defaults to 1. rgb_indices (list[int] | None, optional): _description_. Defaults to None. train_transform (Albumentations.Compose | None): Albumentations transform @@ -152,7 +152,7 @@ def __init__( self.dataset_bands = dataset_bands self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands - self.output_bands = output_bands + self.input_bands = input_bands self.rgb_indices = rgb_indices self.expand_temporal_dimension = expand_temporal_dimension @@ -178,7 +178,7 @@ def setup(self, stage: str) -> None: ignore_split_file_extensions=self.ignore_split_file_extensions, allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, - output_bands=self.output_bands, + input_bands=self.input_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.train_transform, @@ -193,7 +193,7 @@ def setup(self, stage: str) -> None: ignore_split_file_extensions=self.ignore_split_file_extensions, allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, - output_bands=self.output_bands, + input_bands=self.input_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.val_transform, @@ -208,7 +208,7 @@ def setup(self, stage: str) -> None: ignore_split_file_extensions=self.ignore_split_file_extensions, allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, - output_bands=self.output_bands, + input_bands=self.input_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.test_transform, @@ -220,7 +220,7 @@ def setup(self, stage: str) -> None: self.predict_root, self.num_classes, dataset_bands=self.predict_dataset_bands, - output_bands=self.output_bands, + input_bands=self.input_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.test_transform, @@ -277,7 +277,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: # allow_substring_split_file: bool = True, # dataset_bands: list[HLSBands | int] | None = None, # predict_dataset_bands: list[HLSBands | int] | None = None, -# output_bands: list[HLSBands | int] | None = None, +# input_bands: list[HLSBands | int] | None = None, # constant_scale: float = 1, # rgb_indices: list[int] | None = None, # train_transform: A.Compose | None | list[A.BasicTransform] = None, @@ -311,7 +311,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: # matches (e.g. eurosat). Defaults to True. # dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. # predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. -# output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. +# input_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None. # constant_scale (float, optional): _description_. Defaults to 1. # rgb_indices (list[int] | None, optional): _description_. Defaults to None. # train_transform (Albumentations.Compose | None): Albumentations transform @@ -350,7 +350,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: # self.constant_scale = constant_scale # self.dataset_bands = dataset_bands # self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands -# self.output_bands = output_bands +# self.input_bands = input_bands # self.rgb_indices = rgb_indices # # self.aug = AugmentationSequential( @@ -372,7 +372,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: # ignore_split_file_extensions=self.ignore_split_file_extensions, # allow_substring_split_file=self.allow_substring_split_file, # dataset_bands=self.dataset_bands, -# output_bands=self.output_bands, +# input_bands=self.input_bands, # constant_scale=self.constant_scale, # rgb_indices=self.rgb_indices, # transform=self.train_transform, @@ -386,7 +386,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: # ignore_split_file_extensions=self.ignore_split_file_extensions, # allow_substring_split_file=self.allow_substring_split_file, # dataset_bands=self.dataset_bands, -# output_bands=self.output_bands, +# input_bands=self.input_bands, # constant_scale=self.constant_scale, # rgb_indices=self.rgb_indices, # transform=self.val_transform, @@ -400,7 +400,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: # ignore_split_file_extensions=self.ignore_split_file_extensions, # allow_substring_split_file=self.allow_substring_split_file, # dataset_bands=self.dataset_bands, -# output_bands=self.output_bands, +# input_bands=self.input_bands, # constant_scale=self.constant_scale, # rgb_indices=self.rgb_indices, # transform=self.test_transform, @@ -412,7 +412,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: # self.predict_dataset = self.dataset_class( # self.predict_root, # dataset_bands=self.predict_dataset_bands, -# output_bands=self.output_bands, +# input_bands=self.input_bands, # constant_scale=self.constant_scale, # rgb_indices=self.rgb_indices, # transform=self.test_transform, diff --git a/src/terratorch/datasets/generic_pixel_wise_dataset.py b/src/terratorch/datasets/generic_pixel_wise_dataset.py index 4ebdd6da..f63c0204 100644 --- a/src/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/src/terratorch/datasets/generic_pixel_wise_dataset.py @@ -42,7 +42,7 @@ def __init__( allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + input_bands: list[HLSBands | int] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None, @@ -72,7 +72,7 @@ def __init__( matches (e.g. eurosat). Defaults to True. rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + input_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. constant_scale (float): Factor to multiply image values by. Defaults to 1. transform (Albumentations.Compose | None): Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, @@ -97,8 +97,8 @@ def __init__( self.reduce_zero_label = reduce_zero_label self.expand_temporal_dimension = expand_temporal_dimension - if self.expand_temporal_dimension and output_bands is None: - msg = "Please provide output_bands when expand_temporal_dimension is True" + if self.expand_temporal_dimension and input_bands is None: + msg = "Please provide input_bands when expand_temporal_dimension is True" raise Exception(msg) if self.split_file is not None: with open(self.split_file) as f: @@ -119,16 +119,16 @@ def __init__( self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices self.dataset_bands = dataset_bands - self.output_bands = output_bands - if self.output_bands and not self.dataset_bands: + self.input_bands = input_bands + if self.input_bands and not self.dataset_bands: msg = "If output bands provided, dataset_bands must also be provided" return Exception(msg) # noqa: PLE0101 - if self.output_bands: - if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands): + if self.input_bands: + if len(set(self.input_bands) & set(self.dataset_bands)) != len(self.input_bands): msg = "Output bands must be a subset of dataset bands" raise Exception(msg) - self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands] + self.filter_indices = [self.dataset_bands.index(band) for band in self.input_bands] else: self.filter_indices = None # If no transform is given, apply only to transform to torch tensor @@ -142,7 +142,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: image = self._load_file(self.image_files[index], nan_replace = self.no_data_replace).to_numpy() # to channels last if self.expand_temporal_dimension: - image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands)) + image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.input_bands)) image = np.moveaxis(image, 0, -1) if self.filter_indices: @@ -180,7 +180,7 @@ def __init__( allow_substring_split_file: bool = True, rgb_indices: list[str] | None = None, dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + input_bands: list[HLSBands | int] | None = None, class_names: list[str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, @@ -212,7 +212,7 @@ def __init__( matches (e.g. eurosat). Defaults to True. rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + input_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. class_names (list[str], optional): Class names. Defaults to None. constant_scale (float): Factor to multiply image values by. Defaults to 1. transform (Albumentations.Compose | None): Albumentations transform to be applied. @@ -236,7 +236,7 @@ def __init__( allow_substring_split_file=allow_substring_split_file, rgb_indices=rgb_indices, dataset_bands=dataset_bands, - output_bands=output_bands, + input_bands=input_bands, constant_scale=constant_scale, transform=transform, no_data_replace=no_data_replace, @@ -347,7 +347,7 @@ def __init__( allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + input_bands: list[HLSBands | int] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None, @@ -377,7 +377,7 @@ def __init__( matches (e.g. eurosat). Defaults to True. rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + input_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. constant_scale (float): Factor to multiply image values by. Defaults to 1. transform (Albumentations.Compose | None): Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, @@ -400,7 +400,7 @@ def __init__( allow_substring_split_file=allow_substring_split_file, rgb_indices=rgb_indices, dataset_bands=dataset_bands, - output_bands=output_bands, + input_bands=input_bands, constant_scale=constant_scale, transform=transform, no_data_replace=no_data_replace, diff --git a/src/terratorch/datasets/generic_scalar_label_dataset.py b/src/terratorch/datasets/generic_scalar_label_dataset.py index 0a47b2da..5aaa92ce 100644 --- a/src/terratorch/datasets/generic_scalar_label_dataset.py +++ b/src/terratorch/datasets/generic_scalar_label_dataset.py @@ -41,7 +41,7 @@ def __init__( allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + input_bands: list[HLSBands | int] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float = 0, @@ -63,7 +63,7 @@ def __init__( matches (e.g. eurosat). Defaults to True. rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + input_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. constant_scale (float): Factor to multiply image values by. Defaults to 1. transform (Albumentations.Compose | None): Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, @@ -80,8 +80,8 @@ def __init__( self.constant_scale = constant_scale self.no_data_replace = no_data_replace self.expand_temporal_dimension = expand_temporal_dimension - if self.expand_temporal_dimension and output_bands is None: - msg = "Please provide output_bands when expand_temporal_dimension is True" + if self.expand_temporal_dimension and input_bands is None: + msg = "Please provide input_bands when expand_temporal_dimension is True" raise Exception(msg) if self.split_file is not None: with open(self.split_file) as f: @@ -109,16 +109,16 @@ def is_valid_file(x): self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices self.dataset_bands = dataset_bands - self.output_bands = output_bands - if self.output_bands and not self.dataset_bands: + self.input_bands = input_bands + if self.input_bands and not self.dataset_bands: msg = "If output bands provided, dataset_bands must also be provided" return Exception(msg) # noqa: PLE0101 - if self.output_bands: - if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands): + if self.input_bands: + if len(set(self.input_bands) & set(self.dataset_bands)) != len(self.input_bands): msg = "Output bands must be a subset of dataset bands" raise Exception(msg) - self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands] + self.filter_indices = [self.dataset_bands.index(band) for band in self.input_bands] else: self.filter_indices = None # If no transform is given, apply only to transform to torch tensor @@ -131,7 +131,7 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> dict[str, Any]: image, label = ImageFolder.__getitem__(self, index) if self.expand_temporal_dimension: - image = rearrange(image, "h w (channels time) -> time h w channels", channels=len(self.output_bands)) + image = rearrange(image, "h w (channels time) -> time h w channels", channels=len(self.input_bands)) if self.filter_indices: image = image[..., self.filter_indices] @@ -164,7 +164,7 @@ def __init__( allow_substring_split_file: bool = True, rgb_indices: list[str] | None = None, dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + input_bands: list[HLSBands | int] | None = None, class_names: list[str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, @@ -188,7 +188,7 @@ def __init__( matches (e.g. eurosat). Defaults to True. rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + input_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. class_names (list[str], optional): Class names. Defaults to None. constant_scale (float): Factor to multiply image values by. Defaults to 1. transform (Albumentations.Compose | None): Albumentations transform to be applied. @@ -206,7 +206,7 @@ def __init__( allow_substring_split_file=allow_substring_split_file, rgb_indices=rgb_indices, dataset_bands=dataset_bands, - output_bands=output_bands, + input_bands=input_bands, constant_scale=constant_scale, transform=transform, no_data_replace=no_data_replace, @@ -236,7 +236,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure # allow_substring_split_file: bool = True, # rgb_indices: list[int] | None = None, # dataset_bands: list[HLSBands | int] | None = None, -# output_bands: list[HLSBands | int] | None = None, +# input_bands: list[HLSBands | int] | None = None, # constant_scale: float = 1, # transform: A.Compose | None = None, # no_data_replace: float = 0, @@ -258,7 +258,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure # matches (e.g. eurosat). Defaults to True. # rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. # dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. -# output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. +# input_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. # constant_scale (float): Factor to multiply image values by. Defaults to 1. # transform (Albumentations.Compose | None): Albumentations transform to be applied. # Should end with ToTensorV2(). If used through the generic_data_module, @@ -275,7 +275,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure # allow_substring_split_file=allow_substring_split_file, # rgb_indices=rgb_indices, # dataset_bands=dataset_bands, -# output_bands=output_bands, +# input_bands=input_bands, # constant_scale=constant_scale, # transform=transform, # no_data_replace=no_data_replace,