diff --git a/python/.pylintrc b/python/.pylintrc index 89244f98dae..0aed22e524f 100644 --- a/python/.pylintrc +++ b/python/.pylintrc @@ -95,7 +95,8 @@ disable=raw-checker-failed, file-ignored, suppressed-message, deprecated-pragma, - use-symbolic-message-instead + use-symbolic-message-instead, + duplicate-code # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/python/grass/jupyter/Makefile b/python/grass/jupyter/Makefile index 1fd9f18b40d..102a70cbc80 100644 --- a/python/grass/jupyter/Makefile +++ b/python/grass/jupyter/Makefile @@ -11,6 +11,7 @@ MODULES = \ interactivemap \ region \ map3d \ + seriesmap \ reprojection_renderer \ utils \ timeseriesmap diff --git a/python/grass/jupyter/__init__.py b/python/grass/jupyter/__init__.py index e2fa5dde429..246f6d7b3ab 100644 --- a/python/grass/jupyter/__init__.py +++ b/python/grass/jupyter/__init__.py @@ -103,3 +103,4 @@ from .map3d import Map3D from .setup import init from .timeseriesmap import TimeSeriesMap +from .seriesmap import SeriesMap diff --git a/python/grass/jupyter/region.py b/python/grass/jupyter/region.py index 298d95bc23a..30e71385343 100644 --- a/python/grass/jupyter/region.py +++ b/python/grass/jupyter/region.py @@ -215,6 +215,78 @@ def set_region_from_command(self, module, **kwargs): return +class RegionManagerForSeries: + """Region manager for SeriesMap""" + + def __init__(self, use_region, saved_region, width, height, env): + """Manages region during rendering. + + :param use_region: if True, use either current or provided saved region, + else derive region from rendered layers + :param saved_region: if name of saved_region is provided, + this region is then used for rendering + :param width: rendering width + :param height: rendering height + :param env: environment for rendering + """ + self._env = env + self._width = width + self._height = height + self._use_region = use_region + self._saved_region = saved_region + self._extent_set = False + self._resolution_set = False + + def set_region_from_rasters(self, rasters): + """Sets computational region for rendering from a series of rasters. + + This function sets the region from a series of rasters. If the extent or + resolution has already been set by calling this function previously or by the + set_region_from vectors() function, this function will not modify it. + + If user specified the name of saved region during object's initialization, + the provided region is used. If it's not specified + and use_region=True, current region is used. + """ + if self._saved_region: + self._env["GRASS_REGION"] = gs.region_env( + region=self._saved_region, env=self._env + ) + return + if self._use_region: + # use current + return + if self._resolution_set and self._extent_set: + return + if not self._resolution_set and not self._extent_set: + self._env["GRASS_REGION"] = gs.region_env(raster=rasters, env=self._env) + self._extent_set = True + self._resolution_set = True + elif not self._resolution_set: + self._env["GRASS_REGION"] = gs.region_env(align=rasters[0], env=self._env) + self._resolution_set = True + + def set_region_from_vectors(self, vectors): + """Sets computational region extent for rendering from a series of vectors + + If the extent and resolution has already been set by set_region_from_rasters, + or by using the saved_region or use_region arguments, the region is not modified + """ + if self._saved_region: + self._env["GRASS_REGION"] = gs.region_env( + region=self._saved_region, env=self._env + ) + return + if self._use_region: + # use current + return + if self._resolution_set and self._extent_set: + return + if not self._resolution_set and not self._extent_set: + self._env["GRASS_REGION"] = gs.region_env(vector=vectors, env=self._env) + self._extent_set = True + + class RegionManagerFor3D: """Region manager for 3D displays (gets region from m.nviz.image command)""" diff --git a/python/grass/jupyter/seriesmap.py b/python/grass/jupyter/seriesmap.py new file mode 100644 index 00000000000..2c3a3860130 --- /dev/null +++ b/python/grass/jupyter/seriesmap.py @@ -0,0 +1,355 @@ +# MODULE: grass.jupyter.seriesmap +# +# AUTHOR(S): Caitlin Haedrich +# +# PURPOSE: This module contains functions for visualizing series of rasters in +# Jupyter Notebooks +# +# COPYRIGHT: (C) 2022 Caitlin Haedrich, and by the GRASS Development Team +# +# This program is free software under the GNU General Public +# License (>=v2). Read the file COPYING that comes with GRASS +# for details. +"""Create and display visualizations for a series of rasters.""" + +import tempfile +import os +import weakref +import shutil + +import grass.script as gs +from grass.grassdb.data import map_exists + +from .map import Map +from .region import RegionManagerForSeries +from .utils import save_gif + + +class SeriesMap: + """Creates visualizations from a series of rasters or vectors in Jupyter + Notebooks. + + Basic usage:: + + >>> series = gj.SeriesMap(height = 500) + >>> series.add_rasters(["elevation_shade", "geology", "soils"]) + >>> series.add_vectors(["streams", "streets", "viewpoints"]) + >>> series.d_barscale() + >>> series.show() # Create Slider + >>> series.save("image.gif") + + This class of grass.jupyter is experimental and under development. The API can + change at anytime. + """ + + # pylint: disable=too-many-instance-attributes + # pylint: disable=duplicate-code + + def __init__( + self, + width=None, + height=None, + env=None, + use_region=False, + saved_region=None, + ): + """Creates an instance of the SeriesMap visualizations class. + + :param int width: width of map in pixels + :param int height: height of map in pixels + :param str env: environment + :param use_region: if True, use either current or provided saved region, + else derive region from rendered layers + :param saved_region: if name of saved_region is provided, + this region is then used for rendering + """ + + # Copy Environment + if env: + self._env = env.copy() + else: + self._env = os.environ.copy() + + self._series_length = None + self._base_layer_calls = [] + self._calls = [] + self._series_added = False + self._layers_rendered = False + self._layer_filename_dict = {} + self._names = [] + self._width = width + self._height = height + + # Create a temporary directory for our PNG images + # Resource managed by weakref.finalize. + self._tmpdir = ( + # pylint: disable=consider-using-with + tempfile.TemporaryDirectory() + ) + + def cleanup(tmpdir): + tmpdir.cleanup() + + weakref.finalize(self, cleanup, self._tmpdir) + + # Handle Regions + self._region_manager = RegionManagerForSeries( + use_region=use_region, + saved_region=saved_region, + width=width, + height=height, + env=self._env, + ) + + def add_rasters(self, rasters, **kwargs): + """ + :param list rasters: list of raster layers to add to SeriesMap + """ + for raster in rasters: + if not map_exists(name=raster, element="raster"): + raise NameError(_("Could not find a raster named {}").format(raster)) + # Update region to rasters if not use_region or saved_region + self._region_manager.set_region_from_rasters(rasters) + if self._series_added: + assert self._series_length == len(rasters), _( + "Number of vectors in series must match number of vectors" + ) + for i in range(self._series_length): + kwargs["map"] = rasters[i] + self._calls[i].append(("d.rast", kwargs.copy())) + else: + self._series_length = len(rasters) + for raster in rasters: + kwargs["map"] = raster + self._calls.append([("d.rast", kwargs.copy())]) + self._series_added = True + if not self._names: + self._names = rasters + self._layers_rendered = False + + def add_vectors(self, vectors, **kwargs): + """ + :param list vectors: list of vector layers to add to SeriesMap + """ + for vector in vectors: + if not map_exists(name=vector, element="vector"): + raise NameError(_("Could not find a vector named {}").format(vector)) + # Update region extent to vectors if not use_region or saved_region + self._region_manager.set_region_from_vectors(vectors) + if self._series_added: + assert self._series_length == len(vectors), _( + "Number of rasters in series must match number of vectors" + ) + for i in range(self._series_length): + kwargs["map"] = vectors[i] + self._calls[i].append(("d.vect", kwargs.copy())) + else: + self._series_length = len(vectors) + for vector in vectors: + kwargs["map"] = vector + self._calls.append([("d.vect", kwargs.copy())]) + self._series_added = True + if not self._names: + self._names = vectors + self._layers_rendered = False + + def __getattr__(self, name): + """ + Parse attribute to GRASS display module. Attribute should be in + the form 'd_module_name'. For example, 'd.rast' is called with 'd_rast'. + """ + # Check to make sure format is correct + if not name.startswith("d_"): + raise AttributeError(_("Module must begin with 'd_'")) + # Reformat string + grass_module = name.replace("_", ".") + # Assert module exists + if not shutil.which(grass_module): + raise AttributeError(_("Cannot find GRASS module {}").format(grass_module)) + # if this function is called, the images need to be rendered again + self._layers_rendered = False + + def wrapper(**kwargs): + if not self._series_added: + self._base_layer_calls.append((grass_module, kwargs)) + else: + for row in self._calls: + row.append((grass_module, kwargs)) + + return wrapper + + def add_names(self, names): + """Add list of names associated with layers. + Default will be names of first series added.""" + assert self._series_length == len(names), _( + "Number of vectors in series must match number of vectors" + ) + self._names = names + + def _render_baselayers(self, img): + """Add collected baselayers to Map instance""" + for grass_module, kwargs in self._base_layer_calls: + img.run(grass_module, **kwargs) + + def render(self): + """Renders image for each raster in series. + + Save PNGs to temporary directory. Must be run before creating a visualization + (i.e. show or save). + """ + + if not self._series_added: + raise RuntimeError( + "Cannot render series since none has been added." + "Use SeriesMap.add_rasters() or SeriesMap.add_vectors()" + ) + + # Make base image (background and baselayers) + # Random name needed to avoid potential conflict with layer names + random_name_base = gs.append_random("base", 8) + ".png" + base_file = os.path.join(self._tmpdir.name, random_name_base) + img = Map( + width=self._width, + height=self._height, + filename=base_file, + use_region=True, + env=self._env, + read_file=True, + ) + # We have to call d_erase to ensure the file is created. If there are no + # base layers, then there is nothing to render in random_base_name + img.d_erase() + # Add baselayers + self._render_baselayers(img) + + # Render each layer + for i in range(self._series_length): + # Create file + filename = os.path.join(self._tmpdir.name, f"{i}.png") + # Copying the base_file ensures that previous results are overwritten + shutil.copyfile(base_file, filename) + self._layer_filename_dict[i] = filename + # Render image + img = Map( + width=self._width, + height=self._height, + filename=filename, + use_region=True, + env=self._env, + read_file=True, + ) + for grass_module, kwargs in self._calls[i]: + img.run(grass_module, **kwargs) + + self._layers_rendered = True + + def show(self, slider_width=None): + """Create interactive timeline slider. + + param str slider_width: width of datetime selection slider + + The slider_width parameter sets the width of the slider in the output cell. + It should be formatted as a percentage (%) between 0 and 100 of the cell width + or in pixels (px). Values should be formatted as strings and include the "%" + or "px" suffix. For example, slider_width="80%" or slider_width="500px". + slider_width is passed to ipywidgets in ipywidgets.Layout(width=slider_width). + """ + # Lazy Imports + import ipywidgets as widgets # pylint: disable=import-outside-toplevel + + # Render images if they have not been already + if not self._layers_rendered: + self.render() + + # Set default slider width + if not slider_width: + slider_width = "70%" + + # Create lookup table for slider + lookup = list(zip(self._names, range(self._series_length))) + + # Datetime selection slider + slider = widgets.SelectionSlider( + options=lookup, + value=0, + disabled=False, + continuous_update=True, + orientation="horizontal", + readout=True, + layout=widgets.Layout(width=slider_width), + ) + play = widgets.Play( + interval=500, + value=0, + min=0, + max=self._series_length - 1, + step=1, + description="Press play", + disabled=False, + ) + out_img = widgets.Image(value=b"", format="png") + + def change_slider(change): + slider.value = slider.options[change.new][1] + + play.observe(change_slider, names="value") + + # Display image associated with datetime + def change_image(index): + # Look up layer name for date + filename = self._layer_filename_dict[index] + with open(filename, "rb") as rfile: + out_img.value = rfile.read() + + # Return interact widget with image and slider + widgets.interactive_output(change_image, {"index": slider}) + layout = widgets.Layout( + width="100%", display="inline-flex", flex_flow="row wrap" + ) + return widgets.HBox([play, slider, out_img], layout=layout) + + def save( + self, + filename, + duration=500, + label=True, + font=None, + text_size=12, + text_color="gray", + ): + """ + Creates a GIF animation of rendered layers. + + Text color must be in a format accepted by PIL ImageColor module. For supported + formats, visit: + https://pillow.readthedocs.io/en/stable/reference/ImageColor.html#color-names + + param str filename: name of output GIF file + param int duration: time to display each frame; milliseconds + param bool label: include label on each frame + param str font: font file + param int text_size: size of label text + param str text_color: color to use for the text + """ + + # Render images if they have not been already + if not self._layers_rendered: + self.render() + + tmp_files = [] + for _, file in self._layer_filename_dict.items(): + tmp_files.append(file) + + save_gif( + tmp_files, + filename, + duration=duration, + label=label, + labels=self._names, + font=font, + text_size=text_size, + text_color=text_color, + ) + + # Display the GIF + return filename diff --git a/python/grass/jupyter/tests/seriesmap_test.py b/python/grass/jupyter/tests/seriesmap_test.py new file mode 100644 index 00000000000..79f3b20dbda --- /dev/null +++ b/python/grass/jupyter/tests/seriesmap_test.py @@ -0,0 +1,52 @@ +"""Test SeriesMap functions""" + + +from pathlib import Path +import pytest + +try: + import IPython +except ImportError: + IPython = None + +try: + import ipywidgets +except ImportError: + ipywidgets = None + +import grass.jupyter as gj + + +def test_default_init(space_time_raster_dataset): + """Check that TimeSeriesMap init runs with default parameters""" + img = gj.SeriesMap() + img.add_rasters(space_time_raster_dataset.raster_names) + assert img._names == space_time_raster_dataset.raster_names + + +def test_render_layers(space_time_raster_dataset): + """Check that layers are rendered""" + # create instance of TimeSeriesMap + img = gj.SeriesMap() + # test adding base layer and d_legend here too for efficiency (rendering is + # time-intensive) + img.d_rast(map=space_time_raster_dataset.raster_names[0]) + img.add_rasters(space_time_raster_dataset.raster_names[1:]) + img.d_barscale() + # Render layers + img.render() + # check files exist + # We need to check values which are only in protected attributes + # pylint: disable=protected-access + for unused_layer, filename in img._layer_filename_dict.items(): + assert Path(filename).is_file() + + +@pytest.mark.skipif(IPython is None, reason="IPython package not available") +@pytest.mark.skipif(ipywidgets is None, reason="ipywidgets package not available") +def test_save(space_time_raster_dataset, tmp_path): + """Test returns from animate and time_slider are correct object types""" + img = gj.SeriesMap() + img.add_rasters(space_time_raster_dataset.raster_names) + gif_file = img.save(tmp_path / "image.gif") + assert Path(gif_file).is_file() diff --git a/python/grass/jupyter/timeseriesmap.py b/python/grass/jupyter/timeseriesmap.py index 8d00eba9427..0d9c365b387 100644 --- a/python/grass/jupyter/timeseriesmap.py +++ b/python/grass/jupyter/timeseriesmap.py @@ -16,12 +16,12 @@ import os import weakref import shutil -from pathlib import Path import grass.script as gs from .map import Map from .region import RegionManagerForTimeSeries +from .utils import save_gif def fill_none_values(names): @@ -131,6 +131,7 @@ class TimeSeriesMap: # pylint: disable=too-many-instance-attributes # Need more attributes to build timeseriesmap visuals + # pylint: disable=duplicate-code def __init__( self, @@ -475,42 +476,24 @@ def save( param int text_size: size of date/time text param str text_color: color to use for the text. """ - # Create a GIF from the PNG images - import PIL.Image # pylint: disable=import-outside-toplevel - import PIL.ImageDraw # pylint: disable=import-outside-toplevel - import PIL.ImageFont # pylint: disable=import-outside-toplevel # Render images if they have not been already if not self._layers_rendered: self.render() - # filepath to output GIF - filename = Path(filename) - if filename.suffix.lower() != ".gif": - raise ValueError(_("filename must end in '.gif'")) - - images = [] + input_files = [] for date in self._dates: - img_path = self._date_filename_dict[date] - img = PIL.Image.open(img_path) - img = img.convert("RGBA", dither=None) - draw = PIL.ImageDraw.Draw(img) - if label: - draw.text( - (0, 0), - date, - fill=text_color, - font=PIL.ImageFont.truetype(font, text_size), - ) - images.append(img) - - images[0].save( - fp=filename, - format="GIF", - append_images=images[1:], - save_all=True, + input_files.append(self._date_filename_dict[date]) + + save_gif( + input_files, + filename, duration=duration, - loop=0, + label=label, + labels=self._dates, + font=font, + text_size=text_size, + text_color=text_color, ) # Display the GIF diff --git a/python/grass/jupyter/utils.py b/python/grass/jupyter/utils.py index 92d59e6cd16..e0aa7eb11b2 100644 --- a/python/grass/jupyter/utils.py +++ b/python/grass/jupyter/utils.py @@ -10,7 +10,7 @@ # for details. """Utility functions warpping existing processes in a suitable way""" - +from pathlib import Path import grass.script as gs @@ -200,3 +200,69 @@ def get_rendering_size(region, width, height, default_width=600, default_height= if region_height > region_width: return (round(default_height * region_width / region_height), default_height) return (default_width, round(default_width * region_height / region_width)) + + +def save_gif( + input_files, + output_filename, + duration=500, + label=True, + labels=None, + font=None, + text_size=12, + text_color="gray", +): + """ + Creates a GIF animation + + param list input_files: list of paths to source + param str output_filename: destination gif filename + param int duration: time to display each frame; milliseconds + param bool label: include label stamp on each frame + param list labels: list of labels for each source image + param str font: font file + param int text_size: size of label text + param str text_color: color to use for the text + """ + # Create a GIF from the PNG images + import PIL.Image # pylint: disable=import-outside-toplevel + import PIL.ImageDraw # pylint: disable=import-outside-toplevel + import PIL.ImageFont # pylint: disable=import-outside-toplevel + + # filepath to output GIF + filename = Path(output_filename) + if filename.suffix.lower() != ".gif": + raise ValueError(_("filename must end in '.gif'")) + + images = [] + for i, file in enumerate(input_files): + img = PIL.Image.open(file) + img = img.convert("RGBA", dither=None) + draw = PIL.ImageDraw.Draw(img) + if label: + if font: + font_obj = PIL.ImageFont.truetype(font, text_size) + else: + try: + font_obj = PIL.ImageFont.load_default(size=text_size) + except TypeError: + font_obj = PIL.ImageFont.load_default() + draw.text( + (0, 0), + labels[i], + fill=text_color, + font=font_obj, + ) + images.append(img) + + images[0].save( + fp=filename, + format="GIF", + append_images=images[1:], + save_all=True, + duration=duration, + loop=0, + ) + + # Display the GIF + return filename