diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 069a762001d..918d26b09b6 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -7,6 +7,7 @@ import ctypes as ctp import pathlib import sys +import warnings from contextlib import contextmanager, nullcontext import numpy as np @@ -26,7 +27,12 @@ GMTInvalidInput, GMTVersionError, ) -from pygmt.helpers import data_kind, fmt_docstring, tempfile_from_geojson +from pygmt.helpers import ( + data_kind, + fmt_docstring, + tempfile_from_geojson, + tempfile_from_image, +) FAMILIES = [ "GMT_IS_DATASET", # Entity is a data table @@ -1540,7 +1546,7 @@ def virtualfile_from_data( if check_kind: valid_kinds = ("file", "arg") if required_data is False else ("file",) if check_kind == "raster": - valid_kinds += ("grid",) + valid_kinds += ("grid", "image") elif check_kind == "vector": valid_kinds += ("matrix", "vectors", "geojson") if kind not in valid_kinds: @@ -1554,6 +1560,7 @@ def virtualfile_from_data( "arg": nullcontext, "geojson": tempfile_from_geojson, "grid": self.virtualfile_from_grid, + "image": tempfile_from_image, # Note: virtualfile_from_matrix is not used because a matrix can be # converted to vectors instead, and using vectors allows for better # handling of string type inputs (e.g. for datetime data types) @@ -1562,7 +1569,16 @@ def virtualfile_from_data( }[kind] # Ensure the data is an iterable (Python list or tuple) - if kind in ("geojson", "grid", "file", "arg"): + if kind in ("geojson", "grid", "image", "file", "arg"): + if kind == "image" and data.dtype != "uint8": + msg = ( + f"Input image has dtype: {data.dtype} which is unsupported, " + "and may result in an incorrect output. Please recast image " + "to a uint8 dtype and/or scale to 0-255 range, e.g. " + "using a histogram equalization function like " + "skimage.exposure.equalize_hist." + ) + warnings.warn(message=msg, category=RuntimeWarning, stacklevel=2) _data = (data,) if not isinstance(data, pathlib.PurePath) else (str(data),) elif kind == "vectors": _data = [np.atleast_1d(x), np.atleast_1d(y)] diff --git a/pygmt/helpers/__init__.py b/pygmt/helpers/__init__.py index efea2845cc7..eabcb87500f 100644 --- a/pygmt/helpers/__init__.py +++ b/pygmt/helpers/__init__.py @@ -7,7 +7,12 @@ kwargs_to_strings, use_alias, ) -from pygmt.helpers.tempfile import GMTTempFile, tempfile_from_geojson, unique_name +from pygmt.helpers.tempfile import ( + GMTTempFile, + tempfile_from_geojson, + tempfile_from_image, + unique_name, +) from pygmt.helpers.utils import ( args_in_kwargs, build_arg_string, diff --git a/pygmt/helpers/tempfile.py b/pygmt/helpers/tempfile.py index c184c5e73d3..8ac63006565 100644 --- a/pygmt/helpers/tempfile.py +++ b/pygmt/helpers/tempfile.py @@ -147,3 +147,34 @@ def tempfile_from_geojson(geojson): geoseries.to_file(**ogrgmt_kwargs) yield tmpfile.name + + +@contextmanager +def tempfile_from_image(image): + """ + Saves a 3-band :class:`xarray.DataArray` to a temporary GeoTIFF file via + rioxarray. + + Parameters + ---------- + image : xarray.DataArray + An xarray.DataArray with three dimensions, having a shape like + (3, Y, X). + + Yields + ------ + tmpfilename : str + A temporary GeoTIFF file holding the image data. E.g. '1a2b3c4d5.tif'. + """ + with GMTTempFile(suffix=".tif") as tmpfile: + os.remove(tmpfile.name) # ensure file is deleted first + try: + image.rio.to_raster(raster_path=tmpfile.name) + except AttributeError as e: # object has no attribute 'rio' + raise ImportError( + "Package `rioxarray` is required to be installed to use this function. " + "Please use `python -m pip install rioxarray` or " + "`mamba install -c conda-forge rioxarray` " + "to install the package." + ) from e + yield tmpfile.name diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index e52197012be..68686ac9759 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -141,8 +141,8 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data Returns ------- kind : str - One of ``'arg'``, ``'file'``, ``'grid'``, ``'geojson'``, ``'matrix'``, - or ``'vectors'``. + One of ``'arg'``, ``'file'``, ``'grid'``, ``image``, ``'geojson'``, + ``'matrix'``, or ``'vectors'``. Examples -------- @@ -166,6 +166,8 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data 'arg' >>> data_kind(data=xr.DataArray(np.random.rand(4, 3))) 'grid' + >>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5))) + 'image' """ # determine the data kind if isinstance(data, (str, pathlib.PurePath)): @@ -173,7 +175,7 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data elif isinstance(data, (bool, int, float)) or (data is None and not required_data): kind = "arg" elif isinstance(data, xr.DataArray): - kind = "grid" + kind = "image" if len(data.dims) == 3 else "grid" elif hasattr(data, "__geo_interface__"): # geo-like Python object that implements ``__geo_interface__`` # (geopandas.GeoDataFrame or shapely.geometry) diff --git a/pygmt/src/grdimage.py b/pygmt/src/grdimage.py index d036453985e..6f200237927 100644 --- a/pygmt/src/grdimage.py +++ b/pygmt/src/grdimage.py @@ -41,11 +41,11 @@ def grdimage(self, grid, **kwargs): instructions to derive intensities from the input data grid. Values outside this range will be clipped. Such intensity files can be created from the grid using :func:`pygmt.grdgradient` and, optionally, modified by - :gmt-docs:`grdmath.html` or :class:`pygmt.grdhisteq`. If GMT is built - with GDAL support, ``grid`` can be an image file (geo-referenced or not). - In this case the image can optionally be illuminated with the file - provided via the ``shading`` parameter. Here, if image has no coordinates - then those of the intensity file will be used. + :gmt-docs:`grdmath.html` or :class:`pygmt.grdhisteq`. Alternatively, pass + *image* which can be an image file (geo-referenced or not). In this case + the image can optionally be illuminated with the file provided via the + ``shading`` parameter. Here, if image has no coordinates then those of the + intensity file will be used. When using map projections, the grid is first resampled on a new rectangular grid with the same dimensions. Higher resolution images can @@ -74,10 +74,7 @@ def grdimage(self, grid, **kwargs): :gmt-docs:`grdimage.html#grid-file-formats`). img_out : str *out_img*\[=\ *driver*]. - Save an image in a raster format instead of PostScript. Use - extension .ppm for a Portable Pixel Map format which is the only - raster format GMT can natively write. For GMT installations - configured with GDAL support there are more choices: Append + Save an image in a raster format instead of PostScript. Append *out_img* to select the image file name and extension. If the extension is one of .bmp, .gif, .jpg, .png, or .tif then no driver information is required. For other output formats you must append @@ -131,8 +128,8 @@ def grdimage(self, grid, **kwargs): :func:`pygmt.grdgradient` separately first. If we should derive intensities from another file than grid, specify the file with suitable modifiers [Default is no illumination]. **Note**: If the - input data is an *image* then an *intensfile* or constant *intensity* - must be provided. + input data represent an *image* then an *intensfile* or constant + *intensity* must be provided. {projection} monochrome : bool Force conversion to monochrome image using the (television) YIQ @@ -144,10 +141,9 @@ def grdimage(self, grid, **kwargs): [**+z**\ *value*][*color*] Make grid nodes with z = NaN transparent, using the color-masking feature in PostScript Level 3 (the PS device must support PS Level - 3). If the input is a grid, use **+z** with a *value* to select - another grid value than NaN. If the input is instead an image, - append an alternate *color* to select another pixel value to be - transparent [Default is ``"black"``]. + 3). If the input is a grid, use **+z** to select another grid value + than NaN. If input is instead an image, append an alternate *color* to + select another pixel value to be transparent [Default is ``"black"``]. {region} {verbose} {panel} @@ -171,6 +167,7 @@ def grdimage(self, grid, **kwargs): >>> fig.show() """ kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access + with Session() as lib: with lib.virtualfile_from_data( check_kind="raster", data=grid diff --git a/pygmt/src/tilemap.py b/pygmt/src/tilemap.py index ae614ae11fd..b0fc0164076 100644 --- a/pygmt/src/tilemap.py +++ b/pygmt/src/tilemap.py @@ -3,13 +3,7 @@ """ from pygmt.clib import Session from pygmt.datasets.tile_map import load_tile_map -from pygmt.helpers import ( - GMTTempFile, - build_arg_string, - fmt_docstring, - kwargs_to_strings, - use_alias, -) +from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias try: import rioxarray @@ -148,9 +142,9 @@ def tilemap( if kwargs.get("N") in [None, False]: kwargs["R"] = "/".join(str(coordinate) for coordinate in region) - with GMTTempFile(suffix=".tif") as tmpfile: - raster.rio.to_raster(raster_path=tmpfile.name) - with Session() as lib: + with Session() as lib: + file_context = lib.virtualfile_from_data(check_kind="raster", data=raster) + with file_context as infile: lib.call_module( - module="grdimage", args=build_arg_string(kwargs, infile=tmpfile.name) + module="grdimage", args=build_arg_string(kwargs, infile=infile) ) diff --git a/pygmt/tests/baseline/test_grdimage_image.png.dvc b/pygmt/tests/baseline/test_grdimage_image.png.dvc new file mode 100644 index 00000000000..4af74249741 --- /dev/null +++ b/pygmt/tests/baseline/test_grdimage_image.png.dvc @@ -0,0 +1,4 @@ +outs: +- md5: 2e919645d5af956ec4f8aa054a86a70a + size: 110214 + path: test_grdimage_image.png diff --git a/pygmt/tests/test_grdimage_image.py b/pygmt/tests/test_grdimage_image.py new file mode 100644 index 00000000000..ade642605df --- /dev/null +++ b/pygmt/tests/test_grdimage_image.py @@ -0,0 +1,79 @@ +""" +Test Figure.grdimage on 3-band RGB images. +""" +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from pygmt import Figure, which + +rasterio = pytest.importorskip("rasterio") +rioxarray = pytest.importorskip("rioxarray") + + +@pytest.fixture(scope="module", name="xr_image") +def fixture_xr_image(): + """ + Load the image data from Blue Marble as an xarray.DataArray with shape + {"band": 3, "y": 180, "x": 360}. + """ + geotiff = which(fname="@earth_day_01d_p", download="c") + with rioxarray.open_rasterio(filename=geotiff) as rda: + if len(rda.band) == 1: + with rasterio.open(fp=geotiff) as src: + df_colormap = pd.DataFrame.from_dict( + data=src.colormap(1), orient="index" + ) + array = src.read() + + red = np.vectorize(df_colormap[0].get)(array) + green = np.vectorize(df_colormap[1].get)(array) + blue = np.vectorize(df_colormap[2].get)(array) + # alpha = np.vectorize(df_colormap[3].get)(array) + + rda.data = red + da_red = rda.astype(dtype=np.uint8).copy() + rda.data = green + da_green = rda.astype(dtype=np.uint8).copy() + rda.data = blue + da_blue = rda.astype(dtype=np.uint8).copy() + + xr_image = xr.concat(objs=[da_red, da_green, da_blue], dim="band") + assert xr_image.sizes == {"band": 3, "y": 180, "x": 360} + return xr_image + + +@pytest.mark.mpl_image_compare +def test_grdimage_image(): + """ + Plot a 3-band RGB image using file input. + """ + fig = Figure() + fig.grdimage(grid="@earth_day_01d") + return fig + + +@pytest.mark.mpl_image_compare(filename="test_grdimage_image.png") +def test_grdimage_image_dataarray(xr_image): + """ + Plot a 3-band RGB image using xarray.DataArray input. + """ + fig = Figure() + fig.grdimage(grid=xr_image) + return fig + + +@pytest.mark.parametrize( + "dtype", + ["int8", "uint16", "int16", "uint32", "int32", "float32", "float64"], +) +def test_grdimage_image_dataarray_unsupported_dtype(dtype, xr_image): + """ + Plot a 3-band RGB image using xarray.DataArray input, with an unsupported + data type. + """ + fig = Figure() + image = xr_image.astype(dtype=dtype) + with pytest.warns(expected_warning=RuntimeWarning) as record: + fig.grdimage(grid=image) + assert len(record) == 1