Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to crop raster source extent #1030

Merged
merged 15 commits into from
Oct 20, 2020
8 changes: 7 additions & 1 deletion rastervision_core/rastervision/core/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,17 @@ def to_npboxes(boxes):
npboxes[boxind, :] = box.npbox_format()
return npboxes

def __iter__(self):
return iter(self.tuple_format())

def __getitem__(self, i):
return self.tuple_format()[i]

def __str__(self): # pragma: no cover
return str(self.npbox_format())

def __repr__(self): # pragma: no cover
return str(self)
return f'{type(self).__name__}{self.tuple_format()}'

def geojson_coordinates(self):
"""Return Box as GeoJSON coordinates."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple
from pydantic import conint

import numpy as np
Expand All @@ -19,14 +19,16 @@ class MultiRasterSource(ActivateMixin, RasterSource):
their output along the channel dimension (assumed to be the last dimension).
"""

def __init__(self,
raster_sources: Sequence[RasterSource],
raw_channel_order: Sequence[conint(ge=0)],
allow_different_extents: bool = False,
force_same_dtype: bool = False,
channel_order: Optional[Sequence[conint(ge=0)]] = None,
crs_source: conint(ge=0) = 0,
raster_transformers: Sequence = []):
def __init__(
self,
raster_sources: Sequence[RasterSource],
raw_channel_order: Sequence[conint(ge=0)],
allow_different_extents: bool = False,
force_same_dtype: bool = False,
channel_order: Optional[Sequence[conint(ge=0)]] = None,
crs_source: conint(ge=0) = 0,
raster_transformers: Sequence = [],
extent_crop: Optional[Tuple[float, float, float, float]] = None):
"""Constructor.

Args:
Expand All @@ -47,6 +49,10 @@ def __init__(self,
that will be used by .get_chip(). Defaults to None.
raster_transformers (Sequence, optional): Sequence of transformers.
Defaults to [].
extent_crop (Tuple[float, float, float, float], optional): Relative
offsets (top, left, bottom, right) for cropping the extent.
Useful for using splitting a scene into different datasets.
Defaults to None i.e. no cropping.
"""
num_channels = len(raw_channel_order)
if not channel_order:
Expand All @@ -59,6 +65,7 @@ def __init__(self,
self.raster_sources = raster_sources
self.raw_channel_order = list(raw_channel_order)
self.crs_source = crs_source
self.extent_crop = extent_crop

self.validate_raster_sources()

Expand All @@ -77,7 +84,8 @@ def validate_raster_sources(self) -> None:
f'Got: {extents} '
'(carefully consider using allow_different_extents)')

sub_num_channels = sum(rs.num_channels for rs in self.raster_sources)
sub_num_channels = sum(
len(rs.channel_order) for rs in self.raster_sources)
if sub_num_channels != self.num_channels:
raise MultiRasterSourceError(
f'num_channels ({self.num_channels}) != sum of num_channels '
Expand All @@ -89,6 +97,12 @@ def _subcomponents_to_activate(self) -> None:
def get_extent(self) -> Box:
rs = self.raster_sources[0]
extent = rs.get_extent()
if self.extent_crop is not None:
h, w = extent.get_height(), extent.get_width()
top, left, bottom, right = self.extent_crop
ymin, xmin = int(h * top), int(w * left)
ymax, xmax = h - int(h * bottom), w - int(w * right)
return Box(ymin, xmin, ymax, xmax)
return extent

def get_dtype(self) -> np.dtype:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def build(self, tmp_dir, use_transformers=True):
allow_different_extents=self.allow_different_extents,
channel_order=self.channel_order,
crs_source=self.crs_source,
raster_transformers=raster_transformers)
raster_transformers=raster_transformers,
extent_crop=self.extent_crop)
return multi_raster_source

def update(self, pipeline=None, scene=None):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Tuple

