Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove scikit-learn #1063

Merged
merged 18 commits into from
Apr 24, 2023
Merged
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ dependencies:
- radiant-mlhub>=0.3
- rtree>=1
- scikit-image>=0.18
- scikit-learn>=0.24
- scipy>=1.6.2
- segmentation-models-pytorch>=0.2
- setuptools>=42
Expand Down
1 change: 0 additions & 1 deletion requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ pillow==8.0.0
pyproj==3.0.0
rasterio==1.2.0
rtree==1.0.0
scikit-learn==0.24
segmentation-models-pytorch==0.2.0
shapely==1.7.1
timm==0.4.12
Expand Down
1 change: 0 additions & 1 deletion requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ pillow==9.5.0
pyproj==3.5.0
rasterio==1.3.6
rtree==1.0.1
scikit-learn==1.2.2
segmentation-models-pytorch==0.3.2
shapely==2.0.1
timm==0.6.12
Expand Down
2 changes: 0 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ install_requires =
rasterio>=1.2,<2
# rtree 1+ required for len(index), index & index, index | index
rtree>=1,<2
# scikit-learn 0.24+ required for Python 3.9 wheels
scikit-learn>=0.24,<2
# segmentation-models-pytorch 0.2+ required for smp.losses module
segmentation-models-pytorch>=0.2,<0.4
# shapely 1.7.1+ required for Python 3.9 wheels
Expand Down
35 changes: 34 additions & 1 deletion tests/datamodules/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import re

import numpy as np
import pytest
import torch
from torch.utils.data import TensorDataset

from torchgeo.datamodules.utils import dataset_split
from torchgeo.datamodules.utils import dataset_split, group_shuffle_split


def test_dataset_split() -> None:
Expand All @@ -23,3 +27,32 @@ def test_dataset_split() -> None:
assert len(train_ds) == round(num_samples / 3)
assert len(val_ds) == round(num_samples / 3)
assert len(test_ds) == round(num_samples / 3)


def test_group_shuffle_split() -> None:
alphabet = np.array(list("abcdefghijklmnopqrstuvwxyz"))
groups = np.random.randint(0, 26, size=(1000))
groups = alphabet[groups]

with pytest.raises(ValueError, match="You must specify `train_size` *"):
group_shuffle_split(groups, train_size=None, test_size=None)
with pytest.raises(ValueError, match="`train_size` and `test_size` must sum to 1."):
group_shuffle_split(groups, train_size=0.2, test_size=1.0)
with pytest.raises(
ValueError,
match=re.escape("`train_size` and `test_size` must be in the range (0,1)."),
):
group_shuffle_split(groups, train_size=-0.2, test_size=1.2)
with pytest.raises(ValueError, match="26 groups were found, however the current *"):
group_shuffle_split(groups, train_size=None, test_size=0.999)

train_indices, test_indices = group_shuffle_split(
groups, train_size=None, test_size=0.2
)
assert len(set(train_indices) & set(test_indices)) == 0
assert len(set(groups[train_indices])) == 21
train_indices, test_indices = group_shuffle_split(
groups, train_size=0.8, test_size=None
)
assert len(set(train_indices) & set(test_indices)) == 0
assert len(set(groups[train_indices])) == 21
8 changes: 3 additions & 5 deletions torchgeo/datamodules/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from typing import Any

from sklearn.model_selection import GroupShuffleSplit
from torch.utils.data import Subset

from ..datasets import TropicalCyclone
from .geo import NonGeoDataModule
from .utils import group_shuffle_split


class TropicalCycloneDataModule(NonGeoDataModule):
Expand Down Expand Up @@ -50,10 +50,8 @@ def setup(self, stage: str) -> None:
storm_id = item["href"].split("/")[0].split("_")[-2]
storm_ids.append(storm_id)

train_indices, val_indices = next(
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=0).split(
storm_ids, groups=storm_ids
)
train_indices, val_indices = group_shuffle_split(
storm_ids, test_size=0.2, random_state=0
)

self.train_dataset = Subset(self.dataset, train_indices)
Expand Down
8 changes: 3 additions & 5 deletions torchgeo/datamodules/sen12ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from typing import Any

import torch
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from torch.utils.data import Subset

from ..datasets import SEN12MS
from .geo import NonGeoDataModule
from .utils import group_shuffle_split


class SEN12MSDataModule(NonGeoDataModule):
Expand Down Expand Up @@ -87,10 +87,8 @@ def setup(self, stage: str) -> None:
scene_id = int(parts[3])
scenes.append(season_id + scene_id)

train_indices, val_indices = next(
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=0).split(
scenes, groups=scenes
)
train_indices, val_indices = group_shuffle_split(
scenes, test_size=0.2, random_state=0
)

self.train_dataset = Subset(self.dataset, train_indices)
Expand Down
75 changes: 75 additions & 0 deletions torchgeo/datamodules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@

"""Common datamodule utilities."""

import math
from collections.abc import Iterable
from typing import Any, Optional, Union

import numpy as np
from torch import Generator
from torch.utils.data import Subset, TensorDataset, random_split

Expand Down Expand Up @@ -52,3 +55,75 @@ def dataset_split(
[train_length, val_length, test_length],
generator=Generator().manual_seed(0),
)


def group_shuffle_split(
groups: Iterable[Any],
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved
train_size: Optional[float] = None,
test_size: Optional[float] = None,
random_state: Optional[int] = None,
) -> tuple[np.ndarray[np.int32], np.ndarray[np.int32]]:
"""Method for performing a single group-wise shuffle split of data.

Loosely based off of `sklearn.model_selection.GroupShuffleSplit`.

Args:
groups: a sequence of group values used to split. Should be in the same order as
the data you want to split.
train_size: the proportion of groups to include in the train split. If None,
then it is set to complement `test_size`.
test_size: the proportion of groups to include in the test split (rounded up).
If None, then it is set to complement `train_size`.
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
random_state: controls the random splits (passed a seed to a
numpy.random.Generator), set for reproducible splits.
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

Returns:
train_indices, test_indices

Raises:
ValueError if `train_size` and `test_size` do not sum to 1, aren't in the range
(0,1), or are both None.
ValueError if the number of training or testing groups turns out to be 0.
"""
if train_size is None and test_size is None:
raise ValueError("You must specify `train_size`, `test_size`, or both.")
if (train_size is not None and test_size is not None) and (
not math.isclose(train_size + test_size, 1)
):
raise ValueError("`train_size` and `test_size` must sum to 1.")

if train_size is None and test_size is not None:
train_size = 1 - test_size
if test_size is None and train_size is not None:
test_size = 1 - train_size

assert train_size is not None and test_size is not None

if train_size <= 0 or train_size >= 1 or test_size <= 0 or test_size >= 1:
raise ValueError("`train_size` and `test_size` must be in the range (0,1).")
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

group_vals = set(groups)
n_groups = len(group_vals)
n_test_groups = round(n_groups * test_size)
n_train_groups = n_groups - n_test_groups

if n_train_groups == 0 or n_test_groups == 0:
raise ValueError(
f"{n_groups} groups were found, however the current settings of "
+ "`train_size` and `test_size` result in 0 training or testing groups."
)

generator = np.random.default_rng(seed=random_state)
train_group_vals = set(
generator.choice(list(group_vals), size=n_train_groups, replace=False)
)

train_idxs = []
test_idxs = []
for i, group_val in enumerate(groups):
if group_val in train_group_vals:
train_idxs.append(i)
else:
test_idxs.append(i)
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

return np.array(train_idxs), np.array(test_idxs)