From cf1d6df0be1d5824f5cee04fb9f3cc059a173048 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 7 May 2025 16:52:59 +0800 Subject: [PATCH 1/4] Extend the data_kind function to validate the kinds --- pygmt/helpers/utils.py | 41 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index 387d61e03b3..b0209f5b814 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -41,6 +41,11 @@ "ISO-8859-16", ] +# Type hints for the list of data kinds. +Kind = Literal[ + "arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors" +] + def _validate_data_input( # noqa: PLR0912 data=None, x=None, y=None, z=None, required=True, mincols=2, kind=None @@ -272,11 +277,11 @@ def _check_encoding(argstr: str) -> Encoding: return "ISOLatin1+" -def data_kind( - data: Any, required: bool = True -) -> Literal[ - "arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors" -]: +def data_kind( # noqa: PLR0912 + data: Any, + required: bool = True, + check_kind: Kind | Sequence[Kind] | Literal["raster", "vector"] | None = None, +) -> Kind: r""" Check the kind of data that is provided to a module. @@ -307,6 +312,14 @@ def data_kind( required Whether 'data' is required. Set to ``False`` when dealing with optional virtual files. + check_kind + Used to validate the type of data that can be passed in. Valid values are: + + - Any recognized data kind + - A list/tuple of recognized data kinds + - ``"raster"``: shorthand for a sequence of raster-like data kinds + - ``"vector"``: shorthand for a sequence of vector-like data kinds + - ``None``: means no validatation. Returns ------- @@ -414,6 +427,24 @@ def data_kind( kind = "matrix" case _: # Fall back to "vectors" if data is None and required=True. kind = "vectors" + + # Now start to check if the data kind is valid. + if check_kind is not None: + valid_kinds = ("file", "arg") if required is False else ("file",) + match check_kind: + case "raster": + valid_kinds += ("grid", "image") + case "vector": + valid_kinds += ("empty", "matrix", "vectors", "geojson") + case str(): + valid_kinds = (check_kind,) + case list() | tuple(): + valid_kinds = check_kind + + if kind not in valid_kinds: + msg = f"Unrecognized data type: {type(data)}." + raise GMTInvalidInput(msg) + return kind # type: ignore[return-value] From 7cfba7359f968939f3adf01b13f28d2dd754a518 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 7 May 2025 16:57:25 +0800 Subject: [PATCH 2/4] Remove the check_kind code block from virtualfile_in --- pygmt/clib/session.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index eba799660f3..cb369e56064 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1755,7 +1755,7 @@ def virtualfile_from_stringio( @deprecate_parameter( "required_data", "required", "v0.16.0", remove_version="v0.20.0" ) - def virtualfile_in( # noqa: PLR0912 + def virtualfile_in( self, check_kind=None, data=None, @@ -1847,7 +1847,7 @@ def virtualfile_in( # noqa: PLR0912 ) mincols = 3 - kind = data_kind(data, required=required) + kind = data_kind(data, required=required, check_kind=check_kind) _validate_data_input( data=data, x=x, @@ -1858,16 +1858,6 @@ def virtualfile_in( # noqa: PLR0912 kind=kind, ) - if check_kind: - valid_kinds = ("file", "arg") if required is False else ("file",) - if check_kind == "raster": - valid_kinds += ("grid", "image") - elif check_kind == "vector": - valid_kinds += ("empty", "matrix", "vectors", "geojson") - if kind not in valid_kinds: - msg = f"Unrecognized data type for {check_kind}: {type(data)}." - raise GMTInvalidInput(msg) - # Decide which virtualfile_from_ function to use _virtualfile_from = { "arg": contextlib.nullcontext, From ab18175cc02054d97256d9833fd08aa1d9560c36 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Thu, 29 May 2025 14:26:02 +0800 Subject: [PATCH 3/4] Refactor to check kind in data_kind function --- pygmt/src/grdcut.py | 4 ++-- pygmt/src/legend.py | 5 +---- pygmt/src/meca.py | 4 ++-- pygmt/src/plot.py | 4 ++-- pygmt/src/plot3d.py | 4 ++-- pygmt/src/text.py | 4 ++-- pygmt/src/x2sys_cross.py | 2 +- 7 files changed, 12 insertions(+), 15 deletions(-) diff --git a/pygmt/src/grdcut.py b/pygmt/src/grdcut.py index 2d5b1f0e5c9..bd8571412df 100644 --- a/pygmt/src/grdcut.py +++ b/pygmt/src/grdcut.py @@ -117,7 +117,7 @@ def grdcut( raise GMTInvalidInput(msg) # Determine the output data kind based on the input data kind. - match inkind := data_kind(grid): + match inkind := data_kind(grid, check_kind="raster"): case "grid" | "image": outkind = inkind case "file": @@ -128,7 +128,7 @@ def grdcut( with Session() as lib: with ( - lib.virtualfile_in(check_kind="raster", data=grid) as vingrd, + lib.virtualfile_in(data=grid) as vingrd, lib.virtualfile_out(kind=outkind, fname=outgrid) as voutgrd, ): kwargs["G"] = voutgrd diff --git a/pygmt/src/legend.py b/pygmt/src/legend.py index 2cb2eddcf95..21a48712c90 100644 --- a/pygmt/src/legend.py +++ b/pygmt/src/legend.py @@ -89,10 +89,7 @@ def legend( if kwargs.get("F") is None: kwargs["F"] = box - kind = data_kind(spec) - if kind not in {"empty", "file", "stringio"}: - msg = f"Unrecognized data type: {type(spec)}" - raise GMTInvalidInput(msg) + kind = data_kind(spec, check_kind=("empty", "file", "stringio")) if kind == "file" and is_nonstr_iter(spec): msg = "Only one legend specification file is allowed." raise GMTInvalidInput(msg) diff --git a/pygmt/src/meca.py b/pygmt/src/meca.py index 4a576ce0e51..f5f515bb686 100644 --- a/pygmt/src/meca.py +++ b/pygmt/src/meca.py @@ -49,7 +49,7 @@ def _preprocess_spec(spec, colnames, override_cols): Dictionary of column names and values to override in the input data. Only makes sense if ``spec`` is a dict or :class:`pandas.DataFrame`. """ - kind = data_kind(spec) # Determine the kind of the input data. + kind = data_kind(spec, check_kind="vector") # Determine the kind of the input data. # Convert pandas.DataFrame and numpy.ndarray to dict. if isinstance(spec, pd.DataFrame): @@ -359,5 +359,5 @@ def meca( # noqa: PLR0913 kwargs["A"] = _auto_offset(spec) kwargs["S"] = f"{_convention.code}{scale}" with Session() as lib: - with lib.virtualfile_in(check_kind="vector", data=spec) as vintbl: + with lib.virtualfile_in(data=spec) as vintbl: lib.call_module(module="meca", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/plot.py b/pygmt/src/plot.py index 35dab4aa009..6f62ae16223 100644 --- a/pygmt/src/plot.py +++ b/pygmt/src/plot.py @@ -232,7 +232,7 @@ def plot( # noqa: PLR0912 # parameter. self._activate_figure() - kind = data_kind(data) + kind = data_kind(data, check_kind="vector") if kind == "empty": # Data is given via a series of vectors. data = {"x": x, "y": y} # Parameters for vector styles @@ -280,5 +280,5 @@ def plot( # noqa: PLR0912 kwargs["S"] = "s0.2c" with Session() as lib: - with lib.virtualfile_in(check_kind="vector", data=data) as vintbl: + with lib.virtualfile_in(data=data) as vintbl: lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/plot3d.py b/pygmt/src/plot3d.py index 491ebbcf9df..9cb1fdb2d2c 100644 --- a/pygmt/src/plot3d.py +++ b/pygmt/src/plot3d.py @@ -210,7 +210,7 @@ def plot3d( # noqa: PLR0912 # parameter. self._activate_figure() - kind = data_kind(data) + kind = data_kind(data, check_kind="vector") if kind == "empty": # Data is given via a series of vectors. data = {"x": x, "y": y, "z": z} # Parameters for vector styles @@ -259,5 +259,5 @@ def plot3d( # noqa: PLR0912 kwargs["S"] = "u0.2c" with Session() as lib: - with lib.virtualfile_in(check_kind="vector", data=data, mincols=3) as vintbl: + with lib.virtualfile_in(data=data, mincols=3) as vintbl: lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/text.py b/pygmt/src/text.py index 34c70734d52..c1a4940ac12 100644 --- a/pygmt/src/text.py +++ b/pygmt/src/text.py @@ -191,7 +191,7 @@ def text_( # noqa: PLR0912 raise GMTInvalidInput(msg) data_is_required = position is None - kind = data_kind(textfiles, required=data_is_required) + kind = data_kind(textfiles, required=data_is_required, check_kind="vector") if position is not None and (text is None or is_nonstr_iter(text)): msg = "'text' can't be None or array when 'position' is given." @@ -261,7 +261,7 @@ def text_( # noqa: PLR0912 with Session() as lib: with lib.virtualfile_in( - check_kind="vector", data=textfiles or data, required=data_is_required + data=textfiles or data, required=data_is_required ) as vintbl: lib.call_module( module="text", diff --git a/pygmt/src/x2sys_cross.py b/pygmt/src/x2sys_cross.py index d502d72c190..ff21afd0ac0 100644 --- a/pygmt/src/x2sys_cross.py +++ b/pygmt/src/x2sys_cross.py @@ -195,7 +195,7 @@ def x2sys_cross( file_contexts: list[contextlib.AbstractContextManager[Any]] = [] for track in tracks: - match data_kind(track): + match data_kind(track, check_kind="vector"): case "file": file_contexts.append(contextlib.nullcontext(track)) case "vectors": From ff860513ed9ebeda5fd8fe4cfc470a06a501fd5c Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Thu, 29 May 2025 14:39:55 +0800 Subject: [PATCH 4/4] Session.virtualf_in: Allow passing the data kind if it's already known --- pygmt/clib/session.py | 5 ++++- pygmt/src/grdcut.py | 2 +- pygmt/src/legend.py | 2 +- pygmt/src/plot.py | 3 ++- pygmt/src/plot3d.py | 3 ++- pygmt/src/text.py | 7 +++++-- 6 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index cb369e56064..12f7df5eb45 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1758,6 +1758,7 @@ def virtualfile_from_stringio( def virtualfile_in( self, check_kind=None, + kind=None, data=None, x=None, y=None, @@ -1847,7 +1848,9 @@ def virtualfile_in( ) mincols = 3 - kind = data_kind(data, required=required, check_kind=check_kind) + # Determine the data kind if not given. + if kind is None: + kind = data_kind(data, required=required, check_kind=check_kind) _validate_data_input( data=data, x=x, diff --git a/pygmt/src/grdcut.py b/pygmt/src/grdcut.py index bd8571412df..b200a5c118b 100644 --- a/pygmt/src/grdcut.py +++ b/pygmt/src/grdcut.py @@ -128,7 +128,7 @@ def grdcut( with Session() as lib: with ( - lib.virtualfile_in(data=grid) as vingrd, + lib.virtualfile_in(data=grid, kind=inkind) as vingrd, lib.virtualfile_out(kind=outkind, fname=outgrid) as voutgrd, ): kwargs["G"] = voutgrd diff --git a/pygmt/src/legend.py b/pygmt/src/legend.py index 21a48712c90..df2636eb108 100644 --- a/pygmt/src/legend.py +++ b/pygmt/src/legend.py @@ -95,5 +95,5 @@ def legend( raise GMTInvalidInput(msg) with Session() as lib: - with lib.virtualfile_in(data=spec, required=False) as vintbl: + with lib.virtualfile_in(data=spec, required=False, kind=kind) as vintbl: lib.call_module(module="legend", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/plot.py b/pygmt/src/plot.py index 6f62ae16223..5fcaffa6aea 100644 --- a/pygmt/src/plot.py +++ b/pygmt/src/plot.py @@ -234,6 +234,7 @@ def plot( # noqa: PLR0912 kind = data_kind(data, check_kind="vector") if kind == "empty": # Data is given via a series of vectors. + kind = "vectors" data = {"x": x, "y": y} # Parameters for vector styles if ( @@ -280,5 +281,5 @@ def plot( # noqa: PLR0912 kwargs["S"] = "s0.2c" with Session() as lib: - with lib.virtualfile_in(data=data) as vintbl: + with lib.virtualfile_in(data=data, kind=kind) as vintbl: lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/plot3d.py b/pygmt/src/plot3d.py index 9cb1fdb2d2c..17028cfba3c 100644 --- a/pygmt/src/plot3d.py +++ b/pygmt/src/plot3d.py @@ -212,6 +212,7 @@ def plot3d( # noqa: PLR0912 kind = data_kind(data, check_kind="vector") if kind == "empty": # Data is given via a series of vectors. + kind = "vectors" data = {"x": x, "y": y, "z": z} # Parameters for vector styles if ( @@ -259,5 +260,5 @@ def plot3d( # noqa: PLR0912 kwargs["S"] = "u0.2c" with Session() as lib: - with lib.virtualfile_in(data=data, mincols=3) as vintbl: + with lib.virtualfile_in(data=data, mincols=3, kind=kind) as vintbl: lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/text.py b/pygmt/src/text.py index c1a4940ac12..e1ff229d4a2 100644 --- a/pygmt/src/text.py +++ b/pygmt/src/text.py @@ -42,7 +42,7 @@ w="wrap", ) @kwargs_to_strings(R="sequence", c="sequence_comma", p="sequence") -def text_( # noqa: PLR0912 +def text_( # noqa: PLR0912, PLR0915 self, textfiles: PathLike | TableLike | None = None, x=None, @@ -225,6 +225,7 @@ def text_( # noqa: PLR0912 confdict = {} data = None if kind == "empty": + kind = "vectors" data = {"x": x, "y": y} for arg, flag, name in array_args: @@ -261,7 +262,9 @@ def text_( # noqa: PLR0912 with Session() as lib: with lib.virtualfile_in( - data=textfiles or data, required=data_is_required + data=textfiles or data, + required=data_is_required, + kind=kind, ) as vintbl: lib.call_module( module="text",