from rastervision.pipeline.config import Config, register_config, Field
from rastervision.core.data.raster_transformer import RasterTransformerConfig
Expand All @@ -11,6 +11,13 @@ class RasterSourceConfig(Config):
description=
'The sequence of channel indices to use when reading imagery.')
transformers: List[RasterTransformerConfig] = []
extent_crop: Optional[Tuple[float, float, float, float]] = Field(
None,
description='Relative offsets (top, left, bottom, right) for cropping '
'the extent of the raster source. Useful for splitting a scene into '
'different dataset splits. E.g. (0, 0, 0.2, 0) for the training set '
'and (0.8, 0, 0, 0) for the validation set will do a 80-20 split by '
'height. Defaults to None i.e. no cropping.')

def build(self, tmp_dir, use_transformers=True):
raise NotImplementedError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import subprocess
from decimal import Decimal
import tempfile
from typing import Optional, Tuple

import numpy as np
import rasterio
Expand Down Expand Up @@ -73,14 +74,16 @@ def load_window(image_dataset, window=None, is_masked=False):


class RasterioSource(ActivateMixin, RasterSource):
def __init__(self,
uris,
raster_transformers,
tmp_dir,
allow_streaming=False,
channel_order=None,
x_shift=0.0,
y_shift=0.0):
def __init__(
self,
uris,
raster_transformers,
tmp_dir,
allow_streaming=False,
channel_order=None,
x_shift=0.0,
y_shift=0.0,
extent_crop: Optional[Tuple[float, float, float, float]] = None):
"""Constructor.

This RasterSource can read any file that can be opened by Rasterio/GDAL
Expand All @@ -92,6 +95,10 @@ def __init__(self,

Args:
channel_order: list of indices of channels to extract from raw imagery
extent_crop (Tuple[float, float, float, float], optional): Relative
offsets (top, left, bottom, right) for cropping the extent.
Useful for using splitting a scene into different datasets.
Defaults to None i.e. no cropping.
"""
self.uris = uris
self.tmp_dir = tmp_dir
Expand All @@ -101,6 +108,7 @@ def __init__(self,
self.y_shift = y_shift
self.do_shift = self.x_shift != 0.0 or self.y_shift != 0.0
self.allow_streaming = allow_streaming
self.extent_crop = extent_crop

num_channels = None

Expand Down Expand Up @@ -156,7 +164,13 @@ def get_crs_transformer(self):
return self.crs_transformer

def get_extent(self):
return Box(0, 0, self.height, self.width)
h, w = self.height, self.width
if self.extent_crop is not None:
top, left, bottom, right = self.extent_crop
ymin, xmin = int(h * top), int(w * left)
ymax, xmax = h - int(h * bottom), w - int(w * right)
return Box(ymin, xmin, ymax, xmax)
return Box(0, 0, h, w)

def get_dtype(self):
"""Return the numpy.dtype of this scene"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ def build(self, tmp_dir, use_transformers=True):
allow_streaming=self.allow_streaming,
channel_order=self.channel_order,
x_shift=self.x_shift,
y_shift=self.y_shift)
y_shift=self.y_shift,
extent_crop=self.extent_crop)
66 changes: 66 additions & 0 deletions tests/core/data/raster_source/test_multi_raster_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import unittest

from rastervision.core.data import (
RasterioSourceConfig, MultiRasterSourceConfig, SubRasterSourceConfig)
from rastervision.pipeline import rv_config

from tests import data_file_path


def make_cfg(img_path='small-rgb-tile.tif', **kwargs):
img_path = data_file_path(img_path)
r_source = RasterioSourceConfig(uris=[img_path], channel_order=[0])
g_source = RasterioSourceConfig(uris=[img_path], channel_order=[1])
b_source = RasterioSourceConfig(uris=[img_path], channel_order=[2])

cfg = MultiRasterSourceConfig(
raster_sources=[
SubRasterSourceConfig(raster_source=r_source, target_channels=[0]),
SubRasterSourceConfig(raster_source=g_source, target_channels=[1]),
SubRasterSourceConfig(raster_source=b_source, target_channels=[2])
],
**kwargs)
return cfg


class TestRasterioSource(unittest.TestCase):
def setUp(self):
self.tmp_dir_obj = rv_config.get_tmp_dir()
self.tmp_dir = self.tmp_dir_obj.name

def tearDown(self):
self.tmp_dir_obj.cleanup()

def test_extent(self):
cfg = make_cfg('small-rgb-tile.tif')
rs = cfg.build(tmp_dir=self.tmp_dir)
extent = rs.get_extent()
h, w = extent.get_height(), extent.get_width()
ymin, xmin, ymax, xmax = extent
self.assertEqual(h, 256)
self.assertEqual(w, 256)
self.assertEqual(ymin, 0)
self.assertEqual(xmin, 0)
self.assertEqual(ymax, 256)
self.assertEqual(xmax, 256)

def test_extent_crop(self):
f = 1 / 4
cfg_crop = make_cfg('small-rgb-tile.tif', extent_crop=(f, f, f, f))
rs_crop = cfg_crop.build(tmp_dir=self.tmp_dir)
extent_crop = rs_crop.get_extent()

self.assertEqual(extent_crop.ymin, 64)
self.assertEqual(extent_crop.xmin, 64)
self.assertEqual(extent_crop.ymax, 192)
self.assertEqual(extent_crop.xmax, 192)

windows = extent_crop.get_windows(64, 64)
self.assertEqual(windows[0].ymin, 64)
self.assertEqual(windows[0].xmin, 64)
self.assertEqual(windows[-1].ymax, 192)
self.assertEqual(windows[-1].xmax, 192)


if __name__ == '__main__':
unittest.main()
33 changes: 33 additions & 0 deletions tests/core/data/raster_source/test_rasterio_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,39 @@ def test_no_epsg(self):
'Creating RasterioSource with CRS with no EPSG attribute '
'raised an exception when it should not have.')

def test_extent(self):
img_path = data_file_path('small-rgb-tile.tif')
cfg = RasterioSourceConfig(uris=[img_path])
rs = cfg.build(tmp_dir=self.tmp_dir)
extent = rs.get_extent()
h, w = extent.get_height(), extent.get_width()
ymin, xmin, ymax, xmax = extent
self.assertEqual(h, 256)
self.assertEqual(w, 256)
self.assertEqual(ymin, 0)
self.assertEqual(xmin, 0)
self.assertEqual(ymax, 256)
self.assertEqual(xmax, 256)

def test_extent_crop(self):
f = 1 / 4
img_path = data_file_path('small-rgb-tile.tif')
cfg_crop = RasterioSourceConfig(
uris=[img_path], extent_crop=(f, f, f, f))
rs_crop = cfg_crop.build(tmp_dir=self.tmp_dir)
extent_crop = rs_crop.get_extent()

self.assertEqual(extent_crop.ymin, 64)
self.assertEqual(extent_crop.xmin, 64)
self.assertEqual(extent_crop.ymax, 192)
self.assertEqual(extent_crop.xmax, 192)

windows = extent_crop.get_windows(64, 64)
self.assertEqual(windows[0].ymin, 64)
self.assertEqual(windows[0].xmin, 64)
self.assertEqual(windows[-1].ymax, 192)
self.assertEqual(windows[-1].xmax, 192)


if __name__ == '__main__':
unittest.main()
17 changes: 17 additions & 0 deletions tests/core/test_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,23 @@ def test_get_windows(self):
[window.tuple_format() for window in expected_windows])
self.assertSetEqual(windows, expected_windows)

def test_unpacking(self):
box = Box(1, 2, 3, 4)
ymin, xmin, ymax, xmax = box
self.assertEqual((ymin, xmin, ymax, xmax), box.tuple_format())

def test_subscripting(self):
box = Box(1, 2, 3, 4)
self.assertEqual(box[0], 1)
self.assertEqual(box[1], 2)
self.assertEqual(box[2], 3)
self.assertEqual(box[3], 4)
self.assertRaises(IndexError, lambda: box[4])

def test_repr(self):
box = Box(1, 2, 3, 4)
self.assertEqual(box.__repr__(), 'Box(1, 2, 3, 4)')


if __name__ == '__main__':
unittest.main()