From 9ac45fc6d69e5c143001b6ebad3ee77a86698ee7 Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Fri, 27 Nov 2020 12:07:17 -0500 Subject: [PATCH 01/14] add type hints for the str accessor class --- xarray/core/accessor_str.py | 172 ++++++++++++++++++++++++++++++------ 1 file changed, 143 insertions(+), 29 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 02d8ca00bf9..bfb1e7a0efe 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -40,6 +40,7 @@ import codecs import re import textwrap +from typing import Any, Callable, Mapping, Union import numpy as np @@ -104,7 +105,11 @@ def __getitem__(self, key): else: return self.get(key) - def get(self, i, default=""): + def get( + self, + i: int, + default: str = "", + ) -> Any: """ Extract character number `i` from each string in the array. @@ -129,7 +134,12 @@ def f(x): return self._apply(f) - def slice(self, start=None, stop=None, step=None): + def slice( + self, + start: int = None, + stop: int = None, + step: int = None, + ) -> Any: """ Slice substrings from each string in the array. @@ -150,7 +160,12 @@ def slice(self, start=None, stop=None, step=None): f = lambda x: x[s] return self._apply(f) - def slice_replace(self, start=None, stop=None, repl=""): + def slice_replace( + self, + start: int = None, + stop: int = None, + repl: str = "", + ) -> Any: """ Replace a positional slice of a string with another value. @@ -338,7 +353,11 @@ def isupper(self): """ return self._apply(lambda x: x.isupper(), dtype=bool) - def count(self, pat, flags=0): + def count( + self, + pat: str, + flags: int = 0, + ) -> Any: """ Count occurrences of pattern in each string of the array. @@ -363,7 +382,10 @@ def count(self, pat, flags=0): f = lambda x: len(regex.findall(x)) return self._apply(f, dtype=int) - def startswith(self, pat): + def startswith( + self, + pat: str, + ) -> Any: """ Test if the start of each string in the array matches a pattern. @@ -382,7 +404,10 @@ def startswith(self, pat): f = lambda x: x.startswith(pat) return self._apply(f, dtype=bool) - def endswith(self, pat): + def endswith( + self, + pat: str, + ) -> Any: """ Test if the end of each string in the array matches a pattern. @@ -401,7 +426,12 @@ def endswith(self, pat): f = lambda x: x.endswith(pat) return self._apply(f, dtype=bool) - def pad(self, width, side="left", fillchar=" "): + def pad( + self, + width: int, + side: str = "left", + fillchar: str = " ", + ) -> Any: """ Pad strings in the array up to width. @@ -436,7 +466,11 @@ def pad(self, width, side="left", fillchar=" "): return self._apply(f) - def center(self, width, fillchar=" "): + def center( + self, + width: int, + fillchar: str = " ", + ) -> Any: """ Pad left and right side of each string in the array. @@ -454,7 +488,11 @@ def center(self, width, fillchar=" "): """ return self.pad(width, side="both", fillchar=fillchar) - def ljust(self, width, fillchar=" "): + def ljust( + self, + width: int, + fillchar: str = " ", + ) -> Any: """ Pad right side of each string in the array. @@ -472,7 +510,11 @@ def ljust(self, width, fillchar=" "): """ return self.pad(width, side="right", fillchar=fillchar) - def rjust(self, width, fillchar=" "): + def rjust( + self, + width: int, + fillchar: str = " ", + ) -> Any: """ Pad left side of each string in the array. @@ -490,7 +532,10 @@ def rjust(self, width, fillchar=" "): """ return self.pad(width, side="left", fillchar=fillchar) - def zfill(self, width): + def zfill( + self, + width: int, + ) -> Any: """ Pad each string in the array by prepending '0' characters. @@ -510,7 +555,13 @@ def zfill(self, width): """ return self.pad(width, side="left", fillchar="0") - def contains(self, pat, case=True, flags=0, regex=True): + def contains( + self, + pat: str, + case: bool = True, + flags: int = 0, + regex: bool = True, + ) -> Any: """ Test if pattern or regex is contained within each string of the array. @@ -542,12 +593,12 @@ def contains(self, pat, case=True, flags=0, regex=True): if not case: flags |= re.IGNORECASE - regex = re.compile(pat, flags=flags) + regex_obj = re.compile(pat, flags=flags) - if regex.groups > 0: # pragma: no cover + if regex_obj.groups > 0: # pragma: no cover raise ValueError("This pattern has match groups.") - f = lambda x: bool(regex.search(x)) + f = lambda x: bool(regex_obj.search(x)) else: if case: f = lambda x: pat in x @@ -557,7 +608,12 @@ def contains(self, pat, case=True, flags=0, regex=True): return self._apply(f, dtype=bool) - def match(self, pat, case=True, flags=0): + def match( + self, + pat: str, + case: bool = True, + flags: int = 0, + ) -> Any: """ Determine if each string in the array matches a regular expression. @@ -582,7 +638,11 @@ def match(self, pat, case=True, flags=0): f = lambda x: bool(regex.match(x)) return self._apply(f, dtype=bool) - def strip(self, to_strip=None, side="both"): + def strip( + self, + to_strip: str = None, + side: str = "both", + ) -> Any: """ Remove leading and trailing characters. @@ -616,7 +676,10 @@ def strip(self, to_strip=None, side="both"): return self._apply(f) - def lstrip(self, to_strip=None): + def lstrip( + self, + to_strip: str = None, + ) -> Any: """ Remove leading characters. @@ -636,7 +699,10 @@ def lstrip(self, to_strip=None): """ return self.strip(to_strip, side="left") - def rstrip(self, to_strip=None): + def rstrip( + self, + to_strip: str = None, + ) -> Any: """ Remove trailing characters. @@ -656,7 +722,11 @@ def rstrip(self, to_strip=None): """ return self.strip(to_strip, side="right") - def wrap(self, width, **kwargs): + def wrap( + self, + width: int, + **kwargs, + ) -> Any: """ Wrap long strings in the array in paragraphs with length less than `width`. @@ -678,7 +748,10 @@ def wrap(self, width, **kwargs): f = lambda x: "\n".join(tw.wrap(x)) return self._apply(f) - def translate(self, table): + def translate( + self, + table: Mapping[str, str], + ) -> Any: """ Map characters of each string through the given mapping table. @@ -697,7 +770,10 @@ def translate(self, table): f = lambda x: x.translate(table) return self._apply(f) - def repeat(self, repeats): + def repeat( + self, + repeats: int, + ) -> Any: """ Duplicate each string in the array. @@ -714,7 +790,13 @@ def repeat(self, repeats): f = lambda x: repeats * x return self._apply(f) - def find(self, sub, start=0, end=None, side="left"): + def find( + self, + sub: str, + start: int = 0, + end: int = None, + side: str = "left", + ) -> Any: """ Return lowest or highest indexes in each strings in the array where the substring is fully contained between [start:end]. @@ -751,7 +833,12 @@ def find(self, sub, start=0, end=None, side="left"): return self._apply(f, dtype=int) - def rfind(self, sub, start=0, end=None): + def rfind( + self, + sub: str, + start: int = 0, + end: int = None, + ) -> Any: """ Return highest indexes in each strings in the array where the substring is fully contained between [start:end]. @@ -772,7 +859,13 @@ def rfind(self, sub, start=0, end=None): """ return self.find(sub, start=start, end=end, side="right") - def index(self, sub, start=0, end=None, side="left"): + def index( + self, + sub: str, + start: int = 0, + end: int = None, + side: str = "left", + ) -> Any: """ Return lowest or highest indexes in each strings where the substring is fully contained between [start:end]. This is the same as @@ -810,7 +903,12 @@ def index(self, sub, start=0, end=None, side="left"): return self._apply(f, dtype=int) - def rindex(self, sub, start=0, end=None): + def rindex( + self, + sub: str, + start: int = 0, + end: int = None, + ) -> Any: """ Return highest indexes in each strings where the substring is fully contained between [start:end]. This is the same as @@ -832,7 +930,15 @@ def rindex(self, sub, start=0, end=None): """ return self.index(sub, start=start, end=end, side="right") - def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): + def replace( + self, + pat: Union[str, Any], + repl: Union[str, Callable], + n: int = -1, + case: bool = None, + flags: int = 0, + regex: bool = True, + ) -> Any: """ Replace occurrences of pattern/regex in the array with some string. @@ -907,7 +1013,11 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): f = lambda x: x.replace(pat, repl, n) return self._apply(f) - def decode(self, encoding, errors="strict"): + def decode( + self, + encoding: str, + errors: str = "strict", + ) -> Any: """ Decode character string in the array using indicated encoding. @@ -927,7 +1037,11 @@ def decode(self, encoding, errors="strict"): f = lambda x: decoder(x, errors)[0] return self._apply(f, dtype=np.str_) - def encode(self, encoding, errors="strict"): + def encode( + self, + encoding: str, + errors: str = "strict", + ) -> Any: """ Encode character string in the array using indicated encoding. From 7b37c6d0afc93f8d34aa07655740801ea76ddd3f Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Fri, 27 Nov 2020 12:07:26 -0500 Subject: [PATCH 02/14] allow str accessors to use regular expression objects for regular expressions --- xarray/core/accessor_str.py | 143 ++++++++++++++++++------------ xarray/tests/test_accessor_str.py | 6 +- 2 files changed, 88 insertions(+), 61 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index bfb1e7a0efe..ea68548c306 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -77,6 +77,7 @@ class StringAccessor: """ __slots__ = ("_obj",) + _pattern_type = type(re.compile("")) def __init__(self, obj): self._obj = obj @@ -89,6 +90,30 @@ def _apply(self, f, dtype=None): g = np.vectorize(f, otypes=[dtype]) return apply_ufunc(g, self._obj, dask="parallelized", output_dtypes=[dtype]) + def _check_is_compiled_re(self, pat): + return isinstance(pat, self._pattern_type) + + def _re_compile(self, pat, flags, case=None): + is_compiled_re = self._check_is_compiled_re(pat) + + if is_compiled_re and flags != 0: + raise ValueError("flags cannot be set when pat is a compiled regex") + + if is_compiled_re and case is not None: + raise ValueError("case cannot be set when pat is a compiled regex") + + if is_compiled_re: + return pat + + if case is None: + case = True + + if not case: + flags |= re.IGNORECASE + + pat = self._obj.dtype.type(pat) + return re.compile(pat, flags=flags) + def len(self): """ Compute the length of each string in the array. @@ -355,7 +380,8 @@ def isupper(self): def count( self, - pat: str, + pat: Any, + case: bool = True, flags: int = 0, ) -> Any: """ @@ -367,8 +393,11 @@ def count( Parameters ---------- - pat : str - Valid regular expression. + pat : str or re.Pattern + A string contain a regular expression or + a compiled regular expression object. + case : bool, default: True + If True, case sensitive. flags : int, default: 0 Flags for the `re` module. Use 0 for no flags. For a complete list, `see here `_. @@ -377,9 +406,9 @@ def count( ------- counts : array of int """ - pat = self._obj.dtype.type(pat) - regex = re.compile(pat, flags=flags) - f = lambda x: len(regex.findall(x)) + pat = self._re_compile(pat, flags, case) + + f = lambda x: len(pat.findall(x)) return self._apply(f, dtype=int) def startswith( @@ -557,8 +586,8 @@ def zfill( def contains( self, - pat: str, - case: bool = True, + pat: Any, + case: bool = None, flags: int = 0, regex: bool = True, ) -> Any: @@ -570,8 +599,9 @@ def contains( Parameters ---------- - pat : str - Character sequence or regular expression. + pat : str or re.Pattern + Character sequence, a string containing a regular expression, + or a compiled regular expression object. case : bool, default: True If True, case sensitive. flags : int, default: 0 @@ -588,19 +618,21 @@ def contains( given pattern is contained within the string of each element of the array. """ - pat = self._obj.dtype.type(pat) - if regex: - if not case: - flags |= re.IGNORECASE - - regex_obj = re.compile(pat, flags=flags) + is_compiled_re = self._check_is_compiled_re(pat) + if is_compiled_re and not regex: + raise ValueError( + "Must use regular expression matching for regular expression object." + ) - if regex_obj.groups > 0: # pragma: no cover + if regex: + pat = self._re_compile(pat, flags, case) + if pat.groups > 0: # pragma: no cover raise ValueError("This pattern has match groups.") - f = lambda x: bool(regex_obj.search(x)) + f = lambda x: bool(pat.search(x)) else: - if case: + pat = self._obj.dtype.type(pat) + if case or case is None: f = lambda x: pat in x else: uppered = self._obj.str.upper() @@ -619,8 +651,9 @@ def match( Parameters ---------- - pat : str - Character sequence or regular expression + pat : str or re.Pattern + A string containing a regular expression or + a compiled regular expression object. case : bool, default: True If True, case sensitive flags : int, default: 0 @@ -630,11 +663,23 @@ def match( ------- matched : array of bool """ - if not case: - flags |= re.IGNORECASE + is_compiled_re = self._check_is_compiled_re(pat) + + if is_compiled_re and flags != 0: + raise ValueError("flags cannot be set when pat is a compiled regex") + + if is_compiled_re and case is not None: + raise ValueError("case cannot be set when pat is a compiled regex") + + if case is None: + case = True + + if not is_compiled_re: + if not case: + flags |= re.IGNORECASE + pat = self._obj.dtype.type(pat) + regex = re.compile(pat, flags=flags) - pat = self._obj.dtype.type(pat) - regex = re.compile(pat, flags=flags) f = lambda x: bool(regex.match(x)) return self._apply(f, dtype=bool) @@ -932,7 +977,7 @@ def rindex( def replace( self, - pat: Union[str, Any], + pat: Any, repl: Union[str, Callable], n: int = -1, case: bool = None, @@ -971,45 +1016,27 @@ def replace( A copy of the object with all matching occurrences of `pat` replaced by `repl`. """ - if not (_is_str_like(repl) or callable(repl)): # pragma: no cover + if not _is_str_like(repl) and not callable(repl): # pragma: no cover raise TypeError("repl must be a string or callable") - if _is_str_like(pat): - pat = self._obj.dtype.type(pat) - if _is_str_like(repl): repl = self._obj.dtype.type(repl) - is_compiled_re = isinstance(pat, type(re.compile(""))) + is_compiled_re = self._check_is_compiled_re(pat) + if not regex and is_compiled_re: + raise ValueError( + "Cannot use a compiled regex as replacement pattern with regex=False" + ) + + if not regex and callable(repl): + raise ValueError("Cannot use a callable replacement when regex=False") + if regex: - if is_compiled_re: - if (case is not None) or (flags != 0): - raise ValueError( - "case and flags cannot be set when pat is a compiled regex" - ) - else: - # not a compiled regex - # set default case - if case is None: - case = True - - # add case flag, if provided - if case is False: - flags |= re.IGNORECASE - if is_compiled_re or len(pat) > 1 or flags or callable(repl): - n = n if n >= 0 else 0 - compiled = re.compile(pat, flags=flags) - f = lambda x: compiled.sub(repl=repl, string=x, count=n) - else: - f = lambda x: x.replace(pat, repl, n) + pat = self._re_compile(pat, flags, case) + n = n if n >= 0 else 0 + f = lambda x: pat.sub(repl=repl, string=x, count=n) else: - if is_compiled_re: - raise ValueError( - "Cannot use a compiled regex as replacement " - "pattern with regex=False" - ) - if callable(repl): - raise ValueError("Cannot use a callable replacement when regex=False") + pat = self._obj.dtype.type(pat) f = lambda x: x.replace(pat, repl, n) return self._apply(f) diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index e0cbdb7377a..5b10fe2b553 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -181,13 +181,13 @@ def test_replace_compiled_regex(dtype): values = xr.DataArray(["fooBAD__barBAD__bad"]).astype(dtype) pat = re.compile(dtype("BAD[_]*")) - with pytest.raises(ValueError, match="case and flags cannot be"): + with pytest.raises(ValueError, match="flags cannot be set"): result = values.str.replace(pat, "", flags=re.IGNORECASE) - with pytest.raises(ValueError, match="case and flags cannot be"): + with pytest.raises(ValueError, match="case cannot be set"): result = values.str.replace(pat, "", case=False) - with pytest.raises(ValueError, match="case and flags cannot be"): + with pytest.raises(ValueError, match="case cannot be set"): result = values.str.replace(pat, "", case=True) # test with callable From 636c166f7e765522d6bafbe2cc1518214fc49028 Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Fri, 27 Nov 2020 12:07:30 -0500 Subject: [PATCH 03/14] implement casefold and normalize str accessor functions --- xarray/core/accessor_str.py | 37 +++++++++++++++++++++++++++ xarray/tests/test_accessor_str.py | 42 ++++++++++++++++++++++++++++--- 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index ea68548c306..de86a0435ec 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -41,6 +41,7 @@ import re import textwrap from typing import Any, Callable, Mapping, Union +from unicodedata import normalize import numpy as np @@ -279,6 +280,42 @@ def upper(self): """ return self._apply(lambda x: x.upper()) + def casefold(self): + """ + Convert strings in the array to be casefolded. + + Casefolding is similar to converting to lowercase, + but removes all case distinctions. + This is important in some languages that have more complicated + cases and case conversions. + + Returns + ------- + casefolded : same type as values + """ + return self._apply(lambda x: x.casefold()) + + def normalize( + self, + form: str, + ) -> Any: + """ + Return the Unicode normal form for the strings in the datarray. + + For more information on the forms, see the documentation for + :func:`unicodedata.normalize`. + + Parameters + ---------- + side : {"NFC", "NFKC", "NFD", and "NFKD"} + Unicode form. + + Returns + ------- + normalized : same type as values + """ + return self._apply(lambda x: normalize(form, x)) + def isalnum(self): """ Check whether all characters in each string are alphanumeric. diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index 5b10fe2b553..b8cd8da71f0 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- + # Tests for the `str` accessor are derived from the original # pandas string accessor tests. @@ -93,18 +95,52 @@ def test_starts_ends_with(dtype): assert_equal(result, expected) -def test_case(dtype): - da = xr.DataArray(["SOme word"]).astype(dtype) +def test_case_bytes(dtype): + dtype = np.bytes_ + + da = xr.DataArray(["SOme wOrd"]).astype(dtype) capitalized = xr.DataArray(["Some word"]).astype(dtype) lowered = xr.DataArray(["some word"]).astype(dtype) - swapped = xr.DataArray(["soME WORD"]).astype(dtype) + swapped = xr.DataArray(["soME WoRD"]).astype(dtype) titled = xr.DataArray(["Some Word"]).astype(dtype) uppered = xr.DataArray(["SOME WORD"]).astype(dtype) + + assert_equal(da.str.capitalize(), capitalized) + assert_equal(da.str.lower(), lowered) + assert_equal(da.str.swapcase(), swapped) + assert_equal(da.str.title(), titled) + assert_equal(da.str.upper(), uppered) + + +def test_case_str(dtype): + dtype = np.str_ + + # This string includes some unicode characters + # that are common case management corner cases + da = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(dtype) + capitalized = xr.DataArray(["Some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(dtype) + lowered = xr.DataArray(["some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(dtype) + swapped = xr.DataArray(["soME WoRD dž SS ᾛ σς FFI⁵å ç ⅰ"]).astype(dtype) + titled = xr.DataArray(["Some Word Dž Ss ᾛ Σς Ffi⁵Å Ç Ⅰ"]).astype(dtype) + uppered = xr.DataArray(["SOME WORD DŽ SS ἫΙ ΣΣ FFI⁵Å Ç Ⅰ"]).astype(dtype) + casefolded = xr.DataArray(["some word dž ss ἣι σσ ffi⁵å ç ⅰ"]).astype(dtype) + + norm_nfc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(dtype) + norm_nfkc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(dtype) + norm_nfd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(dtype) + norm_nfkd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(dtype) + assert_equal(da.str.capitalize(), capitalized) assert_equal(da.str.lower(), lowered) assert_equal(da.str.swapcase(), swapped) assert_equal(da.str.title(), titled) assert_equal(da.str.upper(), uppered) + assert_equal(da.str.casefold(), casefolded) + + assert_equal(da.str.normalize("NFC"), norm_nfc) + assert_equal(da.str.normalize("NFKC"), norm_nfkc) + assert_equal(da.str.normalize("NFD"), norm_nfd) + assert_equal(da.str.normalize("NFKD"), norm_nfkd) def test_replace(dtype): From 9ea202050fcb05919fff00cfbcba2101b0a91890 Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Fri, 27 Nov 2020 12:07:33 -0500 Subject: [PATCH 04/14] implement one-to-many str accessor functions --- xarray/core/accessor_str.py | 1088 +++++++++++++++-- xarray/tests/test_accessor_str.py | 1842 +++++++++++++++++++++++++++-- 2 files changed, 2727 insertions(+), 203 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index de86a0435ec..5e4d1f1ae05 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -40,7 +40,9 @@ import codecs import re import textwrap -from typing import Any, Callable, Mapping, Union +from functools import reduce +from operator import or_ as set_union +from typing import Any, Callable, Hashable, Mapping, Optional, Pattern, Union from unicodedata import normalize import numpy as np @@ -59,7 +61,7 @@ _cpython_optimized_decoders = _cpython_optimized_encoders + ("utf-16", "utf-32") -def _is_str_like(x): +def _is_str_like(x: Any) -> bool: return isinstance(x, str) or isinstance(x, bytes) @@ -83,19 +85,51 @@ class StringAccessor: def __init__(self, obj): self._obj = obj - def _apply(self, f, dtype=None): - # TODO handling of na values ? - if dtype is None: - dtype = self._obj.dtype - - g = np.vectorize(f, otypes=[dtype]) - return apply_ufunc(g, self._obj, dask="parallelized", output_dtypes=[dtype]) + def _stringify( + self, + invar: Any, + ) -> Union[str, bytes]: + """ + Convert a string-like to the correct string/bytes type. - def _check_is_compiled_re(self, pat): - return isinstance(pat, self._pattern_type) + This is mostly here to tell mypy a pattern is a str/bytes not a re.Pattern. + """ + return self._obj.dtype.type(invar) - def _re_compile(self, pat, flags, case=None): - is_compiled_re = self._check_is_compiled_re(pat) + def _apply( + self, + f: Callable, + obj: Any = None, + dtype: Union[str, np.dtype] = None, + output_core_dims: Union[list, tuple] = ((),), + output_sizes: Mapping[Hashable, int] = None, + **kwargs, + ) -> Any: + # TODO handling of na values ? + if obj is None: + obj = self._obj + if dtype is None: + dtype = obj.dtype + + dask_gufunc_kwargs = dict() + if output_sizes is not None: + dask_gufunc_kwargs["output_sizes"] = output_sizes + + return apply_ufunc( + f, + obj, + vectorize=True, + dask="parallelized", + output_dtypes=[dtype], + output_core_dims=output_core_dims, + dask_gufunc_kwargs=dask_gufunc_kwargs, + **kwargs, + ) + + def _re_compile( + self, pat: Union[str, bytes, Pattern], flags: int, case: bool = None + ) -> Pattern: + is_compiled_re = isinstance(pat, self._pattern_type) if is_compiled_re and flags != 0: raise ValueError("flags cannot be set when pat is a compiled regex") @@ -104,18 +138,21 @@ def _re_compile(self, pat, flags, case=None): raise ValueError("case cannot be set when pat is a compiled regex") if is_compiled_re: - return pat + # no-op, needed to tell mypy this isn't a string + return re.compile(pat) if case is None: case = True + # The case is handled by the re flags internally. + # Add it to the flags if necessary. if not case: flags |= re.IGNORECASE - pat = self._obj.dtype.type(pat) + pat = self._stringify(pat) return re.compile(pat, flags=flags) - def len(self): + def len(self) -> Any: """ Compute the length of each string in the array. @@ -125,7 +162,10 @@ def len(self): """ return self._apply(len, dtype=int) - def __getitem__(self, key): + def __getitem__( + self, + key: Union[int, slice], + ) -> Any: if isinstance(key, slice): return self.slice(start=key.start, stop=key.stop, step=key.step) else: @@ -134,7 +174,7 @@ def __getitem__(self, key): def get( self, i: int, - default: str = "", + default: Union[str, bytes] = "", ) -> Any: """ Extract character number `i` from each string in the array. @@ -190,7 +230,7 @@ def slice_replace( self, start: int = None, stop: int = None, - repl: str = "", + repl: Union[str, bytes] = "", ) -> Any: """ Replace a positional slice of a string with another value. @@ -213,14 +253,14 @@ def slice_replace( ------- replaced : same type as values """ - repl = self._obj.dtype.type(repl) + repl = self._stringify(repl) def f(x): if len(x[start:stop]) == 0: local_stop = start else: local_stop = stop - y = self._obj.dtype.type("") + y = self._stringify("") if start is not None: y += x[:start] y += repl @@ -230,7 +270,7 @@ def f(x): return self._apply(f) - def capitalize(self): + def capitalize(self) -> Any: """ Convert strings in the array to be capitalized. @@ -240,7 +280,7 @@ def capitalize(self): """ return self._apply(lambda x: x.capitalize()) - def lower(self): + def lower(self) -> Any: """ Convert strings in the array to lowercase. @@ -250,7 +290,7 @@ def lower(self): """ return self._apply(lambda x: x.lower()) - def swapcase(self): + def swapcase(self) -> Any: """ Convert strings in the array to be swapcased. @@ -260,7 +300,7 @@ def swapcase(self): """ return self._apply(lambda x: x.swapcase()) - def title(self): + def title(self) -> Any: """ Convert strings in the array to titlecase. @@ -270,7 +310,7 @@ def title(self): """ return self._apply(lambda x: x.title()) - def upper(self): + def upper(self) -> Any: """ Convert strings in the array to uppercase. @@ -280,7 +320,7 @@ def upper(self): """ return self._apply(lambda x: x.upper()) - def casefold(self): + def casefold(self) -> Any: """ Convert strings in the array to be casefolded. @@ -307,16 +347,18 @@ def normalize( Parameters ---------- - side : {"NFC", "NFKC", "NFD", and "NFKD"} + form : {"NFC", "NFKC", "NFD", and "NFKD"} Unicode form. Returns ------- normalized : same type as values + + """ return self._apply(lambda x: normalize(form, x)) - def isalnum(self): + def isalnum(self) -> Any: """ Check whether all characters in each string are alphanumeric. @@ -327,7 +369,7 @@ def isalnum(self): """ return self._apply(lambda x: x.isalnum(), dtype=bool) - def isalpha(self): + def isalpha(self) -> Any: """ Check whether all characters in each string are alphabetic. @@ -338,7 +380,7 @@ def isalpha(self): """ return self._apply(lambda x: x.isalpha(), dtype=bool) - def isdecimal(self): + def isdecimal(self) -> Any: """ Check whether all characters in each string are decimal. @@ -349,7 +391,7 @@ def isdecimal(self): """ return self._apply(lambda x: x.isdecimal(), dtype=bool) - def isdigit(self): + def isdigit(self) -> Any: """ Check whether all characters in each string are digits. @@ -360,7 +402,7 @@ def isdigit(self): """ return self._apply(lambda x: x.isdigit(), dtype=bool) - def islower(self): + def islower(self) -> Any: """ Check whether all characters in each string are lowercase. @@ -371,7 +413,7 @@ def islower(self): """ return self._apply(lambda x: x.islower(), dtype=bool) - def isnumeric(self): + def isnumeric(self) -> Any: """ Check whether all characters in each string are numeric. @@ -382,7 +424,7 @@ def isnumeric(self): """ return self._apply(lambda x: x.isnumeric(), dtype=bool) - def isspace(self): + def isspace(self) -> Any: """ Check whether all characters in each string are spaces. @@ -393,7 +435,7 @@ def isspace(self): """ return self._apply(lambda x: x.isspace(), dtype=bool) - def istitle(self): + def istitle(self) -> Any: """ Check whether all characters in each string are titlecase. @@ -404,7 +446,7 @@ def istitle(self): """ return self._apply(lambda x: x.istitle(), dtype=bool) - def isupper(self): + def isupper(self) -> Any: """ Check whether all characters in each string are uppercase. @@ -417,9 +459,9 @@ def isupper(self): def count( self, - pat: Any, - case: bool = True, + pat: Union[str, bytes, Pattern], flags: int = 0, + case: bool = True, ) -> Any: """ Count occurrences of pattern in each string of the array. @@ -431,13 +473,17 @@ def count( Parameters ---------- pat : str or re.Pattern - A string contain a regular expression or - a compiled regular expression object. + A string containing a regular expression or a compiled regular + expression object. + flags : int, default: 0 + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + Cannot be set if `pat` is a compiled regex. case : bool, default: True If True, case sensitive. - flags : int, default: 0 - Flags for the `re` module. Use 0 for no flags. For a complete list, - `see here `_. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. Returns ------- @@ -450,7 +496,7 @@ def count( def startswith( self, - pat: str, + pat: Union[str, bytes], ) -> Any: """ Test if the start of each string in the array matches a pattern. @@ -466,13 +512,13 @@ def startswith( An array of booleans indicating whether the given pattern matches the start of each string element. """ - pat = self._obj.dtype.type(pat) + pat = self._stringify(pat) f = lambda x: x.startswith(pat) return self._apply(f, dtype=bool) def endswith( self, - pat: str, + pat: Union[str, bytes], ) -> Any: """ Test if the end of each string in the array matches a pattern. @@ -488,7 +534,7 @@ def endswith( A Series of booleans indicating whether the given pattern matches the end of each string element. """ - pat = self._obj.dtype.type(pat) + pat = self._stringify(pat) f = lambda x: x.endswith(pat) return self._apply(f, dtype=bool) @@ -496,7 +542,7 @@ def pad( self, width: int, side: str = "left", - fillchar: str = " ", + fillchar: Union[str, bytes] = " ", ) -> Any: """ Pad strings in the array up to width. @@ -517,7 +563,7 @@ def pad( Array with a minimum number of char in each element. """ width = int(width) - fillchar = self._obj.dtype.type(fillchar) + fillchar = self._stringify(fillchar) if len(fillchar) != 1: raise TypeError("fillchar must be a character, not str") @@ -535,7 +581,7 @@ def pad( def center( self, width: int, - fillchar: str = " ", + fillchar: Union[str, bytes] = " ", ) -> Any: """ Pad left and right side of each string in the array. @@ -557,7 +603,7 @@ def center( def ljust( self, width: int, - fillchar: str = " ", + fillchar: Union[str, bytes] = " ", ) -> Any: """ Pad right side of each string in the array. @@ -579,7 +625,7 @@ def ljust( def rjust( self, width: int, - fillchar: str = " ", + fillchar: Union[str, bytes] = " ", ) -> Any: """ Pad left side of each string in the array. @@ -623,8 +669,8 @@ def zfill( def contains( self, - pat: Any, - case: bool = None, + pat: Union[str, bytes, Pattern], + case: bool = True, flags: int = 0, regex: bool = True, ) -> Any: @@ -641,12 +687,17 @@ def contains( or a compiled regular expression object. case : bool, default: True If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. flags : int, default: 0 - Flags to pass through to the re module, e.g. re.IGNORECASE. - ``0`` means no flags. + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + Cannot be set if `pat` is a compiled regex. regex : bool, default: True If True, assumes the pat is a regular expression. If False, treats the pat as a literal string. + Cannot be set to `False` if `pat` is a compiled regex. Returns ------- @@ -655,7 +706,7 @@ def contains( given pattern is contained within the string of each element of the array. """ - is_compiled_re = self._check_is_compiled_re(pat) + is_compiled_re = isinstance(pat, self._pattern_type) if is_compiled_re and not regex: raise ValueError( "Must use regular expression matching for regular expression object." @@ -668,7 +719,7 @@ def contains( f = lambda x: bool(pat.search(x)) else: - pat = self._obj.dtype.type(pat) + pat = self._stringify(pat) if case or case is None: f = lambda x: pat in x else: @@ -679,7 +730,7 @@ def contains( def match( self, - pat: str, + pat: Union[str, bytes, Pattern], case: bool = True, flags: int = 0, ) -> Any: @@ -692,37 +743,27 @@ def match( A string containing a regular expression or a compiled regular expression object. case : bool, default: True - If True, case sensitive + If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. flags : int, default: 0 - re module flags, e.g. re.IGNORECASE. ``0`` means no flags + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + Cannot be set if `pat` is a compiled regex. Returns ------- matched : array of bool """ - is_compiled_re = self._check_is_compiled_re(pat) - - if is_compiled_re and flags != 0: - raise ValueError("flags cannot be set when pat is a compiled regex") - - if is_compiled_re and case is not None: - raise ValueError("case cannot be set when pat is a compiled regex") - - if case is None: - case = True - - if not is_compiled_re: - if not case: - flags |= re.IGNORECASE - pat = self._obj.dtype.type(pat) - regex = re.compile(pat, flags=flags) + pat = self._re_compile(pat, flags, case) - f = lambda x: bool(regex.match(x)) + f = lambda x: bool(pat.match(x)) return self._apply(f, dtype=bool) def strip( self, - to_strip: str = None, + to_strip: Union[str, bytes] = None, side: str = "both", ) -> Any: """ @@ -745,7 +786,7 @@ def strip( stripped : same type as values """ if to_strip is not None: - to_strip = self._obj.dtype.type(to_strip) + to_strip = self._stringify(to_strip) if side == "both": f = lambda x: x.strip(to_strip) @@ -760,7 +801,7 @@ def strip( def lstrip( self, - to_strip: str = None, + to_strip: Union[str, bytes] = None, ) -> Any: """ Remove leading characters. @@ -783,7 +824,7 @@ def lstrip( def rstrip( self, - to_strip: str = None, + to_strip: Union[str, bytes] = None, ) -> Any: """ Remove trailing characters. @@ -832,7 +873,7 @@ def wrap( def translate( self, - table: Mapping[str, str], + table: Mapping[Union[str, bytes], Union[str, bytes]], ) -> Any: """ Map characters of each string through the given mapping table. @@ -874,7 +915,7 @@ def repeat( def find( self, - sub: str, + sub: Union[str, bytes], start: int = 0, end: int = None, side: str = "left", @@ -899,7 +940,7 @@ def find( ------- found : array of int """ - sub = self._obj.dtype.type(sub) + sub = self._stringify(sub) if side == "left": method = "find" @@ -917,7 +958,7 @@ def find( def rfind( self, - sub: str, + sub: Union[str, bytes], start: int = 0, end: int = None, ) -> Any: @@ -943,7 +984,7 @@ def rfind( def index( self, - sub: str, + sub: Union[str, bytes], start: int = 0, end: int = None, side: str = "left", @@ -968,8 +1009,13 @@ def index( Returns ------- found : array of int + + Raises + ------ + ValueError + substring is not found """ - sub = self._obj.dtype.type(sub) + sub = self._stringify(sub) if side == "left": method = "index" @@ -987,7 +1033,7 @@ def index( def rindex( self, - sub: str, + sub: Union[str, bytes], start: int = 0, end: int = None, ) -> Any: @@ -1009,13 +1055,18 @@ def rindex( Returns ------- found : array of int + + Raises + ------ + ValueError + substring is not found """ return self.index(sub, start=start, end=end, side="right") def replace( self, - pat: Any, - repl: Union[str, Callable], + pat: Union[str, bytes, Pattern], + repl: Union[str, bytes, Callable], n: int = -1, case: bool = None, flags: int = 0, @@ -1034,18 +1085,20 @@ def replace( See :func:`re.sub`. n : int, default: -1 Number of replacements to make from start. Use ``-1`` to replace all. - case : bool, default: None - - If True, case sensitive (the default if `pat` is a string) - - Set to False for case insensitive - - Cannot be set if `pat` is a compiled regex + case : bool, default: True + If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. flags : int, default: 0 - - re module flags, e.g. re.IGNORECASE. Use ``0`` for no flags. - - Cannot be set if `pat` is a compiled regex + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + Cannot be set if `pat` is a compiled regex. regex : bool, default: True - - If True, assumes the passed-in pattern is a regular expression. - - If False, treats the pattern as a literal string - - Cannot be set to False if `pat` is a compiled regex or `repl` is - a callable. + If True, assumes the passed-in pattern is a regular expression. + If False, treats the pattern as a literal string. + Cannot be set to False if `pat` is a compiled regex or `repl` is + a callable. Returns ------- @@ -1057,9 +1110,9 @@ def replace( raise TypeError("repl must be a string or callable") if _is_str_like(repl): - repl = self._obj.dtype.type(repl) + repl = self._stringify(repl) - is_compiled_re = self._check_is_compiled_re(pat) + is_compiled_re = isinstance(pat, self._pattern_type) if not regex and is_compiled_re: raise ValueError( "Cannot use a compiled regex as replacement pattern with regex=False" @@ -1073,10 +1126,835 @@ def replace( n = n if n >= 0 else 0 f = lambda x: pat.sub(repl=repl, string=x, count=n) else: - pat = self._obj.dtype.type(pat) + pat = self._stringify(pat) f = lambda x: x.replace(pat, repl, n) return self._apply(f) + def extract( + self, + pat: Union[str, bytes, Pattern], + dim: Hashable, + case: bool = None, + flags: int = 0, + ) -> Any: + """ + Extract the first match of capture groups in the regex pat as a new + dimension in a DataArray. + + For each string in the DataArray, extract groups from the first match + of regular expression pat. + + Parameters + ---------- + pat : str or re.Pattern + A string containing a regular expression or a compiled regular + expression object. + dim : hashable or `None` + Name of the new dimension to store the captured strings in. + If None, the pattern must have only one capture group and the + resulting DataArray will have the same size as the original. + case : bool, default: True + If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. + flags : int, default: 0 + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + Cannot be set if `pat` is a compiled regex. + + Returns + ------- + extracted : same type as values or object array + + Raises + ------ + ValueError + `pat` has no capture groups. + ValueError + `dim` is `None` and there is more than one capture group. + ValueError + `case` is set when `pat` is a compiled regular expression. + KeyError + The given dimension is already present in the DataArray. + + Examples + -------- + Create a string array + + >>> value = xr.DataArray( + ... [ + ... [ + ... "a_Xy_0", + ... "ab_xY_10-bab_Xy_110-baab_Xy_1100", + ... "abc_Xy_01-cbc_Xy_2210", + ... ], + ... [ + ... "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + ... "", + ... "abcdef_Xy_101-fef_Xy_5543210", + ... ], + ... ], + ... dims=["X", "Y"], + ... ) + + Extract matches + + >>> value.str.extract(r"(\\w+)_Xy_(\\d*)", dim="match") + + array([[['a', '0'], + ['bab', '110'], + ['abc', '01']], + + [['abcd', ''], + ['', ''], + ['abcdef', '101']]], dtype=' Any: + """ + Extract all matches of capture groups in the regex pat as new + dimensions in a DataArray. + + For each string in the DataArray, extract groups from all matches + of regular expression pat. + Equivalent to applying re.findall() to all the elements in the DataArray + and splitting the results across dimensions. + + Parameters + ---------- + pat : str or re.Pattern + A string containing a regular expression or a compiled regular + expression object. + group_dim: hashable + Name of the new dimensions corresponding to the capture groups. + This dimension is added to the new DataArray first. + match_dim: hashable + Name of the new dimensions corresponding to the matches for each group. + This dimension is added to the new DataArray second. + case : bool, default: True + If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. + flags : int, default: 0 + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + Cannot be set if `pat` is a compiled regex. + + Returns + ------- + extracted : same type as values or object array + + Raises + ------ + ValueError + `pat` has no capture groups. + ValueError + `case` is set when `pat` is a compiled regular expression. + KeyError + Either of the given dimensions is already present in the DataArray. + KeyError + The given dimensions names are the same. + + Examples + -------- + Create a string array + + >>> value = xr.DataArray( + ... [ + ... [ + ... "a_Xy_0", + ... "ab_xY_10-bab_Xy_110-baab_Xy_1100", + ... "abc_Xy_01-cbc_Xy_2210", + ... ], + ... [ + ... "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + ... "", + ... "abcdef_Xy_101-fef_Xy_5543210", + ... ], + ... ], + ... dims=["X", "Y"], + ... ) + + Extract matches + + >>> value.str.extractall( + ... r"(\\w+)_Xy_(\\d*)", group_dim="group", match_dim="match" + ... ) + + array([[[['a', '0'], + ['', ''], + ['', '']], + + [['bab', '110'], + ['baab', '1100'], + ['', '']], + + [['abc', '01'], + ['cbc', '2210'], + ['', '']]], + + + [[['abcd', ''], + ['dcd', '33210'], + ['dccd', '332210']], + + [['', ''], + ['', ''], + ['', '']], + + [['abcdef', '101'], + ['fef', '5543210'], + ['', '']]]], dtype=' Any: + """ + Find all occurrences of pattern or regular expression in the DataArray. + + Equivalent to applying re.findall() to all the elements in the DataArray. + Results in an object array of lists. + If there is only one capture group, the lists will be a sequence of matches. + If there are multiple capture groups, the lists will be a sequence of lists, + each of which contains a sequence of matches. + + Parameters + ---------- + pat : str or re.Pattern + A string containing a regular expression or a compiled regular + expression object. + case : bool, default: True + If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. + flags : int, default: 0 + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + Cannot be set if `pat` is a compiled regex. + + Returns + ------- + extracted : object array + + Raises + ------ + ValueError + `pat` has no capture groups. + ValueError + `case` is set when `pat` is a compiled regular expression. + + Examples + -------- + Create a string array + + >>> value = xr.DataArray( + ... [ + ... [ + ... "a_Xy_0", + ... "ab_xY_10-bab_Xy_110-baab_Xy_1100", + ... "abc_Xy_01-cbc_Xy_2210", + ... ], + ... [ + ... "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + ... "", + ... "abcdef_Xy_101-fef_Xy_5543210", + ... ], + ... ], + ... dims=["X", "Y"], + ... ) + + Extract matches + + >>> value.str.findall(r"(\\w+)_Xy_(\\d*)") + + array([[list([('a', '0')]), list([('bab', '110'), ('baab', '1100')]), + list([('abc', '01'), ('cbc', '2210')])], + [list([('abcd', ''), ('dcd', '33210'), ('dccd', '332210')]), + list([]), list([('abcdef', '101'), ('fef', '5543210')])]], + dtype=object) + Dimensions without coordinates: X, Y + + See Also + -------- + DataArray.str.extract + DataArray.str.extractall + re.compile + re.findall + pandas.Series.str.findall + """ + pat = self._re_compile(pat, flags, case) + + if pat.groups == 0: + raise ValueError("No capture groups found in pattern.") + + return self._apply(pat.findall, dtype=np.object_) + + def _partitioner( + self, + func: Callable, + dim: Hashable, + sep: Optional[Union[str, bytes]], + ) -> Any: + """ + Implements logic for `partition` and `rpartition`. + """ + sep = self._stringify(sep) + + if dim is None: + f = lambda x: list(func(x, sep)) + return self._apply(f, dtype=np.object_) + + # _apply breaks on an empty array in this case + if not self._obj.size: + return self._obj.copy().expand_dims({dim: 0}, -1) + + f = lambda x: np.array(func(x, sep), dtype=self._obj.dtype) + + # dtype MUST be object or strings can be truncated + # See: https://github.com/numpy/numpy/issues/8352 + return self._apply( + f, + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: 3}, + ).astype(self._obj.dtype.kind) + + def partition( + self, + dim: Optional[Hashable], + sep: Union[str, bytes] = " ", + ) -> Any: + """ + Split the strings in the DataArray at the first occurrence of separator `sep`. + + This method splits the string at the first occurrence of `sep`, + and returns 3 elements containing the part before the separator, + the separator itself, and the part after the separator. + If the separator is not found, return 3 elements containing the string itself, + followed by two empty strings. + + This is equivalent to :meth:`str.partion`. + + Parameters + ---------- + dim : Hashable or `None` + Name for the dimension to place the 3 elements in. + If `None`, place the results as list elements in an object DataArray + sep : str, default `" "` + String to split on. + + Returns + ------- + partitioned : same type as values or object array + + See Also + -------- + DataArray.str.rpartition + str.partition + pandas.Series.str.partition + """ + return self._partitioner(func=self._obj.dtype.type.partition, dim=dim, sep=sep) + + def rpartition( + self, + dim: Optional[Hashable], + sep: Union[str, bytes] = " ", + ) -> Any: + """ + Split the strings in the DataArray at the last occurrence of separator `sep`. + + This method splits the string at the last occurrence of `sep`, + and returns 3 elements containing the part before the separator, + the separator itself, and the part after the separator. + If the separator is not found, return 3 elements containing two empty strings, + followed by the string itself. + + This is equivalent to :meth:`str.rpartion`. + + Parameters + ---------- + dim : Hashable or `None` + Name for the dimension to place the 3 elements in. + If `None`, place the results as list elements in an object DataArray + sep : str, default `" "` + String to split on. + + Returns + ------- + rpartitioned : same type as values or object array + + See Also + -------- + DataArray.str.partition + str.rpartition + pandas.Series.str.rpartition + """ + return self._partitioner(func=self._obj.dtype.type.rpartition, dim=dim, sep=sep) + + def _splitter( + self, + func: Callable, + pre: bool, + dim: Hashable, + sep: Optional[Union[str, bytes]], + maxsplit: int, + ) -> Any: + """ + Implements logic for `split` and `rsplit`. + """ + if sep is not None: + sep = self._stringify(sep) + + if dim is None: + f = lambda x: func(x, sep, maxsplit) + return self._apply(f, dtype=np.object_) + + # _apply breaks on an empty array in this case + if not self._obj.size: + return self._obj.copy().expand_dims({dim: 0}, -1) + + f_count = lambda x: max(len(func(x, sep, maxsplit)), 1) + maxsplit = self._apply(f_count, dtype=np.int_).max().data.tolist() - 1 + + def _dosplit(mystr, sep=sep, maxsplit=maxsplit, dtype=self._obj.dtype): + res = func(mystr, sep, maxsplit) + if len(res) < maxsplit + 1: + pad = [""] * (maxsplit + 1 - len(res)) + if pre: + res += pad + else: + res = pad + res + return np.array(res, dtype=dtype) + + # dtype MUST be object or strings can be truncated + # See: https://github.com/numpy/numpy/issues/8352 + return self._apply( + _dosplit, + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: maxsplit}, + ).astype(self._obj.dtype.kind) + + def split( + self, + dim: Optional[Hashable], + sep: Union[str, bytes] = None, + maxsplit: int = -1, + ) -> Any: + """ + Split strings in a DataArray around the given separator/delimiter `sep`. + + Splits the string in the DataArray from the beginning, + at the specified delimiter string. + + This is equivalent to :meth:`str.split`. + + Parameters + ---------- + dim : Hashable or `None` + Name for the dimension to place the results in. + If `None`, place the results as list elements in an object DataArray + sep : str, default is split on any whitespace. + String to split on. + maxsplit : int, default -1 (all) + Limit number of splits in output, starting from the beginning. + -1 will return all splits. + + Returns + ------- + splitted : same type as values or object array + + Examples + -------- + Create a string DataArray + + >>> values = xr.DataArray( + ... [ + ... ["abc def", "spam\\t\\teggs\\tswallow", "red_blue"], + ... ["test0\\ntest1\\ntest2\\n\\ntest3", "", "abra ka\\nda\\tbra"], + ... ], + ... dims=["X", "Y"], + ... ) + + Split once and put the results in a new dimension + + >>> values.str.split(dim="splitted", maxsplit=1) + + array([[['abc', 'def'], + ['spam', 'eggs\\tswallow'], + ['red_blue', '']], + + [['test0', 'test1\\ntest2\\n\\ntest3'], + ['', ''], + ['abra', 'ka\\nda\\tbra']]], dtype='>> values.str.split(dim="splitted") + + array([[['abc', 'def', '', ''], + ['spam', 'eggs', 'swallow', ''], + ['red_blue', '', '', '']], + + [['test0', 'test1', 'test2', 'test3'], + ['', '', '', ''], + ['abra', 'ka', 'da', 'bra']]], dtype='>> values.str.split(dim=None, maxsplit=1) + + array([[list(['abc', 'def']), list(['spam', 'eggs\\tswallow']), + list(['red_blue'])], + [list(['test0', 'test1\\ntest2\\n\\ntest3']), list([]), + list(['abra', 'ka\\nda\\tbra'])]], dtype=object) + Dimensions without coordinates: X, Y + + Split as many times as needed and put the results in a list + + >>> values.str.split(dim=None) + + array([[list(['abc', 'def']), list(['spam', 'eggs', 'swallow']), + list(['red_blue'])], + [list(['test0', 'test1', 'test2', 'test3']), list([]), + list(['abra', 'ka', 'da', 'bra'])]], dtype=object) + Dimensions without coordinates: X, Y + + Split only on spaces + + >>> values.str.split(dim="splitted", sep=" ") + + array([[['abc', 'def', ''], + ['spam\\t\\teggs\\tswallow', '', ''], + ['red_blue', '', '']], + + [['test0\\ntest1\\ntest2\\n\\ntest3', '', ''], + ['', '', ''], + ['abra', '', 'ka\\nda\\tbra']]], dtype=' Any: + """ + Split strings in a DataArray around the given separator/delimiter `sep`. + + Splits the string in the DataArray from the end, + at the specified delimiter string. + + This is equivalent to :meth:`str.rsplit`. + + Parameters + ---------- + dim : Hashable or `None` + Name for the dimension to place the results in. + If `None`, place the results as list elements in an object DataArray + sep : str, default is split on any whitespace. + String to split on. + maxsplit : int, default -1 (all) + Limit number of splits in output, starting from the end. + -1 will return all splits. + The final number of split values may be less than this if there are no + DataArray elements with that many values. + + Returns + ------- + rsplitted : same type as values or object array + + Examples + -------- + Create a string DataArray + + >>> values = xr.DataArray( + ... [ + ... ["abc def", "spam\\t\\teggs\\tswallow", "red_blue"], + ... ["test0\\ntest1\\ntest2\\n\\ntest3", "", "abra ka\\nda\\tbra"], + ... ], + ... dims=["X", "Y"], + ... ) + + Split once and put the results in a new dimension + + >>> values.str.rsplit(dim="splitted", maxsplit=1) + + array([[['abc', 'def'], + ['spam\\t\\teggs', 'swallow'], + ['', 'red_blue']], + + [['test0\\ntest1\\ntest2', 'test3'], + ['', ''], + ['abra ka\\nda', 'bra']]], dtype='>> values.str.rsplit(dim="splitted") + + array([[['', '', 'abc', 'def'], + ['', 'spam', 'eggs', 'swallow'], + ['', '', '', 'red_blue']], + + [['test0', 'test1', 'test2', 'test3'], + ['', '', '', ''], + ['abra', 'ka', 'da', 'bra']]], dtype='>> values.str.rsplit(dim=None, maxsplit=1) + + array([[list(['abc', 'def']), list(['spam\\t\\teggs', 'swallow']), + list(['red_blue'])], + [list(['test0\\ntest1\\ntest2', 'test3']), list([]), + list(['abra ka\\nda', 'bra'])]], dtype=object) + Dimensions without coordinates: X, Y + + Split as many times as needed and put the results in a list + + >>> values.str.rsplit(dim=None) + + array([[list(['abc', 'def']), list(['spam', 'eggs', 'swallow']), + list(['red_blue'])], + [list(['test0', 'test1', 'test2', 'test3']), list([]), + list(['abra', 'ka', 'da', 'bra'])]], dtype=object) + Dimensions without coordinates: X, Y + + Split only on spaces + + >>> values.str.rsplit(dim="splitted", sep=" ") + + array([[['', 'abc', 'def'], + ['', '', 'spam\\t\\teggs\\tswallow'], + ['', '', 'red_blue']], + + [['', '', 'test0\\ntest1\\ntest2\\n\\ntest3'], + ['', '', ''], + ['abra', '', 'ka\\nda\\tbra']]], dtype=' Any: + """ + Return DataArray of dummy/indicator variables. + + Each string in the DataArray is split at `sep`. + A new dimension is created with coordinates for each unique result, + and the corresponding element of that dimension is `True` if + that result is present and `False` if not. + + Parameters + ---------- + dim : Hashable + Name for the dimension to place the results in. + sep : str, default `"|"`. + String to split on. + + Returns + ------- + dummies : array of bool + + Examples + -------- + Create a string array + + >>> values = xr.DataArray( + ... [ + ... ["a|ab~abc|abc", "ab", "a||abc|abcd"], + ... ["abcd|ab|a", "abc|ab~abc", "|a"], + ... ], + ... dims=["X", "Y"], + ... ) + + Extract dummy values + + >>> values.str.get_dummies(dim="dummies") + + array([[[ True, False, True, False, True], + [False, True, False, False, False], + [ True, False, True, True, False]], + + [[ True, True, False, True, False], + [False, False, True, False, True], + [ True, False, False, False, False]]]) + Coordinates: + * dummies (dummies) \w+) (?P\w+) (?P\w+)" repl = lambda m: m.group("middle").swapcase() result = values.str.replace(pat, repl) exp = xr.DataArray(["bAR"]) + assert result.dtype == exp.dtype assert_equal(result, exp) @@ -197,6 +253,7 @@ def test_replace_unicode(): expected = xr.DataArray([b"abcd, \xc3\xa0".decode("utf-8")]) pat = re.compile(r"(?<=\w),(?=\w)", flags=re.UNICODE) result = values.str.replace(pat, ", ") + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -206,10 +263,12 @@ def test_replace_compiled_regex(dtype): pat = re.compile(dtype("BAD[_]*")) result = values.str.replace(pat, "") expected = xr.DataArray(["foobar"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.replace(pat, "", n=1) expected = xr.DataArray(["foobarBAD"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) # case and flags provided to str.replace will have no effect @@ -232,6 +291,7 @@ def test_replace_compiled_regex(dtype): pat = re.compile(dtype("[a-z][A-Z]{2}")) result = values.str.replace(pat, repl, n=2) expected = xr.DataArray(["foObaD__baRbaD"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -240,10 +300,12 @@ def test_replace_literal(dtype): values = xr.DataArray(["f.o", "foo"]).astype(dtype) expected = xr.DataArray(["bao", "bao"]).astype(dtype) result = values.str.replace("f.", "ba") + assert result.dtype == expected.dtype assert_equal(result, expected) expected = xr.DataArray(["bao", "foo"]).astype(dtype) result = values.str.replace("f.", "ba", regex=False) + assert result.dtype == expected.dtype assert_equal(result, expected) # Cannot do a literal replace if given a callable repl or compiled @@ -260,10 +322,893 @@ def test_replace_literal(dtype): values.str.replace(compiled_pat, "", regex=False) +def test_extract_extractall_findall_empty_raises(dtype): + pat_str = r"a_\w+_b_\d+_c_.*" + pat_re = re.compile(pat_str) + + value = xr.DataArray( + [ + ["a_first_b_1_c_de", "a_second_b_22_c_efh", "a_third_b_333_c_hijk"], + [ + "a_fourth_b_4444_c_klmno", + "a_fifth_b_5555_c_opqr", + "a_sixth_b_66666_c_rst", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + with pytest.raises(ValueError): + value.str.extract(pat=pat_str, dim="ZZ") + + with pytest.raises(ValueError): + value.str.extract(pat=pat_re, dim="ZZ") + + with pytest.raises(ValueError): + value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + + with pytest.raises(ValueError): + value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + + with pytest.raises(ValueError): + value.str.findall(pat=pat_str) + + with pytest.raises(ValueError): + value.str.findall(pat=pat_re) + + +def test_extract_multi_None_raises(dtype): + pat_str = r"a_(\w+)_b_(\d+)_c_.*" + pat_re = re.compile(pat_str) + + value = xr.DataArray( + [ + ["a_first_b_1_c_de", "a_second_b_22_c_efh", "a_third_b_333_c_hijk"], + [ + "a_fourth_b_4444_c_klmno", + "a_fifth_b_5555_c_opqr", + "a_sixth_b_66666_c_rst", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + with pytest.raises(ValueError): + value.str.extract(pat=pat_str, dim=None) + + with pytest.raises(ValueError): + value.str.extract(pat=pat_re, dim=None) + + +def test_extract_extractall_findall_case_re_raises(dtype): + pat_str = r"a_\w+_b_\d+_c_.*" + pat_re = re.compile(pat_str) + + value = xr.DataArray( + [ + ["a_first_b_1_c_de", "a_second_b_22_c_efh", "a_third_b_333_c_hijk"], + [ + "a_fourth_b_4444_c_klmno", + "a_fifth_b_5555_c_opqr", + "a_sixth_b_66666_c_rst", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + with pytest.raises(ValueError): + value.str.extract(pat=pat_re, case=True, dim="ZZ") + + with pytest.raises(ValueError): + value.str.extract(pat=pat_re, case=False, dim="ZZ") + + with pytest.raises(ValueError): + value.str.extractall(pat=pat_re, case=True, group_dim="XX", match_dim="YY") + + with pytest.raises(ValueError): + value.str.extractall(pat=pat_re, case=False, group_dim="XX", match_dim="YY") + + with pytest.raises(ValueError): + value.str.findall(pat=pat_re, case=True) + + with pytest.raises(ValueError): + value.str.findall(pat=pat_re, case=False) + + +def test_extract_extractall_name_collision_raises(dtype): + pat_str = r"a_(\w+)_b_\d+_c_.*" + pat_re = re.compile(pat_str) + + value = xr.DataArray( + [ + ["a_first_b_1_c_de", "a_second_b_22_c_efh", "a_third_b_333_c_hijk"], + [ + "a_fourth_b_4444_c_klmno", + "a_fifth_b_5555_c_opqr", + "a_sixth_b_66666_c_rst", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + with pytest.raises(KeyError): + value.str.extract(pat=pat_str, dim="X") + + with pytest.raises(KeyError): + value.str.extract(pat=pat_re, dim="X") + + with pytest.raises(KeyError): + value.str.extractall(pat=pat_str, group_dim="X", match_dim="ZZ") + + with pytest.raises(KeyError): + value.str.extractall(pat=pat_re, group_dim="X", match_dim="YY") + + with pytest.raises(KeyError): + value.str.extractall(pat=pat_str, group_dim="XX", match_dim="Y") + + with pytest.raises(KeyError): + value.str.extractall(pat=pat_re, group_dim="XX", match_dim="Y") + + with pytest.raises(KeyError): + value.str.extractall(pat=pat_str, group_dim="ZZ", match_dim="ZZ") + + with pytest.raises(KeyError): + value.str.extractall(pat=pat_re, group_dim="ZZ", match_dim="ZZ") + + +def test_extract_single_case(dtype): + pat_str = r"(\w+)_Xy_\d*" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str)) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ_none = xr.DataArray( + [["a", "bab", "abc"], ["abcd", "", "abcdef"]], dims=["X", "Y"] + ).astype(dtype) + targ_dim = xr.DataArray( + [[["a"], ["bab"], ["abc"]], [["abcd"], [""], ["abcdef"]]], dims=["X", "Y", "XX"] + ).astype(dtype) + + res_str_none = value.str.extract(pat=pat_str, dim=None) + res_str_dim = value.str.extract(pat=pat_str, dim="XX") + res_str_none_case = value.str.extract(pat=pat_str, dim=None, case=True) + res_str_dim_case = value.str.extract(pat=pat_str, dim="XX", case=True) + res_re_none = value.str.extract(pat=pat_re, dim=None) + res_re_dim = value.str.extract(pat=pat_re, dim="XX") + + assert res_str_none.dtype == targ_none.dtype + assert res_str_dim.dtype == targ_dim.dtype + assert res_str_none_case.dtype == targ_none.dtype + assert res_str_dim_case.dtype == targ_dim.dtype + assert res_re_none.dtype == targ_none.dtype + assert res_re_dim.dtype == targ_dim.dtype + + assert_equal(res_str_none, targ_none) + assert_equal(res_str_dim, targ_dim) + assert_equal(res_str_none_case, targ_none) + assert_equal(res_str_dim_case, targ_dim) + assert_equal(res_re_none, targ_none) + assert_equal(res_re_dim, targ_dim) + + +def test_extract_single_nocase(dtype): + pat_str = r"(\w+)_Xy_\d*" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str), flags=re.IGNORECASE) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ_none = xr.DataArray( + [["a", "ab", "abc"], ["abcd", "", "abcdef"]], dims=["X", "Y"] + ).astype(dtype) + targ_dim = xr.DataArray( + [[["a"], ["ab"], ["abc"]], [["abcd"], [""], ["abcdef"]]], dims=["X", "Y", "XX"] + ).astype(dtype) + + res_str_none = value.str.extract(pat=pat_str, dim=None, case=False) + res_str_dim = value.str.extract(pat=pat_str, dim="XX", case=False) + res_re_none = value.str.extract(pat=pat_re, dim=None) + res_re_dim = value.str.extract(pat=pat_re, dim="XX") + + assert res_re_dim.dtype == targ_none.dtype + assert res_str_dim.dtype == targ_dim.dtype + assert res_re_none.dtype == targ_none.dtype + assert res_re_dim.dtype == targ_dim.dtype + + assert_equal(res_str_none, targ_none) + assert_equal(res_str_dim, targ_dim) + assert_equal(res_re_none, targ_none) + assert_equal(res_re_dim, targ_dim) + + +def test_extract_multi_case(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str)) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ = xr.DataArray( + [ + [["a", "0"], ["bab", "110"], ["abc", "01"]], + [["abcd", ""], ["", ""], ["abcdef", "101"]], + ], + dims=["X", "Y", "XX"], + ).astype(dtype) + + res_str = value.str.extract(pat=pat_str, dim="XX") + res_re = value.str.extract(pat=pat_re, dim="XX") + res_str_case = value.str.extract(pat=pat_str, dim="XX", case=True) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + assert res_str_case.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + assert_equal(res_str_case, targ) + + +def test_extract_multi_nocase(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str), flags=re.IGNORECASE) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ = xr.DataArray( + [ + [["a", "0"], ["ab", "10"], ["abc", "01"]], + [["abcd", ""], ["", ""], ["abcdef", "101"]], + ], + dims=["X", "Y", "XX"], + ).astype(dtype) + + res_str = value.str.extract(pat=pat_str, dim="XX", case=False) + res_re = value.str.extract(pat=pat_re, dim="XX") + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + + +def test_extractall_single_single_case(dtype): + pat_str = r"(\w+)_Xy_\d*" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str)) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + targ = xr.DataArray( + [[[["a"]], [[""]], [["abc"]]], [[["abcd"]], [[""]], [["abcdef"]]]], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + res_str_case = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=True + ) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + assert res_str_case.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + assert_equal(res_str_case, targ) + + +def test_extractall_single_single_nocase(dtype): + pat_str = r"(\w+)_Xy_\d*" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str), flags=re.I) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + targ = xr.DataArray( + [[[["a"]], [["ab"]], [["abc"]]], [[["abcd"]], [[""]], [["abcdef"]]]], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=False + ) + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + + +def test_extractall_single_multi_case(dtype): + pat_str = r"(\w+)_Xy_\d*" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str)) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ = xr.DataArray( + [ + [[["a"], [""], [""]], [["bab"], ["baab"], [""]], [["abc"], ["cbc"], [""]]], + [ + [["abcd"], ["dcd"], ["dccd"]], + [[""], [""], [""]], + [["abcdef"], ["fef"], [""]], + ], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + res_str_case = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=True + ) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + assert res_str_case.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + assert_equal(res_str_case, targ) + + +def test_extractall_single_multi_nocase(dtype): + pat_str = r"(\w+)_Xy_\d*" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str), flags=re.I) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ = xr.DataArray( + [ + [ + [["a"], [""], [""]], + [["ab"], ["bab"], ["baab"]], + [["abc"], ["cbc"], [""]], + ], + [ + [["abcd"], ["dcd"], ["dccd"]], + [[""], [""], [""]], + [["abcdef"], ["fef"], [""]], + ], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=False + ) + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + + +def test_extractall_multi_single_case(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str)) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + targ = xr.DataArray( + [ + [[["a", "0"]], [["", ""]], [["abc", "01"]]], + [[["abcd", ""]], [["", ""]], [["abcdef", "101"]]], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + res_str_case = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=True + ) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + assert res_str_case.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + assert_equal(res_str_case, targ) + + +def test_extractall_multi_single_nocase(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str), flags=re.I) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + targ = xr.DataArray( + [ + [[["a", "0"]], [["ab", "10"]], [["abc", "01"]]], + [[["abcd", ""]], [["", ""]], [["abcdef", "101"]]], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=False + ) + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + + +def test_extractall_multi_multi_case(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str)) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ = xr.DataArray( + [ + [ + [["a", "0"], ["", ""], ["", ""]], + [["bab", "110"], ["baab", "1100"], ["", ""]], + [["abc", "01"], ["cbc", "2210"], ["", ""]], + ], + [ + [["abcd", ""], ["dcd", "33210"], ["dccd", "332210"]], + [["", ""], ["", ""], ["", ""]], + [["abcdef", "101"], ["fef", "5543210"], ["", ""]], + ], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + res_str_case = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=True + ) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + assert res_str_case.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + assert_equal(res_str_case, targ) + + +def test_extractall_multi_multi_nocase(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str), flags=re.I) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ = xr.DataArray( + [ + [ + [["a", "0"], ["", ""], ["", ""]], + [["ab", "10"], ["bab", "110"], ["baab", "1100"]], + [["abc", "01"], ["cbc", "2210"], ["", ""]], + ], + [ + [["abcd", ""], ["dcd", "33210"], ["dccd", "332210"]], + [["", ""], ["", ""], ["", ""]], + [["abcdef", "101"], ["fef", "5543210"], ["", ""]], + ], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=False + ) + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + + +def test_findall_single_single_case(dtype): + pat_str = r"(\w+)_Xy_\d*" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str)) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + targ = [[["a"], [], ["abc"]], [["abcd"], [], ["abcdef"]]] + targ = [[[conv(x) for x in y] for y in z] for z in targ] + targ = np.array(targ, dtype=np.object_) + targ = xr.DataArray(targ, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + res_str_case = value.str.findall(pat=pat_str, case=True) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + assert res_str_case.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + assert_equal(res_str_case, targ) + + +def test_findall_single_single_nocase(dtype): + pat_str = r"(\w+)_Xy_\d*" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str), flags=re.I) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + targ = [[["a"], ["ab"], ["abc"]], [["abcd"], [], ["abcdef"]]] + targ = [[[conv(x) for x in y] for y in z] for z in targ] + targ = np.array(targ, dtype=np.object_) + print(targ) + targ = xr.DataArray(targ, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str, case=False) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + + +def test_findall_single_multi_case(dtype): + pat_str = r"(\w+)_Xy_\d*" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str)) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ = [ + [["a"], ["bab", "baab"], ["abc", "cbc"]], + [ + ["abcd", "dcd", "dccd"], + [], + ["abcdef", "fef"], + ], + ] + targ = [[[conv(x) for x in y] for y in z] for z in targ] + targ = np.array(targ, dtype=np.object_) + targ = xr.DataArray(targ, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + res_str_case = value.str.findall(pat=pat_str, case=True) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + assert res_str_case.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + assert_equal(res_str_case, targ) + + +def test_findall_single_multi_nocase(dtype): + pat_str = r"(\w+)_Xy_\d*" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str), flags=re.I) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ = [ + [ + ["a"], + ["ab", "bab", "baab"], + ["abc", "cbc"], + ], + [ + ["abcd", "dcd", "dccd"], + [], + ["abcdef", "fef"], + ], + ] + targ = [[[conv(x) for x in y] for y in z] for z in targ] + targ = np.array(targ, dtype=np.object_) + targ = xr.DataArray(targ, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str, case=False) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + + +def test_findall_multi_single_case(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str)) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + targ = [ + [[["a", "0"]], [], [["abc", "01"]]], + [[["abcd", ""]], [], [["abcdef", "101"]]], + ] + targ = [[[tuple(conv(x) for x in y) for y in z] for z in w] for w in targ] + targ = np.array(targ, dtype=np.object_) + targ = xr.DataArray(targ, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + res_str_case = value.str.findall(pat=pat_str, case=True) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + assert res_str_case.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + assert_equal(res_str_case, targ) + + +def test_findall_multi_single_nocase(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str), flags=re.I) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + targ = [ + [[["a", "0"]], [["ab", "10"]], [["abc", "01"]]], + [[["abcd", ""]], [], [["abcdef", "101"]]], + ] + targ = [[[tuple(conv(x) for x in y) for y in z] for z in w] for w in targ] + targ = np.array(targ, dtype=np.object_) + targ = xr.DataArray(targ, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str, case=False) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + + +def test_findall_multi_multi_case(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str)) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ = [ + [ + [["a", "0"]], + [["bab", "110"], ["baab", "1100"]], + [["abc", "01"], ["cbc", "2210"]], + ], + [ + [["abcd", ""], ["dcd", "33210"], ["dccd", "332210"]], + [], + [["abcdef", "101"], ["fef", "5543210"]], + ], + ] + targ = [[[tuple(conv(x) for x in y) for y in z] for z in w] for w in targ] + targ = np.array(targ, dtype=np.object_) + targ = xr.DataArray(targ, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + res_str_case = value.str.findall(pat=pat_str, case=True) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + assert res_str_case.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + assert_equal(res_str_case, targ) + + +def test_findall_multi_multi_nocase(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + pat_re = re.compile(conv(pat_str), flags=re.I) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ = [ + [ + [["a", "0"]], + [["ab", "10"], ["bab", "110"], ["baab", "1100"]], + [["abc", "01"], ["cbc", "2210"]], + ], + [ + [["abcd", ""], ["dcd", "33210"], ["dccd", "332210"]], + [], + [["abcdef", "101"], ["fef", "5543210"]], + ], + ] + targ = [[[tuple(conv(x) for x in y) for y in z] for z in w] for w in targ] + targ = np.array(targ, dtype=np.object_) + targ = xr.DataArray(targ, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str, case=False) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == targ.dtype + assert res_re.dtype == targ.dtype + + assert_equal(res_str, targ) + assert_equal(res_re, targ) + + def test_repeat(dtype): values = xr.DataArray(["a", "b", "c", "d"]).astype(dtype) result = values.str.repeat(3) expected = xr.DataArray(["aaa", "bbb", "ccc", "ddd"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -272,11 +1217,13 @@ def test_match(dtype): values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) result = values.str.match(".*(BAD[_]+).*(BAD)") expected = xr.DataArray([True, False]) + assert result.dtype == expected.dtype assert_equal(result, expected) values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) result = values.str.match(".*BAD[_]+.*BAD") expected = xr.DataArray([True, False]) + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -287,63 +1234,129 @@ def test_empty_str_methods(): empty_bool = xr.DataArray(np.empty(shape=(0,), dtype=bool)) empty_bytes = xr.DataArray(np.empty(shape=(0,), dtype="S")) - assert_equal(empty_str, empty.str.title()) - assert_equal(empty_int, empty.str.count("a")) + # TODO: Determine why U and S dtype sizes don't match and figure + # out a reliable way to predict what they should be + + assert empty_bool.dtype == empty.str.contains("a").dtype + assert empty_bool.dtype == empty.str.endswith("a").dtype + assert empty_bool.dtype == empty.str.match("^a").dtype + assert empty_bool.dtype == empty.str.startswith("a").dtype + assert empty_bool.dtype == empty.str.isalnum().dtype + assert empty_bool.dtype == empty.str.isalpha().dtype + assert empty_bool.dtype == empty.str.isdecimal().dtype + assert empty_bool.dtype == empty.str.isdigit().dtype + assert empty_bool.dtype == empty.str.islower().dtype + assert empty_bool.dtype == empty.str.isnumeric().dtype + assert empty_bool.dtype == empty.str.isspace().dtype + assert empty_bool.dtype == empty.str.istitle().dtype + assert empty_bool.dtype == empty.str.isupper().dtype + assert empty_bytes.dtype.kind == empty.str.encode("ascii").dtype.kind + assert empty_int.dtype.kind == empty.str.count("a").dtype.kind + assert empty_int.dtype.kind == empty.str.find("a").dtype.kind + assert empty_int.dtype.kind == empty.str.len().dtype.kind + assert empty_int.dtype.kind == empty.str.rfind("a").dtype.kind + assert empty_str.dtype.kind == empty.str.capitalize().dtype.kind + assert empty_str.dtype.kind == empty.str.center(42).dtype.kind + assert empty_str.dtype.kind == empty.str.get(0).dtype.kind + assert empty_str.dtype.kind == empty.str.lower().dtype.kind + assert empty_str.dtype.kind == empty.str.lstrip().dtype.kind + assert empty_str.dtype.kind == empty.str.pad(42).dtype.kind + assert empty_str.dtype.kind == empty.str.repeat(3).dtype.kind + assert empty_str.dtype.kind == empty.str.rstrip().dtype.kind + assert empty_str.dtype.kind == empty.str.slice(step=1).dtype.kind + assert empty_str.dtype.kind == empty.str.slice(stop=1).dtype.kind + assert empty_str.dtype.kind == empty.str.strip().dtype.kind + assert empty_str.dtype.kind == empty.str.swapcase().dtype.kind + assert empty_str.dtype.kind == empty.str.title().dtype.kind + assert empty_str.dtype.kind == empty.str.upper().dtype.kind + assert empty_str.dtype.kind == empty.str.wrap(42).dtype.kind + assert empty_str.dtype.kind == empty_bytes.str.decode("ascii").dtype.kind + assert_equal(empty_bool, empty.str.contains("a")) - assert_equal(empty_bool, empty.str.startswith("a")) assert_equal(empty_bool, empty.str.endswith("a")) - assert_equal(empty_str, empty.str.lower()) - assert_equal(empty_str, empty.str.upper()) - assert_equal(empty_str, empty.str.replace("a", "b")) - assert_equal(empty_str, empty.str.repeat(3)) assert_equal(empty_bool, empty.str.match("^a")) - assert_equal(empty_int, empty.str.len()) + assert_equal(empty_bool, empty.str.startswith("a")) + assert_equal(empty_bool, empty.str.isalnum()) + assert_equal(empty_bool, empty.str.isalpha()) + assert_equal(empty_bool, empty.str.isdecimal()) + assert_equal(empty_bool, empty.str.isdigit()) + assert_equal(empty_bool, empty.str.islower()) + assert_equal(empty_bool, empty.str.isnumeric()) + assert_equal(empty_bool, empty.str.isspace()) + assert_equal(empty_bool, empty.str.istitle()) + assert_equal(empty_bool, empty.str.isupper()) + assert_equal(empty_bytes, empty.str.encode("ascii")) + assert_equal(empty_int, empty.str.count("a")) assert_equal(empty_int, empty.str.find("a")) + assert_equal(empty_int, empty.str.len()) assert_equal(empty_int, empty.str.rfind("a")) - assert_equal(empty_str, empty.str.pad(42)) + assert_equal(empty_str, empty.str.capitalize()) assert_equal(empty_str, empty.str.center(42)) - assert_equal(empty_str, empty.str.slice(stop=1)) - assert_equal(empty_str, empty.str.slice(step=1)) - assert_equal(empty_str, empty.str.strip()) + assert_equal(empty_str, empty.str.get(0)) + assert_equal(empty_str, empty.str.lower()) assert_equal(empty_str, empty.str.lstrip()) + assert_equal(empty_str, empty.str.pad(42)) + assert_equal(empty_str, empty.str.repeat(3)) + assert_equal(empty_str, empty.str.replace("a", "b")) assert_equal(empty_str, empty.str.rstrip()) + assert_equal(empty_str, empty.str.slice(step=1)) + assert_equal(empty_str, empty.str.slice(stop=1)) + assert_equal(empty_str, empty.str.strip()) + assert_equal(empty_str, empty.str.swapcase()) + assert_equal(empty_str, empty.str.title()) + assert_equal(empty_str, empty.str.upper()) assert_equal(empty_str, empty.str.wrap(42)) - assert_equal(empty_str, empty.str.get(0)) assert_equal(empty_str, empty_bytes.str.decode("ascii")) - assert_equal(empty_bytes, empty.str.encode("ascii")) - assert_equal(empty_str, empty.str.isalnum()) - assert_equal(empty_str, empty.str.isalpha()) - assert_equal(empty_str, empty.str.isdigit()) - assert_equal(empty_str, empty.str.isspace()) - assert_equal(empty_str, empty.str.islower()) - assert_equal(empty_str, empty.str.isupper()) - assert_equal(empty_str, empty.str.istitle()) - assert_equal(empty_str, empty.str.isnumeric()) - assert_equal(empty_str, empty.str.isdecimal()) - assert_equal(empty_str, empty.str.capitalize()) - assert_equal(empty_str, empty.str.swapcase()) + table = str.maketrans("a", "b") + assert empty_str.dtype.kind == empty.str.translate(table).dtype.kind assert_equal(empty_str, empty.str.translate(table)) def test_ismethods(dtype): values = ["A", "b", "Xy", "4", "3A", "", "TT", "55", "-", " "] - str_s = xr.DataArray(values).astype(dtype) - alnum_e = [True, True, True, True, True, False, True, True, False, False] - alpha_e = [True, True, True, False, False, False, True, False, False, False] - digit_e = [False, False, False, True, False, False, False, True, False, False] - space_e = [False, False, False, False, False, False, False, False, False, True] - lower_e = [False, True, False, False, False, False, False, False, False, False] - upper_e = [True, False, False, False, True, False, True, False, False, False] - title_e = [True, False, True, False, True, False, False, False, False, False] - - assert_equal(str_s.str.isalnum(), xr.DataArray(alnum_e)) - assert_equal(str_s.str.isalpha(), xr.DataArray(alpha_e)) - assert_equal(str_s.str.isdigit(), xr.DataArray(digit_e)) - assert_equal(str_s.str.isspace(), xr.DataArray(space_e)) - assert_equal(str_s.str.islower(), xr.DataArray(lower_e)) - assert_equal(str_s.str.isupper(), xr.DataArray(upper_e)) - assert_equal(str_s.str.istitle(), xr.DataArray(title_e)) + + exp_alnum = [True, True, True, True, True, False, True, True, False, False] + exp_alpha = [True, True, True, False, False, False, True, False, False, False] + exp_digit = [False, False, False, True, False, False, False, True, False, False] + exp_space = [False, False, False, False, False, False, False, False, False, True] + exp_lower = [False, True, False, False, False, False, False, False, False, False] + exp_upper = [True, False, False, False, True, False, True, False, False, False] + exp_title = [True, False, True, False, True, False, False, False, False, False] + + values = xr.DataArray(values).astype(dtype) + + exp_alnum = xr.DataArray(exp_alnum) + exp_alpha = xr.DataArray(exp_alpha) + exp_digit = xr.DataArray(exp_digit) + exp_space = xr.DataArray(exp_space) + exp_lower = xr.DataArray(exp_lower) + exp_upper = xr.DataArray(exp_upper) + exp_title = xr.DataArray(exp_title) + + res_alnum = values.str.isalnum() + res_alpha = values.str.isalpha() + res_digit = values.str.isdigit() + res_lower = values.str.islower() + res_space = values.str.isspace() + res_title = values.str.istitle() + res_upper = values.str.isupper() + + assert res_alnum.dtype == exp_alnum.dtype + assert res_alpha.dtype == exp_alpha.dtype + assert res_digit.dtype == exp_digit.dtype + assert res_lower.dtype == exp_lower.dtype + assert res_space.dtype == exp_space.dtype + assert res_title.dtype == exp_title.dtype + assert res_upper.dtype == exp_upper.dtype + + assert_equal(res_alnum, exp_alnum) + assert_equal(res_alpha, exp_alpha) + assert_equal(res_digit, exp_digit) + assert_equal(res_lower, exp_lower) + assert_equal(res_space, exp_space) + assert_equal(res_title, exp_title) + assert_equal(res_upper, exp_upper) def test_isnumeric(): @@ -352,17 +1365,28 @@ def test_isnumeric(): # 0x1378: ፸ ETHIOPIC NUMBER SEVENTY # 0xFF13: 3 Em 3 values = ["A", "3", "¼", "★", "፸", "3", "four"] - s = xr.DataArray(values) - numeric_e = [False, True, True, False, True, True, False] - decimal_e = [False, True, False, False, False, True, False] - assert_equal(s.str.isnumeric(), xr.DataArray(numeric_e)) - assert_equal(s.str.isdecimal(), xr.DataArray(decimal_e)) + exp_numeric = [False, True, True, False, True, True, False] + exp_decimal = [False, True, False, False, False, True, False] + + values = xr.DataArray(values) + exp_numeric = xr.DataArray(exp_numeric) + exp_decimal = xr.DataArray(exp_decimal) + + res_numeric = values.str.isnumeric() + res_decimal = values.str.isdecimal() + + assert res_numeric.dtype == exp_numeric.dtype + assert res_decimal.dtype == exp_decimal.dtype + + assert_equal(res_numeric, exp_numeric) + assert_equal(res_decimal, exp_decimal) def test_len(dtype): values = ["foo", "fooo", "fooooo", "fooooooo"] result = xr.DataArray(values).astype(dtype).str.len() expected = xr.DataArray([len(x) for x in values]) + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -370,33 +1394,51 @@ def test_find(dtype): values = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"]) values = values.astype(dtype) result = values.str.find("EF") - assert_equal(result, xr.DataArray([4, 3, 1, 0, -1])) + expected = xr.DataArray([4, 3, 1, 0, -1]) + assert result.dtype == expected.dtype + assert_equal(result, expected) expected = xr.DataArray([v.find(dtype("EF")) for v in values.values]) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.rfind("EF") - assert_equal(result, xr.DataArray([4, 5, 7, 4, -1])) + expected = xr.DataArray([4, 5, 7, 4, -1]) + assert result.dtype == expected.dtype + assert_equal(result, expected) expected = xr.DataArray([v.rfind(dtype("EF")) for v in values.values]) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.find("EF", 3) - assert_equal(result, xr.DataArray([4, 3, 7, 4, -1])) + expected = xr.DataArray([4, 3, 7, 4, -1]) + assert result.dtype == expected.dtype + assert_equal(result, expected) expected = xr.DataArray([v.find(dtype("EF"), 3) for v in values.values]) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.rfind("EF", 3) - assert_equal(result, xr.DataArray([4, 5, 7, 4, -1])) + expected = xr.DataArray([4, 5, 7, 4, -1]) + assert result.dtype == expected.dtype + assert_equal(result, expected) expected = xr.DataArray([v.rfind(dtype("EF"), 3) for v in values.values]) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.find("EF", 3, 6) - assert_equal(result, xr.DataArray([4, 3, -1, 4, -1])) + expected = xr.DataArray([4, 3, -1, 4, -1]) + assert result.dtype == expected.dtype + assert_equal(result, expected) expected = xr.DataArray([v.find(dtype("EF"), 3, 6) for v in values.values]) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.rfind("EF", 3, 6) - assert_equal(result, xr.DataArray([4, 3, -1, 4, -1])) + expected = xr.DataArray([4, 3, -1, 4, -1]) + assert result.dtype == expected.dtype + assert_equal(result, expected) xp = xr.DataArray([v.rfind(dtype("EF"), 3, 6) for v in values.values]) + assert result.dtype == xp.dtype assert_equal(result, xp) @@ -404,22 +1446,34 @@ def test_index(dtype): s = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"]).astype(dtype) result = s.str.index("EF") - assert_equal(result, xr.DataArray([4, 3, 1, 0])) + expected = xr.DataArray([4, 3, 1, 0]) + assert result.dtype == expected.dtype + assert_equal(result, expected) result = s.str.rindex("EF") - assert_equal(result, xr.DataArray([4, 5, 7, 4])) + expected = xr.DataArray([4, 5, 7, 4]) + assert result.dtype == expected.dtype + assert_equal(result, expected) result = s.str.index("EF", 3) - assert_equal(result, xr.DataArray([4, 3, 7, 4])) + expected = xr.DataArray([4, 3, 7, 4]) + assert result.dtype == expected.dtype + assert_equal(result, expected) result = s.str.rindex("EF", 3) - assert_equal(result, xr.DataArray([4, 5, 7, 4])) + expected = xr.DataArray([4, 5, 7, 4]) + assert result.dtype == expected.dtype + assert_equal(result, expected) result = s.str.index("E", 4, 8) - assert_equal(result, xr.DataArray([4, 5, 7, 4])) + expected = xr.DataArray([4, 5, 7, 4]) + assert result.dtype == expected.dtype + assert_equal(result, expected) result = s.str.rindex("E", 0, 5) - assert_equal(result, xr.DataArray([4, 3, 1, 4])) + expected = xr.DataArray([4, 3, 1, 4]) + assert result.dtype == expected.dtype + assert_equal(result, expected) with pytest.raises(ValueError): result = s.str.index("DE") @@ -430,14 +1484,17 @@ def test_pad(dtype): result = values.str.pad(5, side="left") expected = xr.DataArray([" a", " b", " c", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.pad(5, side="right") expected = xr.DataArray(["a ", "b ", "c ", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.pad(5, side="both") expected = xr.DataArray([" a ", " b ", " c ", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -446,14 +1503,17 @@ def test_pad_fillchar(dtype): result = values.str.pad(5, side="left", fillchar="X") expected = xr.DataArray(["XXXXa", "XXXXb", "XXXXc", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.pad(5, side="right", fillchar="X") expected = xr.DataArray(["aXXXX", "bXXXX", "cXXXX", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.pad(5, side="both", fillchar="X") expected = xr.DataArray(["XXaXX", "XXbXX", "XXcXX", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) msg = "fillchar must be a character, not str" @@ -466,6 +1526,7 @@ def test_translate(): table = str.maketrans("abc", "cde") result = values.str.translate(table) expected = xr.DataArray(["cdedefg", "cdee", "edddfg", "edefggg"]) + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -474,14 +1535,17 @@ def test_center_ljust_rjust(dtype): result = values.str.center(5) expected = xr.DataArray([" a ", " b ", " c ", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.ljust(5) expected = xr.DataArray(["a ", "b ", "c ", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.rjust(5) expected = xr.DataArray([" a", " b", " c", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -489,14 +1553,17 @@ def test_center_ljust_rjust_fillchar(dtype): values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"]).astype(dtype) result = values.str.center(5, fillchar="X") expected = xr.DataArray(["XXaXX", "XXbbX", "Xcccc", "ddddd", "eeeeee"]) + assert result.dtype == expected.astype(dtype).dtype assert_equal(result, expected.astype(dtype)) result = values.str.ljust(5, fillchar="X") expected = xr.DataArray(["aXXXX", "bbXXX", "ccccX", "ddddd", "eeeeee"]) + assert result.dtype == expected.astype(dtype).dtype assert_equal(result, expected.astype(dtype)) result = values.str.rjust(5, fillchar="X") expected = xr.DataArray(["XXXXa", "XXXbb", "Xcccc", "ddddd", "eeeeee"]) + assert result.dtype == expected.astype(dtype).dtype assert_equal(result, expected.astype(dtype)) # If fillchar is not a charatter, normal str raises TypeError @@ -519,10 +1586,12 @@ def test_zfill(dtype): result = values.str.zfill(5) expected = xr.DataArray(["00001", "00022", "00aaa", "00333", "45678"]) + assert result.dtype == expected.astype(dtype).dtype assert_equal(result, expected.astype(dtype)) result = values.str.zfill(3) expected = xr.DataArray(["001", "022", "aaa", "333", "45678"]) + assert result.dtype == expected.astype(dtype).dtype assert_equal(result, expected.astype(dtype)) @@ -531,6 +1600,7 @@ def test_slice(dtype): result = arr.str.slice(2, 5) exp = xr.DataArray(["foo", "bar", "baz"]).astype(dtype) + assert result.dtype == exp.dtype assert_equal(result, exp) for start, stop, step in [(0, 3, -1), (None, None, -1), (3, 10, 2), (3, 0, -1)]: @@ -549,34 +1619,42 @@ def test_slice_replace(dtype): expected = da(["shrt", "a it longer", "evnlongerthanthat", ""]) result = values.str.slice_replace(2, 3) + assert result.dtype == expected.dtype assert_equal(result, expected) expected = da(["shzrt", "a zit longer", "evznlongerthanthat", "z"]) result = values.str.slice_replace(2, 3, "z") + assert result.dtype == expected.dtype assert_equal(result, expected) expected = da(["shzort", "a zbit longer", "evzenlongerthanthat", "z"]) result = values.str.slice_replace(2, 2, "z") + assert result.dtype == expected.dtype assert_equal(result, expected) expected = da(["shzort", "a zbit longer", "evzenlongerthanthat", "z"]) result = values.str.slice_replace(2, 1, "z") + assert result.dtype == expected.dtype assert_equal(result, expected) expected = da(["shorz", "a bit longez", "evenlongerthanthaz", "z"]) result = values.str.slice_replace(-1, None, "z") + assert result.dtype == expected.dtype assert_equal(result, expected) expected = da(["zrt", "zer", "zat", "z"]) result = values.str.slice_replace(None, -2, "z") + assert result.dtype == expected.dtype assert_equal(result, expected) expected = da(["shortz", "a bit znger", "evenlozerthanthat", "z"]) result = values.str.slice_replace(6, 8, "z") + assert result.dtype == expected.dtype assert_equal(result, expected) expected = da(["zrt", "a zit longer", "evenlongzerthanthat", "z"]) result = values.str.slice_replace(-10, 3, "z") + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -585,14 +1663,17 @@ def test_strip_lstrip_rstrip(dtype): result = values.str.strip() expected = xr.DataArray(["aa", "bb", "cc"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.lstrip() expected = xr.DataArray(["aa ", "bb \n", "cc "]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.rstrip() expected = xr.DataArray([" aa", " bb", "cc"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -601,14 +1682,17 @@ def test_strip_lstrip_rstrip_args(dtype): rs = values.str.strip("x") xp = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) + assert rs.dtype == xp.dtype assert_equal(rs, xp) rs = values.str.lstrip("x") xp = xr.DataArray(["ABCxx", " BNSD", "LDFJH xx"]).astype(dtype) + assert rs.dtype == xp.dtype assert_equal(rs, xp) rs = values.str.rstrip("x") xp = xr.DataArray(["xxABC", "xx BNSD", "LDFJH "]).astype(dtype) + assert rs.dtype == xp.dtype assert_equal(rs, xp) @@ -647,6 +1731,7 @@ def test_wrap(): ) result = values.str.wrap(12, break_long_words=True) + assert result.dtype == expected.dtype assert_equal(result, expected) # test with pre and post whitespace (non-unicode), NaN, and non-ascii @@ -654,6 +1739,7 @@ def test_wrap(): values = xr.DataArray([" pre ", "\xac\u20ac\U00008000 abadcafe"]) expected = xr.DataArray([" pre", "\xac\u20ac\U00008000 ab\nadcafe"]) result = values.str.wrap(6) + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -664,10 +1750,12 @@ def test_wrap_kwargs_passed(): result = values.str.wrap(7) expected = xr.DataArray(" hello\nworld") + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.wrap(7, drop_whitespace=False) expected = xr.DataArray(" hello\n world\n ") + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -676,6 +1764,7 @@ def test_get(dtype): result = values.str[2] expected = xr.DataArray(["b", "d", "g"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) # bounds testing @@ -684,11 +1773,13 @@ def test_get(dtype): # positive index result = values.str[5] expected = xr.DataArray(["_", "_", ""]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) # negative index result = values.str[-6] expected = xr.DataArray(["_", "8", ""]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -698,6 +1789,7 @@ def test_get_default(dtype): result = values.str.get(2, "default") expected = xr.DataArray(["b", "default", "default"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) @@ -705,6 +1797,7 @@ def test_encode_decode(): data = xr.DataArray(["a", "b", "a\xe4"]) encoded = data.str.encode("utf-8") decoded = encoded.str.decode("utf-8") + assert data.dtype == decoded.dtype assert_equal(data, decoded) @@ -721,6 +1814,8 @@ def test_encode_decode_errors(): f = lambda x: x.encode("cp1252", "ignore") result = encodeBase.str.encode("cp1252", "ignore") expected = xr.DataArray([f(x) for x in encodeBase.values.tolist()]) + + assert result.dtype == expected.dtype assert_equal(result, expected) decodeBase = xr.DataArray([b"a", b"b", b"a\x9d"]) @@ -735,4 +1830,555 @@ def test_encode_decode_errors(): f = lambda x: x.decode("cp1252", "ignore") result = decodeBase.str.decode("cp1252", "ignore") expected = xr.DataArray([f(x) for x in decodeBase.values.tolist()]) + + assert result.dtype == expected.dtype assert_equal(result, expected) + + +def test_partition_whitespace(dtype): + values = xr.DataArray( + [ + ["abc def", "spam eggs swallow", "red_blue"], + ["test0 test1 test2 test3", "", "abra ka da bra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + exp_part_dim = [ + [ + ["abc", " ", "def"], + ["spam", " ", "eggs swallow"], + ["red_blue", "", ""], + ], + [ + ["test0", " ", "test1 test2 test3"], + ["", "", ""], + ["abra", " ", "ka da bra"], + ], + ] + + exp_rpart_dim = [ + [ + ["abc", " ", "def"], + ["spam eggs", " ", "swallow"], + ["", "", "red_blue"], + ], + [ + ["test0 test1 test2", " ", "test3"], + ["", "", ""], + ["abra ka da", " ", "bra"], + ], + ] + + exp_part_dim = xr.DataArray(exp_part_dim, dims=["X", "Y", "ZZ"]).astype(dtype) + exp_rpart_dim = xr.DataArray(exp_rpart_dim, dims=["X", "Y", "ZZ"]).astype(dtype) + + res_part_dim = values.str.partition(dim="ZZ") + res_rpart_dim = values.str.rpartition(dim="ZZ") + + assert res_part_dim.dtype == exp_part_dim.dtype + assert res_rpart_dim.dtype == exp_rpart_dim.dtype + + assert_equal(res_part_dim, exp_part_dim) + assert_equal(res_rpart_dim, exp_rpart_dim) + + +def test_partition_comma(dtype): + values = xr.DataArray( + [ + ["abc, def", "spam, eggs, swallow", "red_blue"], + ["test0, test1, test2, test3", "", "abra, ka, da, bra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + exp_part_dim = [ + [ + ["abc", ", ", "def"], + ["spam", ", ", "eggs, swallow"], + ["red_blue", "", ""], + ], + [ + ["test0", ", ", "test1, test2, test3"], + ["", "", ""], + ["abra", ", ", "ka, da, bra"], + ], + ] + + exp_rpart_dim = [ + [ + ["abc", ", ", "def"], + ["spam, eggs", ", ", "swallow"], + ["", "", "red_blue"], + ], + [ + ["test0, test1, test2", ", ", "test3"], + ["", "", ""], + ["abra, ka, da", ", ", "bra"], + ], + ] + + exp_part_dim = xr.DataArray(exp_part_dim, dims=["X", "Y", "ZZ"]).astype(dtype) + exp_rpart_dim = xr.DataArray(exp_rpart_dim, dims=["X", "Y", "ZZ"]).astype(dtype) + + res_part_dim = values.str.partition(sep=", ", dim="ZZ") + res_rpart_dim = values.str.rpartition(sep=", ", dim="ZZ") + + assert res_part_dim.dtype == exp_part_dim.dtype + assert res_rpart_dim.dtype == exp_rpart_dim.dtype + + assert_equal(res_part_dim, exp_part_dim) + assert_equal(res_rpart_dim, exp_rpart_dim) + + +def test_split_whitespace(dtype): + values = xr.DataArray( + [ + ["abc def", "spam\t\teggs\tswallow", "red_blue"], + ["test0\ntest1\ntest2\n\ntest3", "", "abra ka\nda\tbra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + exp_split_dim_full = [ + [ + ["abc", "def", "", ""], + ["spam", "eggs", "swallow", ""], + ["red_blue", "", "", ""], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ] + + exp_rsplit_dim_full = [ + [ + ["", "", "abc", "def"], + ["", "spam", "eggs", "swallow"], + ["", "", "", "red_blue"], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ] + + exp_split_dim_1 = [ + [["abc", "def"], ["spam", "eggs\tswallow"], ["red_blue", ""]], + [["test0", "test1\ntest2\n\ntest3"], ["", ""], ["abra", "ka\nda\tbra"]], + ] + + exp_rsplit_dim_1 = [ + [["abc", "def"], ["spam\t\teggs", "swallow"], ["", "red_blue"]], + [["test0\ntest1\ntest2", "test3"], ["", ""], ["abra ka\nda", "bra"]], + ] + + exp_split_none_full = [ + [["abc", "def"], ["spam", "eggs", "swallow"], ["red_blue"]], + [["test0", "test1", "test2", "test3"], [], ["abra", "ka", "da", "bra"]], + ] + + exp_rsplit_none_full = [ + [["abc", "def"], ["spam", "eggs", "swallow"], ["red_blue"]], + [["test0", "test1", "test2", "test3"], [], ["abra", "ka", "da", "bra"]], + ] + + exp_split_none_1 = [ + [["abc", "def"], ["spam", "eggs\tswallow"], ["red_blue"]], + [["test0", "test1\ntest2\n\ntest3"], [], ["abra", "ka\nda\tbra"]], + ] + + exp_rsplit_none_1 = [ + [["abc", "def"], ["spam\t\teggs", "swallow"], ["red_blue"]], + [["test0\ntest1\ntest2", "test3"], [], ["abra ka\nda", "bra"]], + ] + + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + + exp_split_none_full = [ + [[conv(x) for x in y] for y in z] for z in exp_split_none_full + ] + exp_rsplit_none_full = [ + [[conv(x) for x in y] for y in z] for z in exp_rsplit_none_full + ] + exp_split_none_1 = [[[conv(x) for x in y] for y in z] for z in exp_split_none_1] + exp_rsplit_none_1 = [[[conv(x) for x in y] for y in z] for z in exp_rsplit_none_1] + + exp_split_none_full = np.array(exp_split_none_full, dtype=np.object_) + exp_rsplit_none_full = np.array(exp_rsplit_none_full, dtype=np.object_) + exp_split_none_1 = np.array(exp_split_none_1, dtype=np.object_) + exp_rsplit_none_1 = np.array(exp_rsplit_none_1, dtype=np.object_) + + exp_split_dim_full = xr.DataArray(exp_split_dim_full, dims=["X", "Y", "ZZ"]).astype( + dtype + ) + exp_rsplit_dim_full = xr.DataArray( + exp_rsplit_dim_full, dims=["X", "Y", "ZZ"] + ).astype(dtype) + exp_split_dim_1 = xr.DataArray(exp_split_dim_1, dims=["X", "Y", "ZZ"]).astype(dtype) + exp_rsplit_dim_1 = xr.DataArray(exp_rsplit_dim_1, dims=["X", "Y", "ZZ"]).astype( + dtype + ) + + exp_split_none_full = xr.DataArray(exp_split_none_full, dims=["X", "Y"]) + exp_rsplit_none_full = xr.DataArray(exp_rsplit_none_full, dims=["X", "Y"]) + exp_split_none_1 = xr.DataArray(exp_split_none_1, dims=["X", "Y"]) + exp_rsplit_none_1 = xr.DataArray(exp_rsplit_none_1, dims=["X", "Y"]) + + res_split_dim_full = values.str.split(dim="ZZ") + res_rsplit_dim_full = values.str.rsplit(dim="ZZ") + res_split_dim_1 = values.str.split(dim="ZZ", maxsplit=1) + res_rsplit_dim_1 = values.str.rsplit(dim="ZZ", maxsplit=1) + res_split_dim_10 = values.str.split(dim="ZZ", maxsplit=10) + res_rsplit_dim_10 = values.str.rsplit(dim="ZZ", maxsplit=10) + + res_split_none_full = values.str.split(dim=None) + res_rsplit_none_full = values.str.rsplit(dim=None) + res_split_none_1 = values.str.split(dim=None, maxsplit=1) + res_rsplit_none_1 = values.str.rsplit(dim=None, maxsplit=1) + res_split_none_10 = values.str.split(dim=None, maxsplit=10) + res_rsplit_none_10 = values.str.rsplit(dim=None, maxsplit=10) + + assert res_split_dim_full.dtype == exp_split_dim_full.dtype + assert res_rsplit_dim_full.dtype == exp_rsplit_dim_full.dtype + assert res_split_dim_1.dtype == exp_split_dim_1.dtype + assert res_rsplit_dim_1.dtype == exp_rsplit_dim_1.dtype + assert res_split_dim_10.dtype == exp_split_dim_full.dtype + assert res_rsplit_dim_10.dtype == exp_rsplit_dim_full.dtype + + assert res_split_none_full.dtype == exp_split_none_full.dtype + assert res_rsplit_none_full.dtype == exp_rsplit_none_full.dtype + assert res_split_none_1.dtype == exp_split_none_1.dtype + assert res_rsplit_none_1.dtype == exp_rsplit_none_1.dtype + assert res_split_none_10.dtype == exp_split_none_full.dtype + assert res_rsplit_none_10.dtype == exp_rsplit_none_full.dtype + + assert_equal(res_split_dim_full, exp_split_dim_full) + assert_equal(res_rsplit_dim_full, exp_rsplit_dim_full) + assert_equal(res_split_dim_1, exp_split_dim_1) + assert_equal(res_rsplit_dim_1, exp_rsplit_dim_1) + assert_equal(res_split_dim_10, exp_split_dim_full) + assert_equal(res_rsplit_dim_10, exp_rsplit_dim_full) + + assert_equal(res_split_none_full, exp_split_none_full) + assert_equal(res_rsplit_none_full, exp_rsplit_none_full) + assert_equal(res_split_none_1, exp_split_none_1) + assert_equal(res_rsplit_none_1, exp_rsplit_none_1) + assert_equal(res_split_none_10, exp_split_none_full) + assert_equal(res_rsplit_none_10, exp_rsplit_none_full) + + +def test_split_comma(dtype): + values = xr.DataArray( + [ + ["abc,def", "spam,,eggs,swallow", "red_blue"], + ["test0,test1,test2,test3", "", "abra,ka,da,bra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + exp_split_dim_full = [ + [ + ["abc", "def", "", ""], + ["spam", "", "eggs", "swallow"], + ["red_blue", "", "", ""], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ] + + exp_rsplit_dim_full = [ + [ + ["", "", "abc", "def"], + ["spam", "", "eggs", "swallow"], + ["", "", "", "red_blue"], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ] + + exp_split_dim_1 = [ + [["abc", "def"], ["spam", ",eggs,swallow"], ["red_blue", ""]], + [["test0", "test1,test2,test3"], ["", ""], ["abra", "ka,da,bra"]], + ] + + exp_rsplit_dim_1 = [ + [["abc", "def"], ["spam,,eggs", "swallow"], ["", "red_blue"]], + [["test0,test1,test2", "test3"], ["", ""], ["abra,ka,da", "bra"]], + ] + + exp_split_none_full = [ + [["abc", "def"], ["spam", "", "eggs", "swallow"], ["red_blue"]], + [["test0", "test1", "test2", "test3"], [""], ["abra", "ka", "da", "bra"]], + ] + + exp_rsplit_none_full = [ + [["abc", "def"], ["spam", "", "eggs", "swallow"], ["red_blue"]], + [["test0", "test1", "test2", "test3"], [""], ["abra", "ka", "da", "bra"]], + ] + + exp_split_none_1 = [ + [["abc", "def"], ["spam", ",eggs,swallow"], ["red_blue"]], + [["test0", "test1,test2,test3"], [""], ["abra", "ka,da,bra"]], + ] + + exp_rsplit_none_1 = [ + [["abc", "def"], ["spam,,eggs", "swallow"], ["red_blue"]], + [["test0,test1,test2", "test3"], [""], ["abra,ka,da", "bra"]], + ] + + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + + exp_split_none_full = [ + [[conv(x) for x in y] for y in z] for z in exp_split_none_full + ] + exp_rsplit_none_full = [ + [[conv(x) for x in y] for y in z] for z in exp_rsplit_none_full + ] + exp_split_none_1 = [[[conv(x) for x in y] for y in z] for z in exp_split_none_1] + exp_rsplit_none_1 = [[[conv(x) for x in y] for y in z] for z in exp_rsplit_none_1] + + exp_split_none_full = np.array(exp_split_none_full, dtype=np.object_) + exp_rsplit_none_full = np.array(exp_rsplit_none_full, dtype=np.object_) + exp_split_none_1 = np.array(exp_split_none_1, dtype=np.object_) + exp_rsplit_none_1 = np.array(exp_rsplit_none_1, dtype=np.object_) + + exp_split_dim_full = xr.DataArray(exp_split_dim_full, dims=["X", "Y", "ZZ"]).astype( + dtype + ) + exp_rsplit_dim_full = xr.DataArray( + exp_rsplit_dim_full, dims=["X", "Y", "ZZ"] + ).astype(dtype) + exp_split_dim_1 = xr.DataArray(exp_split_dim_1, dims=["X", "Y", "ZZ"]).astype(dtype) + exp_rsplit_dim_1 = xr.DataArray(exp_rsplit_dim_1, dims=["X", "Y", "ZZ"]).astype( + dtype + ) + + exp_split_none_full = xr.DataArray(exp_split_none_full, dims=["X", "Y"]) + exp_rsplit_none_full = xr.DataArray(exp_rsplit_none_full, dims=["X", "Y"]) + exp_split_none_1 = xr.DataArray(exp_split_none_1, dims=["X", "Y"]) + exp_rsplit_none_1 = xr.DataArray(exp_rsplit_none_1, dims=["X", "Y"]) + + res_split_dim_full = values.str.split(sep=",", dim="ZZ") + res_rsplit_dim_full = values.str.rsplit(sep=",", dim="ZZ") + res_split_dim_1 = values.str.split(sep=",", dim="ZZ", maxsplit=1) + res_rsplit_dim_1 = values.str.rsplit(sep=",", dim="ZZ", maxsplit=1) + res_split_dim_10 = values.str.split(sep=",", dim="ZZ", maxsplit=10) + res_rsplit_dim_10 = values.str.rsplit(sep=",", dim="ZZ", maxsplit=10) + + res_split_none_full = values.str.split(sep=",", dim=None) + res_rsplit_none_full = values.str.rsplit(sep=",", dim=None) + res_split_none_1 = values.str.split(sep=",", dim=None, maxsplit=1) + res_rsplit_none_1 = values.str.rsplit(sep=",", dim=None, maxsplit=1) + res_split_none_10 = values.str.split(sep=",", dim=None, maxsplit=10) + res_rsplit_none_10 = values.str.rsplit(sep=",", dim=None, maxsplit=10) + + assert res_split_dim_full.dtype == exp_split_dim_full.dtype + assert res_rsplit_dim_full.dtype == exp_rsplit_dim_full.dtype + assert res_split_dim_1.dtype == exp_split_dim_1.dtype + assert res_rsplit_dim_1.dtype == exp_rsplit_dim_1.dtype + assert res_split_dim_10.dtype == exp_split_dim_full.dtype + assert res_rsplit_dim_10.dtype == exp_rsplit_dim_full.dtype + + assert res_split_none_full.dtype == exp_split_none_full.dtype + assert res_rsplit_none_full.dtype == exp_rsplit_none_full.dtype + assert res_split_none_1.dtype == exp_split_none_1.dtype + assert res_rsplit_none_1.dtype == exp_rsplit_none_1.dtype + assert res_split_none_10.dtype == exp_split_none_full.dtype + assert res_rsplit_none_10.dtype == exp_rsplit_none_full.dtype + + assert_equal(res_split_dim_full, exp_split_dim_full) + assert_equal(res_rsplit_dim_full, exp_rsplit_dim_full) + assert_equal(res_split_dim_1, exp_split_dim_1) + assert_equal(res_rsplit_dim_1, exp_rsplit_dim_1) + assert_equal(res_split_dim_10, exp_split_dim_full) + assert_equal(res_rsplit_dim_10, exp_rsplit_dim_full) + + assert_equal(res_split_none_full, exp_split_none_full) + assert_equal(res_rsplit_none_full, exp_rsplit_none_full) + assert_equal(res_split_none_1, exp_split_none_1) + assert_equal(res_rsplit_none_1, exp_rsplit_none_1) + assert_equal(res_split_none_10, exp_split_none_full) + assert_equal(res_rsplit_none_10, exp_rsplit_none_full) + + +def test_get_dummies(dtype): + values_line = xr.DataArray( + [["a|ab~abc|abc", "ab", "a||abc|abcd"], ["abcd|ab|a", "abc|ab~abc", "|a"]], + dims=["X", "Y"], + ).astype(dtype) + values_comma = xr.DataArray( + [["a~ab|abc~~abc", "ab", "a~abc~abcd"], ["abcd~ab~a", "abc~ab|abc", "~a"]], + dims=["X", "Y"], + ).astype(dtype) + + vals_line = np.array(["a", "ab", "abc", "abcd", "ab~abc"]).astype(dtype) + vals_comma = np.array(["a", "ab", "abc", "abcd", "ab|abc"]).astype(dtype) + targ = [ + [ + [True, False, True, False, True], + [False, True, False, False, False], + [True, False, True, True, False], + ], + [ + [True, True, False, True, False], + [False, False, True, False, True], + [True, False, False, False, False], + ], + ] + targ = np.array(targ) + targ = xr.DataArray(targ, dims=["X", "Y", "ZZ"]) + targ_line = targ.copy() + targ_comma = targ.copy() + targ_line.coords["ZZ"] = vals_line + targ_comma.coords["ZZ"] = vals_comma + + res_default = values_line.str.get_dummies(dim="ZZ") + res_line = values_line.str.get_dummies(dim="ZZ", sep="|") + res_comma = values_comma.str.get_dummies(dim="ZZ", sep="~") + + assert res_default.dtype == targ_line.dtype + assert res_line.dtype == targ_line.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_default, targ_line) + assert_equal(res_line, targ_line) + assert_equal(res_comma, targ_comma) + + +def test_splitters_empty_str(dtype): + values = xr.DataArray( + [["", "", ""], ["", "", ""]], + dims=["X", "Y"], + ).astype(dtype) + + conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] + + targ_partition_dim = xr.DataArray( + [ + [["", "", ""], ["", "", ""], ["", "", ""]], + [["", "", ""], ["", "", ""], ["", "", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + + targ_partition_none = [ + [["", "", ""], ["", "", ""], ["", "", ""]], + [["", "", ""], ["", "", ""], ["", "", "", ""]], + ] + targ_partition_none = [ + [[conv(x) for x in y] for y in z] for z in targ_partition_none + ] + targ_partition_none = np.array(targ_partition_none, dtype=np.object_) + del targ_partition_none[-1, -1][-1] + targ_partition_none = xr.DataArray( + targ_partition_none, + dims=["X", "Y"], + ) + + targ_split_dim = xr.DataArray( + [[[""], [""], [""]], [[""], [""], [""]]], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + targ_split_none = xr.DataArray( + np.array([[[], [], []], [[], [], [""]]], dtype=np.object_), + dims=["X", "Y"], + ) + del targ_split_none.data[-1, -1][-1] + + res_partition_dim = values.str.partition(dim="ZZ") + res_rpartition_dim = values.str.rpartition(dim="ZZ") + res_partition_none = values.str.partition(dim=None) + res_rpartition_none = values.str.rpartition(dim=None) + + res_split_dim = values.str.split(dim="ZZ") + res_rsplit_dim = values.str.rsplit(dim="ZZ") + res_split_none = values.str.split(dim=None) + res_rsplit_none = values.str.rsplit(dim=None) + + res_dummies = values.str.rsplit(dim="ZZ") + + assert res_partition_dim.dtype == targ_partition_dim.dtype + assert res_rpartition_dim.dtype == targ_partition_dim.dtype + assert res_partition_none.dtype == targ_partition_none.dtype + assert res_rpartition_none.dtype == targ_partition_none.dtype + + assert res_split_dim.dtype == targ_split_dim.dtype + assert res_rsplit_dim.dtype == targ_split_dim.dtype + assert res_split_none.dtype == targ_split_none.dtype + assert res_rsplit_none.dtype == targ_split_none.dtype + + assert res_dummies.dtype == targ_split_dim.dtype + + assert_equal(res_partition_dim, targ_partition_dim) + assert_equal(res_rpartition_dim, targ_partition_dim) + assert_equal(res_partition_none, targ_partition_none) + assert_equal(res_rpartition_none, targ_partition_none) + + assert_equal(res_split_dim, targ_split_dim) + assert_equal(res_rsplit_dim, targ_split_dim) + assert_equal(res_split_none, targ_split_none) + assert_equal(res_rsplit_none, targ_split_none) + + assert_equal(res_dummies, targ_split_dim) + + +def test_splitters_empty_array(dtype): + values = xr.DataArray( + [[], []], + dims=["X", "Y"], + ).astype(dtype) + + targ_dim = xr.DataArray( + np.empty([2, 0, 0]), + dims=["X", "Y", "ZZ"], + ).astype(dtype) + targ_none = xr.DataArray( + np.empty([2, 0]), + dims=["X", "Y"], + ).astype(np.object_) + + res_part_dim = values.str.partition(dim="ZZ") + res_rpart_dim = values.str.rpartition(dim="ZZ") + res_part_none = values.str.partition(dim=None) + res_rpart_none = values.str.rpartition(dim=None) + + res_split_dim = values.str.split(dim="ZZ") + res_rsplit_dim = values.str.rsplit(dim="ZZ") + res_split_none = values.str.split(dim=None) + res_rsplit_none = values.str.rsplit(dim=None) + + res_dummies = values.str.get_dummies(dim="ZZ") + + assert res_part_dim.dtype == targ_dim.dtype + assert res_rpart_dim.dtype == targ_dim.dtype + assert res_part_none.dtype == targ_none.dtype + assert res_rpart_none.dtype == targ_none.dtype + + assert res_split_dim.dtype == targ_dim.dtype + assert res_rsplit_dim.dtype == targ_dim.dtype + assert res_split_none.dtype == targ_none.dtype + assert res_rsplit_none.dtype == targ_none.dtype + + assert res_dummies.dtype == targ_dim.dtype + + assert_equal(res_part_dim, targ_dim) + assert_equal(res_rpart_dim, targ_dim) + assert_equal(res_part_none, targ_none) + assert_equal(res_rpart_none, targ_none) + + assert_equal(res_split_dim, targ_dim) + assert_equal(res_rsplit_dim, targ_dim) + assert_equal(res_split_none, targ_none) + assert_equal(res_rsplit_none, targ_none) + + assert_equal(res_dummies, targ_dim) From 787216a598a556d882b97a875938ba4eccc64511 Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Mon, 7 Dec 2020 00:14:07 -0500 Subject: [PATCH 05/14] implement cat, join, format, +, *, and % --- xarray/core/accessor_str.py | 276 ++++++++++++++++- xarray/tests/test_accessor_str.py | 481 +++++++++++++++++++++++++++++- 2 files changed, 752 insertions(+), 5 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 5e4d1f1ae05..39cbd6d53fc 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -77,6 +77,8 @@ class StringAccessor: array([4, 4, 2, 2, 5]) Dimensions without coordinates: dim_0 + It also implements `+`, `*`, and `%`, which operate as elementwise + versions of the corresponding `str` methods. """ __slots__ = ("_obj",) @@ -88,17 +90,21 @@ def __init__(self, obj): def _stringify( self, invar: Any, - ) -> Union[str, bytes]: + ) -> Union[str, bytes, Any]: """ Convert a string-like to the correct string/bytes type. This is mostly here to tell mypy a pattern is a str/bytes not a re.Pattern. """ - return self._obj.dtype.type(invar) + if hasattr(invar, "astype"): + return invar.astype(self._obj.dtype.kind) + else: + return self._obj.dtype.type(invar) def _apply( self, - f: Callable, + func: Callable, + *args, obj: Any = None, dtype: Union[str, np.dtype] = None, output_core_dims: Union[list, tuple] = ((),), @@ -116,8 +122,9 @@ def _apply( dask_gufunc_kwargs["output_sizes"] = output_sizes return apply_ufunc( - f, + func, obj, + *args, vectorize=True, dask="parallelized", output_dtypes=[dtype], @@ -171,6 +178,29 @@ def __getitem__( else: return self.get(key) + def __add__( + self, + other: Any, + ) -> Any: + return self.cat(other, sep="") + + def __mul__( + self, + num: int, + ) -> Any: + if num <= 0: + return self[:0] + if num == 1: + return self._obj.copy() + else: + return self.repeat(num) + + def __mod__( + self, + other: Any, + ) -> Any: + return self._apply(lambda x: x % other) + def get( self, i: int, @@ -270,6 +300,244 @@ def f(x): return self._apply(f) + def cat( + self, + *others, + sep: Any = "", + ) -> Any: + """ + Concatenate strings elementwise in the DataArray with other strings. + + The other strings can either be string scalars or other array-like. + Dimensions are automatically broadcast together. + + An optional separator can also be specified. + + Parameters + ---------- + *others : str or array-like of str + Strings or array-like of strings to concatenate elementwise with + the current DataArray. + sep : str or array-like, default `""`. + Seperator to use between strings. + It is broadcast in the same way as the other input strings. + If array-like, its dimensions will be placed at the end of the output array dimensions. + + Returns + ------- + concatenated : same type as values + + Examples + -------- + Create a string array + + >>> myarray = xr.DataArray( + ... ["11111", "4"], + ... dims=["X"], + ... ) + + Create some arrays to concatenate with it + + >>> values_1 = xr.DataArray( + ... ["a", "bb", "cccc"], + ... dims=["Y"], + ... ) + >>> values_2 = np.array(3.4) + >>> values_3 = "" + >>> values_4 = np.array("test", dtype=np.unicode_) + + Determine the separator to use + + >>> seps = xr.DataArray( + ... [" ", ", "], + ... dims=["ZZ"], + ... ) + + Concatenate the arrays using the separator + + >>> myarray.str.cat(values_1, values_2, values_3, values_4, sep=seps) + + array([[['11111 a 3.4 test', '11111, a, 3.4, , test'], + ['11111 bb 3.4 test', '11111, bb, 3.4, , test'], + ['11111 cccc 3.4 test', '11111, cccc, 3.4, , test']], + + [['4 a 3.4 test', '4, a, 3.4, , test'], + ['4 bb 3.4 test', '4, bb, 3.4, , test'], + ['4 cccc 3.4 test', '4, cccc, 3.4, , test']]], dtype=' Any: + """ + Concatenate strings in a DataArray along a particular dimension. + + An optional separator can also be specified. + + Parameters + ---------- + dim : Hashable, optional + Dimension along which the strings should be concatenated. + Optional for 0D or 1D DataArrays, required for multidimensional DataArrays. + sep : str or array-like, default `""`. + Seperator to use between strings. + It is broadcast in the same way as the other input strings. + If array-like, its dimensions will be placed at the end of the output array dimensions. + + Returns + ------- + joined : same type as values + + Examples + -------- + Create an array + + >>> values = xr.DataArray( + ... [["a", "bab", "abc"], ["abcd", "", "abcdef"]], + ... dims=["X", "Y"], + ... ) + + Determine the separator + + >>> seps = xr.DataArray( + ... ["-", "_"], + ... dims=["ZZ"], + ... ) + + Join the strings along a given dimension + + >>> values.str.join(dim="Y", sep=seps) + + array([['a-bab-abc', 'a_bab_abc'], + ['abcd--abcdef', 'abcd__abcdef']], dtype=' 1 and dim is None: + raise ValueError("Dimension must be specified for multidimensional arrays.") + + if self._obj.ndim > 1: + # Move the target dimension to the start and split along it + dimshifted = list(self._obj.transpose(dim, ...)) + elif self._obj.ndim == 1: + dimshifted = list(self._obj) + else: + dimshifted = [self._obj] + + start, *others = dimshifted + + # concatenate the resulting arrays + return start.str.cat(*others, sep=sep) + + def format( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """ + Perform python string formatting on each element of the DataArray. + + This is equivalent to calling `str.format` on every element of the + DataArray. The replacement values can either be a string-like + scalar or an array-like of string-like values. If array-like, + the values will be broadcast and applied elementwiseto the input + DataArray. + + .. note:: + Array-like values provided as `*args` will have their + dimensions added even if those arguments are not used in any + string formatting. + + .. warning:: + Array-like arguments are only applied elementwise for `*args`. + For `**kwargs`, values are used as-is. + + Parameters + ---------- + *args : str or bytes or array-like of str or bytes + Values for positional formatting. + If array-like, the values are broadcast and applied elementwise. + The dimensions will be placed at the end of the output array dimensions + in the order they are provided. + **kwargs : str or bytes or array-like of str or bytes + Values for keyword-based formatting. + These are **not** broadcast or applied elementwise. + + Returns + ------- + formatted : same type as values + + Examples + -------- + Create an array to format. + + >>> values = xr.DataArray( + ... ["{} is {adj0}", "{} and {} are {adj1}"], + ... dims=["X"], + ... ) + + Set the values to fill. + + >>> noun0 = xr.DataArray( + ... ["spam", "egg"], + ... dims=["Y"], + ... ) + >>> noun1 = xr.DataArray( + ... ["lancelot", "arthur"], + ... dims=["ZZ"], + ... ) + >>> adj0 = "unexpected" + >>> adj1 = "like a duck" + + Insert the values into the array + + >>> values.str.format(noun0, noun1, adj0=adj0, adj1=adj1) + + array([[['spam is unexpected', 'spam is unexpected'], + ['egg is unexpected', 'egg is unexpected']], + + [['spam and lancelot are like a duck', + 'spam and arthur are like a duck'], + ['egg and lancelot are like a duck', + 'egg and arthur are like a duck']]], dtype=' Any: """ Convert strings in the array to be capitalized. diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index b2b5fe883a7..bbc9659668c 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -46,7 +46,7 @@ import xarray as xr -from . import assert_equal, requires_dask +from . import assert_equal, assert_identical, requires_dask @pytest.fixture(params=[np.str_, np.bytes_]) @@ -1206,9 +1206,16 @@ def test_findall_multi_multi_nocase(dtype): def test_repeat(dtype): values = xr.DataArray(["a", "b", "c", "d"]).astype(dtype) + result = values.str.repeat(3) + result_mul = values.str * 3 + expected = xr.DataArray(["aaa", "bbb", "ccc", "ddd"]).astype(dtype) + assert result.dtype == expected.dtype + assert result_mul.dtype == expected.dtype + + assert_equal(result_mul, expected) assert_equal(result, expected) @@ -2382,3 +2389,475 @@ def test_splitters_empty_array(dtype): assert_equal(res_rsplit_none, targ_none) assert_equal(res_dummies, targ_dim) + + +def test_cat_str(dtype): + values_1 = xr.DataArray( + [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], + dims=["X", "Y"], + ).astype(dtype) + values_2 = "111" + + targ_blank = xr.DataArray( + [["a111", "bb111", "cccc111"], ["ddddd111", "eeee111", "fff111"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_space = xr.DataArray( + [["a 111", "bb 111", "cccc 111"], ["ddddd 111", "eeee 111", "fff 111"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_bars = xr.DataArray( + [["a||111", "bb||111", "cccc||111"], ["ddddd||111", "eeee||111", "fff||111"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_comma = xr.DataArray( + [["a, 111", "bb, 111", "cccc, 111"], ["ddddd, 111", "eeee, 111", "fff, 111"]], + dims=["X", "Y"], + ).astype(dtype) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_uniform(dtype): + values_1 = xr.DataArray( + [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], + dims=["X", "Y"], + ).astype(dtype) + values_2 = xr.DataArray( + [["11111", "222", "33"], ["4", "5555", "66"]], + dims=["X", "Y"], + ) + + targ_blank = xr.DataArray( + [["a11111", "bb222", "cccc33"], ["ddddd4", "eeee5555", "fff66"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_space = xr.DataArray( + [["a 11111", "bb 222", "cccc 33"], ["ddddd 4", "eeee 5555", "fff 66"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_bars = xr.DataArray( + [["a||11111", "bb||222", "cccc||33"], ["ddddd||4", "eeee||5555", "fff||66"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_comma = xr.DataArray( + [["a, 11111", "bb, 222", "cccc, 33"], ["ddddd, 4", "eeee, 5555", "fff, 66"]], + dims=["X", "Y"], + ).astype(dtype) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_broadcast_right(dtype): + values_1 = xr.DataArray( + [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], + dims=["X", "Y"], + ).astype(dtype) + values_2 = xr.DataArray( + ["11111", "222", "33"], + dims=["Y"], + ) + + targ_blank = xr.DataArray( + [["a11111", "bb222", "cccc33"], ["ddddd11111", "eeee222", "fff33"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_space = xr.DataArray( + [["a 11111", "bb 222", "cccc 33"], ["ddddd 11111", "eeee 222", "fff 33"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_bars = xr.DataArray( + [["a||11111", "bb||222", "cccc||33"], ["ddddd||11111", "eeee||222", "fff||33"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_comma = xr.DataArray( + [["a, 11111", "bb, 222", "cccc, 33"], ["ddddd, 11111", "eeee, 222", "fff, 33"]], + dims=["X", "Y"], + ).astype(dtype) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_broadcast_left(dtype): + values_1 = xr.DataArray( + ["a", "bb", "cccc"], + dims=["Y"], + ).astype(dtype) + values_2 = xr.DataArray( + [["11111", "222", "33"], ["4", "5555", "66"]], + dims=["X", "Y"], + ) + + targ_blank = ( + xr.DataArray( + [["a11111", "bb222", "cccc33"], ["a4", "bb5555", "cccc66"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_space = ( + xr.DataArray( + [["a 11111", "bb 222", "cccc 33"], ["a 4", "bb 5555", "cccc 66"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_bars = ( + xr.DataArray( + [["a||11111", "bb||222", "cccc||33"], ["a||4", "bb||5555", "cccc||66"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_comma = ( + xr.DataArray( + [["a, 11111", "bb, 222", "cccc, 33"], ["a, 4", "bb, 5555", "cccc, 66"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_broadcast_both(dtype): + values_1 = xr.DataArray( + ["a", "bb", "cccc"], + dims=["Y"], + ).astype(dtype) + values_2 = xr.DataArray( + ["11111", "4"], + dims=["X"], + ) + + targ_blank = ( + xr.DataArray( + [["a11111", "bb11111", "cccc11111"], ["a4", "bb4", "cccc4"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_space = ( + xr.DataArray( + [["a 11111", "bb 11111", "cccc 11111"], ["a 4", "bb 4", "cccc 4"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_bars = ( + xr.DataArray( + [["a||11111", "bb||11111", "cccc||11111"], ["a||4", "bb||4", "cccc||4"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_comma = ( + xr.DataArray( + [["a, 11111", "bb, 11111", "cccc, 11111"], ["a, 4", "bb, 4", "cccc, 4"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_multi(): + dtype = np.unicode_ + values_1 = xr.DataArray( + ["11111", "4"], + dims=["X"], + ) + + values_2 = xr.DataArray( + ["a", "bb", "cccc"], + dims=["Y"], + ).astype(np.bytes_) + + values_3 = np.array(3.4) + + values_4 = "" + + values_5 = np.array("", dtype=np.unicode_) + + sep = xr.DataArray( + [" ", ", "], + dims=["ZZ"], + ).astype(dtype) + + targ = xr.DataArray( + [ + [ + ["11111 a 3.4 ", "11111, a, 3.4, , "], + ["11111 bb 3.4 ", "11111, bb, 3.4, , "], + ["11111 cccc 3.4 ", "11111, cccc, 3.4, , "], + ], + [ + ["4 a 3.4 ", "4, a, 3.4, , "], + ["4 bb 3.4 ", "4, bb, 3.4, , "], + ["4 cccc 3.4 ", "4, cccc, 3.4, , "], + ], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + + res = values_1.str.cat(values_2, values_3, values_4, values_5, sep=sep) + + assert res.dtype == targ.dtype + assert_equal(res, targ) + + +def test_join_vector(dtype): + values = xr.DataArray( + ["a", "bb", "cccc"], + dims=["Y"], + ).astype(dtype) + + targ_blank = xr.DataArray("abbcccc").astype(dtype) + targ_space = xr.DataArray("a bb cccc").astype(dtype) + + res_blank_none = values.str.join() + res_blank_y = values.str.join(dim="Y") + + res_space_none = values.str.join(sep=" ") + res_space_y = values.str.join(sep=" ", dim="Y") + + assert res_blank_none.dtype == targ_blank.dtype + assert res_blank_y.dtype == targ_blank.dtype + assert res_space_none.dtype == targ_space.dtype + assert res_space_y.dtype == targ_space.dtype + + assert_identical(res_blank_none, targ_blank) + assert_identical(res_blank_y, targ_blank) + assert_identical(res_space_none, targ_space) + assert_identical(res_space_y, targ_space) + + +def test_join_2d(dtype): + values = xr.DataArray( + [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_blank_x = xr.DataArray( + ["addddd", "bbeeee", "ccccfff"], + dims=["Y"], + ).astype(dtype) + targ_space_x = xr.DataArray( + ["a ddddd", "bb eeee", "cccc fff"], + dims=["Y"], + ).astype(dtype) + + targ_blank_y = xr.DataArray( + ["abbcccc", "dddddeeeefff"], + dims=["X"], + ).astype(dtype) + targ_space_y = xr.DataArray( + ["a bb cccc", "ddddd eeee fff"], + dims=["X"], + ).astype(dtype) + + res_blank_x = values.str.join(dim="X") + res_blank_y = values.str.join(dim="Y") + + res_space_x = values.str.join(dim="X", sep=" ") + res_space_y = values.str.join(sep=" ", dim="Y") + + assert res_blank_x.dtype == targ_blank_x.dtype + assert res_blank_y.dtype == targ_blank_y.dtype + assert res_space_x.dtype == targ_space_x.dtype + assert res_space_y.dtype == targ_space_y.dtype + + assert_identical(res_blank_x, targ_blank_x) + assert_identical(res_blank_y, targ_blank_y) + assert_identical(res_space_x, targ_space_x) + assert_identical(res_space_y, targ_space_y) + + with pytest.raises(ValueError): + values.str.join() + + +def test_join_broadcast(dtype): + values = xr.DataArray( + ["a", "bb", "cccc"], + dims=["X"], + ).astype(dtype) + + sep = xr.DataArray( + [" ", ", "], + dims=["ZZ"], + ).astype(dtype) + + targ = xr.DataArray( + ["a bb cccc", "a, bb, cccc"], + dims=["ZZ"], + ).astype(dtype) + + res = values.str.join(sep=sep) + + assert res.dtype == targ.dtype + assert_identical(res, targ) + + +def test_format_scalar(): + dtype = np.unicode_ + values = xr.DataArray( + ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], + dims=["X"], + ).astype(dtype) + + pos0 = 1 + pos1 = 1.2 + pos2 = "2.3" + X = "'test'" + Y = "X" + ZZ = None + W = "NO!" + + targ = xr.DataArray( + ["1.X.None", "1,1.2,'test','test'", "'test'-X-None"], + dims=["X"], + ).astype(dtype) + + res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) + + assert res.dtype == targ.dtype + assert_equal(res, targ) + + +def test_format_broadcast(dtype): + dtype = np.unicode_ + values = xr.DataArray( + ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], + dims=["X"], + ).astype(dtype) + + pos0 = 1 + pos1 = 1.2 + + pos2 = xr.DataArray( + ["2.3", "3.44444"], + dims=["YY"], + ) + + X = "'test'" + Y = "X" + ZZ = None + W = "NO!" + + targ = xr.DataArray( + [ + ["1.X.None", "1.X.None"], + ["1,1.2,'test','test'", "1,1.2,'test','test'"], + ["'test'-X-None", "'test'-X-None"], + ], + dims=["X", "YY"], + ).astype(dtype) + + res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) + + assert res.dtype == targ.dtype + assert_equal(res, targ) From 0e8011500dae04abe652212b771d7fcc74ff75ba Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Mon, 7 Dec 2020 23:52:56 -0500 Subject: [PATCH 06/14] support elementwise operations in many str accessor functions --- xarray/core/accessor_str.py | 1067 ++++++++++------- xarray/tests/test_accessor_str.py | 1786 ++++++++++++++++++++--------- 2 files changed, 1952 insertions(+), 901 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 39cbd6d53fc..81ba87b12ca 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -42,7 +42,7 @@ import textwrap from functools import reduce from operator import or_ as set_union -from typing import Any, Callable, Hashable, Mapping, Optional, Pattern, Union +from typing import Any, Callable, Hashable, Mapping, Optional, Pattern, Tuple, Union from unicodedata import normalize import numpy as np @@ -61,8 +61,72 @@ _cpython_optimized_decoders = _cpython_optimized_encoders + ("utf-16", "utf-32") -def _is_str_like(x: Any) -> bool: - return isinstance(x, str) or isinstance(x, bytes) +def _contains_obj_type(*, pat: Any, checker: Any) -> bool: + """Determine if the object fits some rule or is array of objects that do so.""" + if isinstance(checker, type): + targtype = checker + checker = lambda x: isinstance(x, targtype) + + if checker(pat): + return True + + # If it is not an object array it can't contain compiled re + if not getattr(pat, "dtype", "no") == np.object_: + return False + + return _apply_str_ufunc(func=checker, obj=pat).all() + + +def _contains_str_like(pat: Any) -> bool: + """Determine if the object is a str-like or array of str-like.""" + if isinstance(pat, (str, bytes)): + return True + + if not hasattr(pat, "dtype"): + return False + + return pat.dtype.kind == "U" or pat.dtype.kind == "S" + + +def _contains_compiled_re(pat: Any) -> bool: + """Determine if the object is a compiled re or array of compiled re.""" + return _contains_obj_type(pat=pat, checker=re.Pattern) + + +def _contains_callable(pat: Any) -> bool: + """Determine if the object is a callable or array of callables.""" + return _contains_obj_type(pat=pat, checker=callable) + + +def _apply_str_ufunc( + *, + func: Callable, + obj: Any, + dtype: Union[str, np.dtype] = None, + output_core_dims: Union[list, tuple] = ((),), + output_sizes: Mapping[Hashable, int] = None, + func_args: Tuple = (), + func_kwargs: Mapping = {}, +) -> Any: + # TODO handling of na values ? + if dtype is None: + dtype = obj.dtype + + dask_gufunc_kwargs = dict() + if output_sizes is not None: + dask_gufunc_kwargs["output_sizes"] = output_sizes + + return apply_ufunc( + func, + obj, + *func_args, + vectorize=True, + dask="parallelized", + output_dtypes=[dtype], + output_core_dims=output_core_dims, + dask_gufunc_kwargs=dask_gufunc_kwargs, + **func_kwargs, + ) class StringAccessor: @@ -77,12 +141,58 @@ class StringAccessor: array([4, 4, 2, 2, 5]) Dimensions without coordinates: dim_0 - It also implements `+`, `*`, and `%`, which operate as elementwise - versions of the corresponding `str` methods. + It also implements ``+``, ``*``, and ``%``, which operate as elementwise + versions of the corresponding ``str`` methods. These will automatically + broadcast for array-like inputs. + + >>> da1 = xr.DataArray(["first", "second", "third"], dims=["X"]) + >>> da2 = xr.DataArray([1, 2, 3], dims=["Y"]) + >>> da1.str + da2 + + array([['first1', 'first2', 'first3'], + ['second1', 'second2', 'second3'], + ['third1', 'third2', 'third3']], dtype='>> da1 = xr.DataArray(["a", "b", "c", "d"], dims=["X"]) + >>> reps = xr.DataArray([3, 4], dims=["Y"]) + >>> da1.str * reps + + array([['aaa', 'aaaa'], + ['bbb', 'bbbb'], + ['ccc', 'cccc'], + ['ddd', 'dddd']], dtype='>> da1 = xr.DataArray(["%s_%s", "%s-%s", "%s|%s"], dims=["X"]) + >>> da2 = xr.DataArray([1, 2], dims=["Y"]) + >>> da3 = xr.DataArray([0.1, 0.2], dims=["Z"]) + >>> da1.str % (da2, da3) + + array([[['1_0.1', '1_0.2'], + ['2_0.1', '2_0.2']], + + [['1-0.1', '1-0.2'], + ['2-0.1', '2-0.2']], + + [['1|0.1', '1|0.2'], + ['2|0.1', '2|0.2']]], dtype='>> da1 = xr.DataArray(["%(a)s"], dims=["X"]) + >>> da2 = xr.DataArray([1, 2, 3], dims=["Y"]) + >>> da1 % {"a": da2} + + array(['\\narray([1, 2, 3])\\nDimensions without coordinates: Y'], + dtype=object) + Dimensions without coordinates: X """ __slots__ = ("_obj",) - _pattern_type = type(re.compile("")) def __init__(self, obj): self._obj = obj @@ -103,46 +213,38 @@ def _stringify( def _apply( self, + *, func: Callable, - *args, - obj: Any = None, dtype: Union[str, np.dtype] = None, output_core_dims: Union[list, tuple] = ((),), output_sizes: Mapping[Hashable, int] = None, - **kwargs, + func_args: Tuple = (), + func_kwargs: Mapping = {}, ) -> Any: - # TODO handling of na values ? - if obj is None: - obj = self._obj - if dtype is None: - dtype = obj.dtype - - dask_gufunc_kwargs = dict() - if output_sizes is not None: - dask_gufunc_kwargs["output_sizes"] = output_sizes - - return apply_ufunc( - func, - obj, - *args, - vectorize=True, - dask="parallelized", - output_dtypes=[dtype], + return _apply_str_ufunc( + obj=self._obj, + func=func, + dtype=dtype, output_core_dims=output_core_dims, - dask_gufunc_kwargs=dask_gufunc_kwargs, - **kwargs, + output_sizes=output_sizes, + func_args=func_args, + func_kwargs=func_kwargs, ) def _re_compile( - self, pat: Union[str, bytes, Pattern], flags: int, case: bool = None - ) -> Pattern: - is_compiled_re = isinstance(pat, self._pattern_type) + self, + *, + pat: Union[str, bytes, Pattern, Any], + flags: int = 0, + case: bool = None, + ) -> Union[Pattern, Any]: + is_compiled_re = isinstance(pat, re.Pattern) if is_compiled_re and flags != 0: - raise ValueError("flags cannot be set when pat is a compiled regex") + raise ValueError("Flags cannot be set when pat is a compiled regex.") if is_compiled_re and case is not None: - raise ValueError("case cannot be set when pat is a compiled regex") + raise ValueError("Case cannot be set when pat is a compiled regex.") if is_compiled_re: # no-op, needed to tell mypy this isn't a string @@ -156,8 +258,15 @@ def _re_compile( if not case: flags |= re.IGNORECASE - pat = self._stringify(pat) - return re.compile(pat, flags=flags) + if getattr(pat, "dtype", None) != np.object_: + pat = self._stringify(pat) + func = lambda x: re.compile(x, flags=flags) + if isinstance(pat, np.ndarray): + # apply_ufunc doesn't work for numpy arrays with output object dtypes + func = np.vectorize(func) + return func(pat) + else: + return _apply_str_ufunc(func=func, obj=pat, dtype=np.object_) def len(self) -> Any: """ @@ -167,7 +276,7 @@ def len(self) -> Any: ------- lengths array : array of int """ - return self._apply(len, dtype=int) + return self._apply(func=len, dtype=int) def __getitem__( self, @@ -186,33 +295,39 @@ def __add__( def __mul__( self, - num: int, + num: Union[int, Any], ) -> Any: - if num <= 0: - return self[:0] - if num == 1: - return self._obj.copy() - else: - return self.repeat(num) + return self.repeat(num) def __mod__( self, other: Any, ) -> Any: - return self._apply(lambda x: x % other) + if isinstance(other, dict): + other = {key: self._stringify(val) for key, val in other.items()} + return self._apply(func=lambda x: x % other) + elif isinstance(other, tuple): + other = tuple(self._stringify(x) for x in other) + return self._apply(func=lambda x, *y: x % y, func_args=other) + else: + return self._apply(func=lambda x, y: x % y, func_args=(other,)) def get( self, - i: int, + i: Union[int, Any], default: Union[str, bytes] = "", ) -> Any: """ Extract character number `i` from each string in the array. + If `i` is array-like, they are broadcast against the array and + applied elementwise. + Parameters ---------- - i : int + i : int or array-like of int Position of element to extract. + If array-like, it is broadcast. default : optional Value for out-of-range index. If not specified (None) defaults to an empty string. @@ -221,63 +336,71 @@ def get( ------- items : array of object """ - s = slice(-1, None) if i == -1 else slice(i, i + 1) - def f(x): - item = x[s] + def f(x, iind): + islice = slice(-1, None) if iind == -1 else slice(iind, iind + 1) + item = x[islice] return item if item else default - return self._apply(f) + return self._apply(func=f, func_args=(i,)) def slice( self, - start: int = None, - stop: int = None, - step: int = None, + start: Union[int, Any] = None, + stop: Union[int, Any] = None, + step: Union[int, Any] = None, ) -> Any: """ Slice substrings from each string in the array. + If `start`, `stop`, or 'step` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - start : int, optional + start : int or array-like of int, optional Start position for slice operation. - stop : int, optional + If array-like, it is broadcast. + stop : int or array-like of int, optional Stop position for slice operation. - step : int, optional + If array-like, it is broadcast. + step : int or array-like of int, optional Step size for slice operation. + If array-like, it is broadcast. Returns ------- sliced strings : same type as values """ - s = slice(start, stop, step) - f = lambda x: x[s] - return self._apply(f) + f = lambda x, istart, istop, istep: x[slice(istart, istop, istep)] + return self._apply(func=f, func_args=(start, stop, step)) def slice_replace( self, - start: int = None, - stop: int = None, - repl: Union[str, bytes] = "", + start: Union[int, Any] = None, + stop: Union[int, Any] = None, + repl: Union[str, bytes, Any] = "", ) -> Any: """ Replace a positional slice of a string with another value. + If `start`, `stop`, or 'repl` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - start : int, optional + start : int or array-like of int, optional Left index position to use for the slice. If not specified (None), the slice is unbounded on the left, i.e. slice from the start - of the string. - stop : int, optional + of the string. If array-like, it is broadcast. + stop : int or array-like of int, optional Right index position to use for the slice. If not specified (None), the slice is unbounded on the right, i.e. slice until the - end of the string. - repl : str, optional + end of the string. If array-like, it is broadcast. + repl : str or array-like of str, optional String for replacement. If not specified, the sliced region - is replaced with an empty string. + is replaced with an empty string. If array-like, it is broadcast. Returns ------- @@ -285,25 +408,25 @@ def slice_replace( """ repl = self._stringify(repl) - def f(x): - if len(x[start:stop]) == 0: - local_stop = start + def func(x, istart, istop, irepl): + if len(x[istart:istop]) == 0: + local_stop = istart else: - local_stop = stop + local_stop = istop y = self._stringify("") - if start is not None: - y += x[:start] - y += repl - if stop is not None: + if istart is not None: + y += x[:istart] + y += irepl + if istop is not None: y += x[local_stop:] return y - return self._apply(f) + return self._apply(func=func, func_args=(start, stop, repl)) def cat( self, *others, - sep: Any = "", + sep: Union[str, bytes, Any] = "", ) -> Any: """ Concatenate strings elementwise in the DataArray with other strings. @@ -311,14 +434,15 @@ def cat( The other strings can either be string scalars or other array-like. Dimensions are automatically broadcast together. - An optional separator can also be specified. + An optional separator `sep` can also be specified. If `sep` is + array-like, it is broadcast against the array and applied elementwise. Parameters ---------- *others : str or array-like of str Strings or array-like of strings to concatenate elementwise with the current DataArray. - sep : str or array-like, default `""`. + sep : str or array-like of str, default: "". Seperator to use between strings. It is broadcast in the same way as the other input strings. If array-like, its dimensions will be placed at the end of the output array dimensions. @@ -366,7 +490,6 @@ def cat( ['4 cccc 3.4 test', '4, cccc, 3.4, , test']]], dtype=' Any: """ Concatenate strings in a DataArray along a particular dimension. - An optional separator can also be specified. + An optional separator `sep` can also be specified. If `sep` is + array-like, it is broadcast against the array and applied elementwise. Parameters ---------- - dim : Hashable, optional + dim : hashable, optional Dimension along which the strings should be concatenated. + Only one dimension is allowed at a time. Optional for 0D or 1D DataArrays, required for multidimensional DataArrays. - sep : str or array-like, default `""`. + sep : str or array-like, default: "". Seperator to use between strings. It is broadcast in the same way as the other input strings. If array-like, its dimensions will be placed at the end of the output array dimensions. @@ -433,7 +558,6 @@ def join( ['abcd--abcdef', 'abcd__abcdef']], dtype=' - [['spam and lancelot are like a duck', - 'spam and arthur are like a duck'], + [['spam and lancelot are like a duck', + 'spam and arthur are like a duck'], ['egg and lancelot are like a duck', - 'egg and arthur are like a duck']]], dtype=' Any: """ @@ -546,7 +671,7 @@ def capitalize(self) -> Any: ------- capitalized : same type as values """ - return self._apply(lambda x: x.capitalize()) + return self._apply(func=lambda x: x.capitalize()) def lower(self) -> Any: """ @@ -556,7 +681,7 @@ def lower(self) -> Any: ------- lowerd : same type as values """ - return self._apply(lambda x: x.lower()) + return self._apply(func=lambda x: x.lower()) def swapcase(self) -> Any: """ @@ -566,7 +691,7 @@ def swapcase(self) -> Any: ------- swapcased : same type as values """ - return self._apply(lambda x: x.swapcase()) + return self._apply(func=lambda x: x.swapcase()) def title(self) -> Any: """ @@ -576,7 +701,7 @@ def title(self) -> Any: ------- titled : same type as values """ - return self._apply(lambda x: x.title()) + return self._apply(func=lambda x: x.title()) def upper(self) -> Any: """ @@ -586,7 +711,7 @@ def upper(self) -> Any: ------- uppered : same type as values """ - return self._apply(lambda x: x.upper()) + return self._apply(func=lambda x: x.upper()) def casefold(self) -> Any: """ @@ -601,7 +726,7 @@ def casefold(self) -> Any: ------- casefolded : same type as values """ - return self._apply(lambda x: x.casefold()) + return self._apply(func=lambda x: x.casefold()) def normalize( self, @@ -615,16 +740,15 @@ def normalize( Parameters ---------- - form : {"NFC", "NFKC", "NFD", and "NFKD"} + form : {"NFC", "NFKC", "NFD", "NFKD"} Unicode form. Returns ------- normalized : same type as values - """ - return self._apply(lambda x: normalize(form, x)) + return self._apply(func=lambda x: normalize(form, x)) def isalnum(self) -> Any: """ @@ -635,7 +759,7 @@ def isalnum(self) -> Any: isalnum : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isalnum(), dtype=bool) + return self._apply(func=lambda x: x.isalnum(), dtype=bool) def isalpha(self) -> Any: """ @@ -646,7 +770,7 @@ def isalpha(self) -> Any: isalpha : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isalpha(), dtype=bool) + return self._apply(func=lambda x: x.isalpha(), dtype=bool) def isdecimal(self) -> Any: """ @@ -657,7 +781,7 @@ def isdecimal(self) -> Any: isdecimal : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isdecimal(), dtype=bool) + return self._apply(func=lambda x: x.isdecimal(), dtype=bool) def isdigit(self) -> Any: """ @@ -668,7 +792,7 @@ def isdigit(self) -> Any: isdigit : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isdigit(), dtype=bool) + return self._apply(func=lambda x: x.isdigit(), dtype=bool) def islower(self) -> Any: """ @@ -679,7 +803,7 @@ def islower(self) -> Any: islower : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.islower(), dtype=bool) + return self._apply(func=lambda x: x.islower(), dtype=bool) def isnumeric(self) -> Any: """ @@ -690,7 +814,7 @@ def isnumeric(self) -> Any: isnumeric : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isnumeric(), dtype=bool) + return self._apply(func=lambda x: x.isnumeric(), dtype=bool) def isspace(self) -> Any: """ @@ -701,7 +825,7 @@ def isspace(self) -> Any: isspace : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isspace(), dtype=bool) + return self._apply(func=lambda x: x.isspace(), dtype=bool) def istitle(self) -> Any: """ @@ -712,7 +836,7 @@ def istitle(self) -> Any: istitle : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.istitle(), dtype=bool) + return self._apply(func=lambda x: x.istitle(), dtype=bool) def isupper(self) -> Any: """ @@ -723,13 +847,13 @@ def isupper(self) -> Any: isupper : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isupper(), dtype=bool) + return self._apply(func=lambda x: x.isupper(), dtype=bool) def count( self, - pat: Union[str, bytes, Pattern], + pat: Union[str, bytes, Pattern, Any], flags: int = 0, - case: bool = True, + case: bool = None, ) -> Any: """ Count occurrences of pattern in each string of the array. @@ -738,15 +862,19 @@ def count( pattern is repeated in each of the string elements of the :class:`~xarray.DataArray`. + The pattern `pat` can either be a single ``str`` or `re.Pattern` or + array-like of ``str`` or `re.Pattern`. If array-like, it is broadcast + against the array and applied elementwise. + Parameters ---------- - pat : str or re.Pattern + pat : str or re.Pattern or array-like of str or re.Pattern A string containing a regular expression or a compiled regular - expression object. + expression object. If array-like, it is broadcast. flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. case : bool, default: True If True, case sensitive. @@ -757,22 +885,26 @@ def count( ------- counts : array of int """ - pat = self._re_compile(pat, flags, case) + pat = self._re_compile(pat=pat, flags=flags, case=case) - f = lambda x: len(pat.findall(x)) - return self._apply(f, dtype=int) + func = lambda x, ipat: len(ipat.findall(x)) + return self._apply(func=func, func_args=(pat,), dtype=int) def startswith( self, - pat: Union[str, bytes], + pat: Union[str, bytes, Any], ) -> Any: """ Test if the start of each string in the array matches a pattern. + The pattern `pat` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- pat : str Character sequence. Regular expressions are not accepted. + If array-like, it is broadcast. Returns ------- @@ -781,20 +913,24 @@ def startswith( the start of each string element. """ pat = self._stringify(pat) - f = lambda x: x.startswith(pat) - return self._apply(f, dtype=bool) + func = lambda x, y: x.startswith(y) + return self._apply(func=func, func_args=(pat,), dtype=bool) def endswith( self, - pat: Union[str, bytes], + pat: Union[str, bytes, Any], ) -> Any: """ Test if the end of each string in the array matches a pattern. + The pattern `pat` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- pat : str Character sequence. Regular expressions are not accepted. + If array-like, it is broadcast. Returns ------- @@ -803,119 +939,150 @@ def endswith( the end of each string element. """ pat = self._stringify(pat) - f = lambda x: x.endswith(pat) - return self._apply(f, dtype=bool) + func = lambda x, y: x.endswith(y) + return self._apply(func=func, func_args=(pat,), dtype=bool) def pad( self, - width: int, + width: Union[int, Any], side: str = "left", - fillchar: Union[str, bytes] = " ", + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Pad strings in the array up to width. + If `width` or 'fillchar` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with character defined in `fillchar`. + filled with character defined in ``fillchar``. + If array-like, it is broadcast. side : {"left", "right", "both"}, default: "left" Side from which to fill resulting string. - fillchar : str, default: " " - Additional character for filling, default is whitespace. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- filled : same type as values Array with a minimum number of char in each element. """ - width = int(width) - fillchar = self._stringify(fillchar) - if len(fillchar) != 1: - raise TypeError("fillchar must be a character, not str") - if side == "left": - f = lambda s: s.rjust(width, fillchar) + func = self.rjust elif side == "right": - f = lambda s: s.ljust(width, fillchar) + func = self.ljust elif side == "both": - f = lambda s: s.center(width, fillchar) + func = self.center else: # pragma: no cover raise ValueError("Invalid side") - return self._apply(f) + return func(width=width, fillchar=fillchar) + + def _padder( + self, + *, + func: Callable, + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", + ) -> Any: + """ + Wrapper function to handle padding operations + """ + fillchar = self._stringify(fillchar) + + def overfunc(x, iwidth, ifillchar): + if len(ifillchar) != 1: + raise TypeError("fillchar must be a character, not str") + return func(x, int(iwidth), ifillchar) + + return self._apply(func=overfunc, func_args=(width, fillchar)) def center( self, - width: int, - fillchar: Union[str, bytes] = " ", + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Pad left and right side of each string in the array. + If `width` or 'fillchar` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with ``fillchar`` - fillchar : str, default: " " - Additional character for filling, default is whitespace + filled with ``fillchar``. If array-like, it is broadcast. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- filled : same type as values """ - return self.pad(width, side="both", fillchar=fillchar) + func = self._obj.dtype.type.center + return self._padder(func=func, width=width, fillchar=fillchar) def ljust( self, - width: int, - fillchar: Union[str, bytes] = " ", + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Pad right side of each string in the array. + If `width` or 'fillchar` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with ``fillchar`` - fillchar : str, default: " " - Additional character for filling, default is whitespace + filled with ``fillchar``. If array-like, it is broadcast. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- filled : same type as values """ - return self.pad(width, side="right", fillchar=fillchar) + func = self._obj.dtype.type.ljust + return self._padder(func=func, width=width, fillchar=fillchar) def rjust( self, - width: int, - fillchar: Union[str, bytes] = " ", + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Pad left side of each string in the array. + If `width` or 'fillchar` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with ``fillchar`` - fillchar : str, default: " " - Additional character for filling, default is whitespace + filled with ``fillchar``. If array-like, it is broadcast. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- filled : same type as values """ - return self.pad(width, side="left", fillchar=fillchar) + func = self._obj.dtype.type.rjust + return self._padder(func=func, width=width, fillchar=fillchar) - def zfill( - self, - width: int, - ) -> Any: + def zfill(self, width: Union[int, Any]) -> Any: """ Pad each string in the array by prepending '0' characters. @@ -923,22 +1090,25 @@ def zfill( left of the string to reach a total string length `width`. Strings in the array with length greater or equal to `width` are unchanged. + If `width` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum length of resulting string; strings with length less - than `width` be prepended with '0' characters. + than `width` be prepended with '0' characters. If array-like, it is broadcast. Returns ------- filled : same type as values """ - return self.pad(width, side="left", fillchar="0") + return self.rjust(width, fillchar="0") def contains( self, - pat: Union[str, bytes, Pattern], - case: bool = True, + pat: Union[str, bytes, Pattern, Any], + case: bool = None, flags: int = 0, regex: bool = True, ) -> Any: @@ -948,11 +1118,15 @@ def contains( Return boolean array based on whether a given pattern or regex is contained within a string of the array. + The pattern `pat` can either be a single ``str`` or `re.Pattern` or + array-like of ``str`` or `re.Pattern`. If array-like, it is broadcast + against the array and applied elementwise. + Parameters ---------- - pat : str or re.Pattern + pat : str or re.Pattern or array-like of str or re.Pattern Character sequence, a string containing a regular expression, - or a compiled regular expression object. + or a compiled regular expression object. If array-like, it is broadcast. case : bool, default: True If True, case sensitive. Cannot be set if `pat` is a compiled regex. @@ -960,7 +1134,7 @@ def contains( flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. regex : bool, default: True If True, assumes the pat is a regular expression. @@ -974,42 +1148,54 @@ def contains( given pattern is contained within the string of each element of the array. """ - is_compiled_re = isinstance(pat, self._pattern_type) + is_compiled_re = _contains_compiled_re(pat) if is_compiled_re and not regex: raise ValueError( "Must use regular expression matching for regular expression object." ) if regex: - pat = self._re_compile(pat, flags, case) - if pat.groups > 0: # pragma: no cover - raise ValueError("This pattern has match groups.") + if not is_compiled_re: + pat = self._re_compile(pat=pat, flags=flags, case=case) + + def func(x, ipat): + if ipat.groups > 0: # pragma: no cover + raise ValueError("This pattern has match groups.") + return bool(ipat.search(x)) - f = lambda x: bool(pat.search(x)) else: pat = self._stringify(pat) if case or case is None: - f = lambda x: pat in x + func = lambda x, ipat: ipat in x + elif self._obj.dtype.char == "U": + uppered = self._obj.str.casefold() + uppat = StringAccessor(pat).casefold() + return uppered.str.contains(uppat, regex=False) else: uppered = self._obj.str.upper() - return uppered.str.contains(pat.upper(), regex=False) + uppat = StringAccessor(pat).upper() + return uppered.str.contains(uppat, regex=False) - return self._apply(f, dtype=bool) + return self._apply(func=func, func_args=(pat,), dtype=bool) def match( self, - pat: Union[str, bytes, Pattern], - case: bool = True, + pat: Union[str, bytes, Pattern, Any], + case: bool = None, flags: int = 0, ) -> Any: """ Determine if each string in the array matches a regular expression. + The pattern `pat` can either be a single ``str`` or `re.Pattern` or + array-like of ``str`` or `re.Pattern`. If array-like, it is broadcast + against the array and applied elementwise. + Parameters ---------- - pat : str or re.Pattern + pat : str or re.Pattern or array-like of str or re.Pattern A string containing a regular expression or - a compiled regular expression object. + a compiled regular expression object. If array-like, it is broadcast. case : bool, default: True If True, case sensitive. Cannot be set if `pat` is a compiled regex. @@ -1017,21 +1203,21 @@ def match( flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. Returns ------- matched : array of bool """ - pat = self._re_compile(pat, flags, case) + pat = self._re_compile(pat=pat, flags=flags, case=case) - f = lambda x: bool(pat.match(x)) - return self._apply(f, dtype=bool) + func = lambda x, ipat: bool(ipat.match(x)) + return self._apply(func=func, func_args=(pat,), dtype=bool) def strip( self, - to_strip: Union[str, bytes] = None, + to_strip: Union[str, bytes, Any] = None, side: str = "both", ) -> Any: """ @@ -1040,13 +1226,16 @@ def strip( Strip whitespaces (including newlines) or a set of specified characters from each string in the array from left and/or right sides. + `to_strip` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- - to_strip : str or None, default: None + to_strip : str or array-like of str or None, default: None Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. - If None then whitespaces are removed. - side : {"left", "right", "both"}, default: "left" + If None then whitespaces are removed. If array-like, it is broadcast. + side : {"left", "right", "both"}, default: "both" Side from which to strip. Returns @@ -1057,19 +1246,19 @@ def strip( to_strip = self._stringify(to_strip) if side == "both": - f = lambda x: x.strip(to_strip) + func = lambda x, y: x.strip(y) elif side == "left": - f = lambda x: x.lstrip(to_strip) + func = lambda x, y: x.lstrip(y) elif side == "right": - f = lambda x: x.rstrip(to_strip) + func = lambda x, y: x.rstrip(y) else: # pragma: no cover raise ValueError("Invalid side") - return self._apply(f) + return self._apply(func=func, func_args=(to_strip,)) def lstrip( self, - to_strip: Union[str, bytes] = None, + to_strip: Union[str, bytes, Any] = None, ) -> Any: """ Remove leading characters. @@ -1077,12 +1266,15 @@ def lstrip( Strip whitespaces (including newlines) or a set of specified characters from each string in the array from the left side. + `to_strip` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- - to_strip : str or None, default: None + to_strip : str or array-like of str or None, default: None Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. - If None then whitespaces are removed. + If None then whitespaces are removed. If array-like, it is broadcast. Returns ------- @@ -1092,7 +1284,7 @@ def lstrip( def rstrip( self, - to_strip: Union[str, bytes] = None, + to_strip: Union[str, bytes, Any] = None, ) -> Any: """ Remove trailing characters. @@ -1100,12 +1292,15 @@ def rstrip( Strip whitespaces (including newlines) or a set of specified characters from each string in the array from the right side. + `to_strip` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- - to_strip : str or None, default: None + to_strip : str or array-like of str or None, default: None Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. - If None then whitespaces are removed. + If None then whitespaces are removed. If array-like, it is broadcast. Returns ------- @@ -1115,7 +1310,7 @@ def rstrip( def wrap( self, - width: int, + width: Union[int, Any], **kwargs, ) -> Any: """ @@ -1124,10 +1319,14 @@ def wrap( This method has the same keyword parameters and defaults as :class:`textwrap.TextWrapper`. + If `width` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- - width : int - Maximum line-width + width : int or array-like of int + Maximum line-width. + If array-like, it is broadcast. **kwargs keyword arguments passed into :class:`textwrap.TextWrapper`. @@ -1135,9 +1334,10 @@ def wrap( ------- wrapped : same type as values """ - tw = textwrap.TextWrapper(width=width, **kwargs) - f = lambda x: "\n".join(tw.wrap(x)) - return self._apply(f) + ifunc = lambda x: textwrap.TextWrapper(width=x, **kwargs) + tw = StringAccessor(width)._apply(func=ifunc, dtype=np.object_) + func = lambda x, itw: "\n".join(itw.wrap(x)) + return self._apply(func=func, func_args=(tw,)) def translate( self, @@ -1158,34 +1358,38 @@ def translate( ------- translated : same type as values """ - f = lambda x: x.translate(table) - return self._apply(f) + func = lambda x: x.translate(table) + return self._apply(func=func) def repeat( self, - repeats: int, + repeats: Union[int, Any], ) -> Any: """ - Duplicate each string in the array. + Repeat each string in the array. + + If `repeats` is array-like, it is broadcast against the array and applied + elementwise. Parameters ---------- - repeats : int + repeats : int or array-like of int Number of repetitions. + If array-like, it is broadcast. Returns ------- repeated : same type as values Array of repeated string objects. """ - f = lambda x: repeats * x - return self._apply(f) + func = lambda x, y: x * y + return self._apply(func=func, func_args=(repeats,)) def find( self, - sub: Union[str, bytes], - start: int = 0, - end: int = None, + sub: Union[str, bytes, Any], + start: Union[int, Any] = 0, + end: Union[int, Any] = None, side: str = "left", ) -> Any: """ @@ -1193,14 +1397,20 @@ def find( where the substring is fully contained between [start:end]. Return -1 on failure. + If `start`, `end`, or 'sub` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - sub : str - Substring being searched - start : int - Left edge index - end : int - Right edge index + sub : str or array-like of str + Substring being searched. + If array-like, it is broadcast. + start : int or array-like of int + Left edge index. + If array-like, it is broadcast. + end : int or array-like of int + Right edge index. + If array-like, it is broadcast. side : {"left", "right"}, default: "left" Starting side for search. @@ -1217,32 +1427,34 @@ def find( else: # pragma: no cover raise ValueError("Invalid side") - if end is None: - f = lambda x: getattr(x, method)(sub, start) - else: - f = lambda x: getattr(x, method)(sub, start, end) - - return self._apply(f, dtype=int) + func = lambda x, isub, istart, iend: getattr(x, method)(isub, istart, iend) + return self._apply(func=func, func_args=(sub, start, end), dtype=int) def rfind( self, - sub: Union[str, bytes], - start: int = 0, - end: int = None, + sub: Union[str, bytes, Any], + start: Union[int, Any] = 0, + end: Union[int, Any] = None, ) -> Any: """ Return highest indexes in each strings in the array where the substring is fully contained between [start:end]. Return -1 on failure. + If `start`, `end`, or 'sub` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - sub : str - Substring being searched - start : int - Left edge index - end : int - Right edge index + sub : str or array-like of str + Substring being searched. + If array-like, it is broadcast. + start : int or array-like of int + Left edge index. + If array-like, it is broadcast. + end : int or array-like of int + Right edge index. + If array-like, it is broadcast. Returns ------- @@ -1252,9 +1464,9 @@ def rfind( def index( self, - sub: Union[str, bytes], - start: int = 0, - end: int = None, + sub: Union[str, bytes, Any], + start: Union[int, Any] = 0, + end: Union[int, Any] = None, side: str = "left", ) -> Any: """ @@ -1263,14 +1475,20 @@ def index( ``str.find`` except instead of returning -1, it raises a ValueError when the substring is not found. + If `start`, `end`, or 'sub` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - sub : str - Substring being searched - start : int - Left edge index - end : int - Right edge index + sub : str or array-like of str + Substring being searched. + If array-like, it is broadcast. + start : int or array-like of int + Left edge index. + If array-like, it is broadcast. + end : int or array-like of int + Right edge index. + If array-like, it is broadcast. side : {"left", "right"}, default: "left" Starting side for search. @@ -1292,18 +1510,14 @@ def index( else: # pragma: no cover raise ValueError("Invalid side") - if end is None: - f = lambda x: getattr(x, method)(sub, start) - else: - f = lambda x: getattr(x, method)(sub, start, end) - - return self._apply(f, dtype=int) + func = lambda x, isub, istart, iend: getattr(x, method)(isub, istart, iend) + return self._apply(func=func, func_args=(sub, start, end), dtype=int) def rindex( self, - sub: Union[str, bytes], - start: int = 0, - end: int = None, + sub: Union[str, bytes, Any], + start: Union[int, Any] = 0, + end: Union[int, Any] = None, ) -> Any: """ Return highest indexes in each strings where the substring is @@ -1311,14 +1525,20 @@ def rindex( ``str.rfind`` except instead of returning -1, it raises a ValueError when the substring is not found. + If `start`, `end`, or 'sub` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - sub : str - Substring being searched - start : int - Left edge index - end : int - Right edge index + sub : str or array-like of str + Substring being searched. + If array-like, it is broadcast. + start : int or array-like of int + Left edge index. + If array-like, it is broadcast. + end : int or array-like of int + Right edge index. + If array-like, it is broadcast. Returns ------- @@ -1333,9 +1553,9 @@ def rindex( def replace( self, - pat: Union[str, bytes, Pattern], - repl: Union[str, bytes, Callable], - n: int = -1, + pat: Union[str, bytes, Pattern, Any], + repl: Union[str, bytes, Callable, Any], + n: Union[int, Any] = -1, case: bool = None, flags: int = 0, regex: bool = True, @@ -1343,16 +1563,22 @@ def replace( """ Replace occurrences of pattern/regex in the array with some string. + If `pat`, `repl`, or 'n` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - pat : str or re.Pattern + pat : str or re.Pattern or array-like of str or re.Pattern String can be a character sequence or regular expression. - repl : str or callable + If array-like, it is broadcast. + repl : str or callable or array-like of str or callable Replacement string or a callable. The callable is passed the regex match object and must return a replacement string to be used. See :func:`re.sub`. - n : int, default: -1 + If array-like, it is broadcast. + n : int or array of int, default: -1 Number of replacements to make from start. Use ``-1`` to replace all. + If array-like, it is broadcast. case : bool, default: True If True, case sensitive. Cannot be set if `pat` is a compiled regex. @@ -1360,7 +1586,7 @@ def replace( flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. regex : bool, default: True If True, assumes the passed-in pattern is a regular expression. @@ -1374,13 +1600,12 @@ def replace( A copy of the object with all matching occurrences of `pat` replaced by `repl`. """ - if not _is_str_like(repl) and not callable(repl): # pragma: no cover - raise TypeError("repl must be a string or callable") - - if _is_str_like(repl): + if _contains_str_like(repl): repl = self._stringify(repl) + elif not _contains_callable(repl): # pragma: no cover + raise TypeError("repl must be a string or callable") - is_compiled_re = isinstance(pat, self._pattern_type) + is_compiled_re = _contains_compiled_re(pat) if not regex and is_compiled_re: raise ValueError( "Cannot use a compiled regex as replacement pattern with regex=False" @@ -1390,17 +1615,18 @@ def replace( raise ValueError("Cannot use a callable replacement when regex=False") if regex: - pat = self._re_compile(pat, flags, case) - n = n if n >= 0 else 0 - f = lambda x: pat.sub(repl=repl, string=x, count=n) + pat = self._re_compile(pat=pat, flags=flags, case=case) + func = lambda x, ipat, irepl, i_n: ipat.sub( + repl=irepl, string=x, count=i_n if i_n >= 0 else 0 + ) else: pat = self._stringify(pat) - f = lambda x: x.replace(pat, repl, n) - return self._apply(f) + func = lambda x, ipat, irepl, i_n: x.replace(ipat, irepl, i_n) + return self._apply(func=func, func_args=(pat, repl, n)) def extract( self, - pat: Union[str, bytes, Pattern], + pat: Union[str, bytes, Pattern, Any], dim: Hashable, case: bool = None, flags: int = 0, @@ -1412,12 +1638,15 @@ def extract( For each string in the DataArray, extract groups from the first match of regular expression pat. + If `pat` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- - pat : str or re.Pattern + pat : str or re.Pattern or array-like of str or re.Pattern A string containing a regular expression or a compiled regular - expression object. - dim : hashable or `None` + expression object. If array-like, it is broadcast. + dim : hashable or None Name of the new dimension to store the captured strings in. If None, the pattern must have only one capture group and the resulting DataArray will have the same size as the original. @@ -1428,7 +1657,7 @@ def extract( flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. Returns @@ -1440,7 +1669,7 @@ def extract( ValueError `pat` has no capture groups. ValueError - `dim` is `None` and there is more than one capture group. + `dim` is None and there is more than one capture group. ValueError `case` is set when `pat` is a compiled regular expression. KeyError @@ -1487,20 +1716,29 @@ def extract( re.search pandas.Series.str.extract """ - pat = self._re_compile(pat, flags, case) + pat = self._re_compile(pat=pat, flags=flags, case=case) + + if isinstance(pat, re.Pattern): + maxgroups = pat.groups + else: + maxgroups = ( + _apply_str_ufunc(obj=pat, func=lambda x: x.groups, dtype=np.int_) + .max() + .data.tolist() + ) - if pat.groups == 0: + if maxgroups == 0: raise ValueError("No capture groups found in pattern.") - if dim is None and pat.groups != 1: + if dim is None and maxgroups != 1: raise ValueError( - "dim must be specified if more than one capture group is given." + "Dimension must be specified if more than one capture group is given." ) if dim is not None and dim in self._obj.dims: - raise KeyError(f"Dimension {dim} already present in DataArray.") + raise KeyError(f"Dimension '{dim}' already present in DataArray.") - def _get_res_single(val, pat=pat): + def _get_res_single(val, pat): match = pat.search(val) if match is None: return "" @@ -1509,7 +1747,7 @@ def _get_res_single(val, pat=pat): res = "" return res - def _get_res_multi(val, pat=pat): + def _get_res_multi(val, pat): match = pat.search(val) if match is None: return np.array([""], val.dtype) @@ -1518,20 +1756,21 @@ def _get_res_multi(val, pat=pat): return np.array(match, val.dtype) if dim is None: - return self._apply(_get_res_single) + return self._apply(func=_get_res_single, func_args=(pat,)) else: # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 return self._apply( - _get_res_multi, + func=_get_res_multi, + func_args=(pat,), dtype=np.object_, output_core_dims=[[dim]], - output_sizes={dim: pat.groups}, + output_sizes={dim: maxgroups}, ).astype(self._obj.dtype.kind) def extractall( self, - pat: Union[str, bytes, Pattern], + pat: Union[str, bytes, Pattern, Any], group_dim: Hashable, match_dim: Hashable, case: bool = None, @@ -1546,15 +1785,18 @@ def extractall( Equivalent to applying re.findall() to all the elements in the DataArray and splitting the results across dimensions. + If `pat` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- pat : str or re.Pattern A string containing a regular expression or a compiled regular - expression object. - group_dim: hashable + expression object. If array-like, it is broadcast. + group_dim : hashable Name of the new dimensions corresponding to the capture groups. This dimension is added to the new DataArray first. - match_dim: hashable + match_dim : hashable Name of the new dimensions corresponding to the matches for each group. This dimension is added to the new DataArray second. case : bool, default: True @@ -1564,7 +1806,7 @@ def extractall( flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. Returns @@ -1634,7 +1876,6 @@ def extractall( ['', '']]]], dtype=' Any: @@ -1700,11 +1958,14 @@ def findall( If there are multiple capture groups, the lists will be a sequence of lists, each of which contains a sequence of matches. + If `pat` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- pat : str or re.Pattern A string containing a regular expression or a compiled regular - expression object. + expression object. If array-like, it is broadcast. case : bool, default: True If True, case sensitive. Cannot be set if `pat` is a compiled regex. @@ -1712,7 +1973,7 @@ def findall( flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. Returns @@ -1765,18 +2026,22 @@ def findall( re.findall pandas.Series.str.findall """ - pat = self._re_compile(pat, flags, case) + pat = self._re_compile(pat=pat, flags=flags, case=case) - if pat.groups == 0: - raise ValueError("No capture groups found in pattern.") + def func(x, ipat): + if ipat.groups == 0: + raise ValueError("No capture groups found in pattern.") + + return ipat.findall(x) - return self._apply(pat.findall, dtype=np.object_) + return self._apply(func=func, func_args=(pat,), dtype=np.object_) def _partitioner( self, + *, func: Callable, dim: Hashable, - sep: Optional[Union[str, bytes]], + sep: Optional[Union[str, bytes, Any]], ) -> Any: """ Implements logic for `partition` and `rpartition`. @@ -1784,19 +2049,20 @@ def _partitioner( sep = self._stringify(sep) if dim is None: - f = lambda x: list(func(x, sep)) - return self._apply(f, dtype=np.object_) + listfunc = lambda x, isep: list(func(x, isep)) + return self._apply(func=listfunc, func_args=(sep,), dtype=np.object_) # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, -1) + return self._obj.copy().expand_dims({dim: 0}, axis=-1) - f = lambda x: np.array(func(x, sep), dtype=self._obj.dtype) + arrfunc = lambda x, isep: np.array(func(x, isep), dtype=self._obj.dtype) # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 return self._apply( - f, + func=arrfunc, + func_args=(sep,), dtype=np.object_, output_core_dims=[[dim]], output_sizes={dim: 3}, @@ -1805,7 +2071,7 @@ def _partitioner( def partition( self, dim: Optional[Hashable], - sep: Union[str, bytes] = " ", + sep: Union[str, bytes, Any] = " ", ) -> Any: """ Split the strings in the DataArray at the first occurrence of separator `sep`. @@ -1816,15 +2082,17 @@ def partition( If the separator is not found, return 3 elements containing the string itself, followed by two empty strings. - This is equivalent to :meth:`str.partion`. + If `sep` is array-like, it is broadcast against the array and applied + elementwise. Parameters ---------- - dim : Hashable or `None` + dim : hashable or None Name for the dimension to place the 3 elements in. - If `None`, place the results as list elements in an object DataArray - sep : str, default `" "` + If `None`, place the results as list elements in an object DataArray. + sep : str, default: " " String to split on. + If array-like, it is broadcast. Returns ------- @@ -1841,7 +2109,7 @@ def partition( def rpartition( self, dim: Optional[Hashable], - sep: Union[str, bytes] = " ", + sep: Union[str, bytes, Any] = " ", ) -> Any: """ Split the strings in the DataArray at the last occurrence of separator `sep`. @@ -1852,15 +2120,17 @@ def rpartition( If the separator is not found, return 3 elements containing two empty strings, followed by the string itself. - This is equivalent to :meth:`str.rpartion`. + If `sep` is array-like, it is broadcast against the array and applied + elementwise. Parameters ---------- - dim : Hashable or `None` + dim : hashable or None Name for the dimension to place the 3 elements in. - If `None`, place the results as list elements in an object DataArray - sep : str, default `" "` + If `None`, place the results as list elements in an object DataArray. + sep : str, default: " " String to split on. + If array-like, it is broadcast. Returns ------- @@ -1876,10 +2146,11 @@ def rpartition( def _splitter( self, + *, func: Callable, pre: bool, dim: Hashable, - sep: Optional[Union[str, bytes]], + sep: Optional[Union[str, bytes, Any]], maxsplit: int, ) -> Any: """ @@ -1889,17 +2160,20 @@ def _splitter( sep = self._stringify(sep) if dim is None: - f = lambda x: func(x, sep, maxsplit) - return self._apply(f, dtype=np.object_) + f_none = lambda x, isep: func(x, isep, maxsplit) + return self._apply(func=f_none, func_args=(sep,), dtype=np.object_) # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, -1) + return self._obj.copy().expand_dims({dim: 0}, axis=-1) - f_count = lambda x: max(len(func(x, sep, maxsplit)), 1) - maxsplit = self._apply(f_count, dtype=np.int_).max().data.tolist() - 1 + f_count = lambda x, isep: max(len(func(x, isep, maxsplit)), 1) + maxsplit = ( + self._apply(func=f_count, func_args=(sep,), dtype=np.int_).max().data.item() + - 1 + ) - def _dosplit(mystr, sep=sep, maxsplit=maxsplit, dtype=self._obj.dtype): + def _dosplit(mystr, sep, maxsplit=maxsplit, dtype=self._obj.dtype): res = func(mystr, sep, maxsplit) if len(res) < maxsplit + 1: pad = [""] * (maxsplit + 1 - len(res)) @@ -1912,7 +2186,8 @@ def _dosplit(mystr, sep=sep, maxsplit=maxsplit, dtype=self._obj.dtype): # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 return self._apply( - _dosplit, + func=_dosplit, + func_args=(sep,), dtype=np.object_, output_core_dims=[[dim]], output_sizes={dim: maxsplit}, @@ -1921,7 +2196,7 @@ def _dosplit(mystr, sep=sep, maxsplit=maxsplit, dtype=self._obj.dtype): def split( self, dim: Optional[Hashable], - sep: Union[str, bytes] = None, + sep: Union[str, bytes, Any] = None, maxsplit: int = -1, ) -> Any: """ @@ -1930,18 +2205,20 @@ def split( Splits the string in the DataArray from the beginning, at the specified delimiter string. - This is equivalent to :meth:`str.split`. + If `sep` is array-like, it is broadcast against the array and applied + elementwise. Parameters ---------- - dim : Hashable or `None` + dim : hashable or None Name for the dimension to place the results in. - If `None`, place the results as list elements in an object DataArray - sep : str, default is split on any whitespace. - String to split on. - maxsplit : int, default -1 (all) + If `None`, place the results as list elements in an object DataArray. + sep : str, default: None + String to split on. If ``None`` (the default), split on any whitespace. + If array-like, it is broadcast. + maxsplit : int, default: -1 Limit number of splits in output, starting from the beginning. - -1 will return all splits. + If -1 (the default), return all splits. Returns ------- @@ -2035,8 +2312,8 @@ def split( def rsplit( self, dim: Optional[Hashable], - sep: Union[str, bytes] = None, - maxsplit: int = -1, + sep: Union[str, bytes, Any] = None, + maxsplit: Union[int, Any] = -1, ) -> Any: """ Split strings in a DataArray around the given separator/delimiter `sep`. @@ -2044,18 +2321,20 @@ def rsplit( Splits the string in the DataArray from the end, at the specified delimiter string. - This is equivalent to :meth:`str.rsplit`. + If `sep` is array-like, it is broadcast against the array and applied + elementwise. Parameters ---------- - dim : Hashable or `None` + dim : hashable or None Name for the dimension to place the results in. If `None`, place the results as list elements in an object DataArray - sep : str, default is split on any whitespace. - String to split on. - maxsplit : int, default -1 (all) + sep : str, default: None + String to split on. If ``None`` (the default), split on any whitespace. + If array-like, it is broadcast. + maxsplit : int, default: -1 Limit number of splits in output, starting from the end. - -1 will return all splits. + If -1 (the default), return all splits. The final number of split values may be less than this if there are no DataArray elements with that many values. @@ -2151,7 +2430,7 @@ def rsplit( def get_dummies( self, dim: Hashable, - sep: Union[str, bytes] = "|", + sep: Union[str, bytes, Any] = "|", ) -> Any: """ Return DataArray of dummy/indicator variables. @@ -2161,12 +2440,16 @@ def get_dummies( and the corresponding element of that dimension is `True` if that result is present and `False` if not. + If `sep` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- - dim : Hashable + dim : hashable Name for the dimension to place the results in. - sep : str, default `"|"`. + sep : str, default: "|". String to split on. + If array-like, it is broadcast. Returns ------- @@ -2205,16 +2488,16 @@ def get_dummies( """ # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, -1) + return self._obj.copy().expand_dims({dim: 0}, axis=-1) sep = self._stringify(sep) - f_set = lambda x: set(x.split(sep)) - {self._stringify("")} - setarr = self._apply(f_set, dtype=np.object_) + f_set = lambda x, isep: set(x.split(isep)) - {self._stringify("")} + setarr = self._apply(func=f_set, func_args=(sep,), dtype=np.object_) vals = sorted(reduce(set_union, setarr.data.ravel())) - f = lambda x: np.array([val in x for val in vals], dtype=np.bool_) - res = self._apply( - f, + func = lambda x: np.array([val in x for val in vals], dtype=np.bool_) + res = _apply_str_ufunc( + func=func, obj=setarr, output_core_dims=[[dim]], output_sizes={dim: len(vals)}, @@ -2234,18 +2517,27 @@ def decode( Parameters ---------- encoding : str + The encoding to use. + Please see the Python `codecs `_ documentation for a list + of encodings handlers errors : str, optional + The handler for encoding errors. + Please see the Python `codecs `_ documentation for a list + of error handlers Returns ------- decoded : same type as values + + .. _encodings: https://docs.python.org/3/library/codecs.html#standard-encodings + .. _handlers: https://docs.python.org/3/library/codecs.html#error-handlers """ if encoding in _cpython_optimized_decoders: - f = lambda x: x.decode(encoding, errors) + func = lambda x: x.decode(encoding, errors) else: decoder = codecs.getdecoder(encoding) - f = lambda x: decoder(x, errors)[0] - return self._apply(f, dtype=np.str_) + func = lambda x: decoder(x, errors)[0] + return self._apply(func=func, dtype=np.str_) def encode( self, @@ -2258,15 +2550,24 @@ def encode( Parameters ---------- encoding : str + The encoding to use. + Please see the Python `codecs `_ documentation for a list + of encodings handlers errors : str, optional + The handler for encoding errors. + Please see the Python `codecs `_ documentation for a list + of error handlers Returns ------- encoded : same type as values + + .. _encodings: https://docs.python.org/3/library/codecs.html#standard-encodings + .. _handlers: https://docs.python.org/3/library/codecs.html#error-handlers """ if encoding in _cpython_optimized_encoders: - f = lambda x: x.encode(encoding, errors) + func = lambda x: x.encode(encoding, errors) else: encoder = codecs.getencoder(encoding) - f = lambda x: encoder(x, errors)[0] - return self._apply(f, dtype=np.bytes_) + func = lambda x: encoder(x, errors)[0] + return self._apply(func=func, dtype=np.bytes_) diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index bbc9659668c..9bf33893241 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -69,20 +69,58 @@ def test_dask(): def test_count(dtype): values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) - result = values.str.count("f[o]+") + pat_str = dtype(r"f[o]+") + pat_re = re.compile(pat_str) + + result_str = values.str.count(pat_str) + result_re = values.str.count(pat_re) + expected = xr.DataArray([1, 2, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + + assert result_str.dtype == expected.dtype + assert result_re.dtype == expected.dtype + assert_equal(result_str, expected) + assert_equal(result_re, expected) + + +def test_count_broadcast(dtype): + values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) + pat_str = np.array([r"f[o]+", r"o", r"m"]).astype(dtype) + pat_re = np.array([re.compile(x) for x in pat_str]) + + result_str = values.str.count(pat_str) + result_re = values.str.count(pat_re) + + expected = xr.DataArray([1, 4, 3]) + + assert result_str.dtype == expected.dtype + assert result_re.dtype == expected.dtype + assert_equal(result_str, expected) + assert_equal(result_re, expected) def test_contains(dtype): values = xr.DataArray(["Foo", "xYz", "fOOomMm__fOo", "MMM_"]).astype(dtype) # case insensitive using regex - result = values.str.contains("FOO|mmm", case=False) + pat = values.dtype.type("FOO|mmm") + result = values.str.contains(pat, case=False) expected = xr.DataArray([True, False, True, True]) assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.contains(re.compile(pat, flags=re.IGNORECASE)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case sensitive using regex + pat = values.dtype.type("Foo|mMm") + result = values.str.contains(pat) + expected = xr.DataArray([True, False, True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.contains(re.compile(pat)) + assert result.dtype == expected.dtype + assert_equal(result, expected) # case insensitive without regex result = values.str.contains("foo", regex=False, case=False) @@ -90,6 +128,87 @@ def test_contains(dtype): assert result.dtype == expected.dtype assert_equal(result, expected) + # case sensitive without regex + result = values.str.contains("fO", regex=False, case=True) + expected = xr.DataArray([False, False, True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # regex regex=False + pat_re = re.compile("(/w+)") + with pytest.raises( + ValueError, + match="Must use regular expression matching for regular expression object.", + ): + values.str.contains(pat_re, regex=False) + + +def test_contains_broadcast(dtype): + values = xr.DataArray(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dims="X").astype( + dtype + ) + pat_str = xr.DataArray(["FOO|mmm", "Foo", "MMM"], dims="Y").astype(dtype) + pat_re = xr.DataArray([re.compile(x) for x in pat_str.data], dims="Y") + + # case insensitive using regex + result = values.str.contains(pat_str, case=False) + expected = xr.DataArray( + [ + [True, True, False], + [False, False, False], + [True, True, True], + [True, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case sensitive using regex + result = values.str.contains(pat_str) + expected = xr.DataArray( + [ + [False, True, False], + [False, False, False], + [False, False, False], + [False, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.contains(pat_re) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case insensitive without regex + result = values.str.contains(pat_str, regex=False, case=False) + expected = xr.DataArray( + [ + [False, True, False], + [False, False, False], + [False, True, True], + [False, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case insensitive with regex + result = values.str.contains(pat_str, regex=False, case=True) + expected = xr.DataArray( + [ + [False, True, False], + [False, False, False], + [False, False, False], + [False, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + def test_starts_ends_with(dtype): values = xr.DataArray(["om", "foo_nom", "nom", "bar_foo", "foo"]).astype(dtype) @@ -105,15 +224,37 @@ def test_starts_ends_with(dtype): assert_equal(result, expected) -def test_case_bytes(dtype): - dtype = np.bytes_ - value = xr.DataArray(["SOme wOrd"]).astype(dtype) +def test_starts_ends_with_broadcast(dtype): + values = xr.DataArray( + ["om", "foo_nom", "nom", "bar_foo", "foo_bar"], dims="X" + ).astype(dtype) + pat = xr.DataArray(["foo", "bar"], dims="Y").astype(dtype) - exp_capitalized = xr.DataArray(["Some word"]).astype(dtype) - exp_lowered = xr.DataArray(["some word"]).astype(dtype) - exp_swapped = xr.DataArray(["soME WoRD"]).astype(dtype) - exp_titled = xr.DataArray(["Some Word"]).astype(dtype) - exp_uppered = xr.DataArray(["SOME WORD"]).astype(dtype) + result = values.str.startswith(pat) + expected = xr.DataArray( + [[False, False], [True, False], [False, False], [False, True], [True, False]], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.endswith(pat) + expected = xr.DataArray( + [[False, False], [False, False], [False, False], [True, False], [False, True]], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_case_bytes(): + value = xr.DataArray(["SOme wOrd"]).astype(np.bytes_) + + exp_capitalized = xr.DataArray(["Some word"]).astype(np.bytes_) + exp_lowered = xr.DataArray(["some word"]).astype(np.bytes_) + exp_swapped = xr.DataArray(["soME WoRD"]).astype(np.bytes_) + exp_titled = xr.DataArray(["Some Word"]).astype(np.bytes_) + exp_uppered = xr.DataArray(["SOME WORD"]).astype(np.bytes_) res_capitalized = value.str.capitalize() res_lowered = value.str.lower() @@ -127,31 +268,33 @@ def test_case_bytes(dtype): assert res_titled.dtype == exp_titled.dtype assert res_uppered.dtype == exp_uppered.dtype - assert_equal(value.str.capitalize(), exp_capitalized) - assert_equal(value.str.lower(), exp_lowered) - assert_equal(value.str.swapcase(), exp_swapped) - assert_equal(value.str.title(), exp_titled) - assert_equal(value.str.upper(), exp_uppered) + assert_equal(res_capitalized, exp_capitalized) + assert_equal(res_lowered, exp_lowered) + assert_equal(res_swapped, exp_swapped) + assert_equal(res_titled, exp_titled) + assert_equal(res_uppered, exp_uppered) def test_case_str(): - dtype = np.str_ - # This string includes some unicode characters # that are common case management corner cases - value = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(dtype) - - exp_capitalized = xr.DataArray(["Some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(dtype) - exp_lowered = xr.DataArray(["some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(dtype) - exp_swapped = xr.DataArray(["soME WoRD dž SS ᾛ σς FFI⁵å ç ⅰ"]).astype(dtype) - exp_titled = xr.DataArray(["Some Word Dž Ss ᾛ Σς Ffi⁵Å Ç Ⅰ"]).astype(dtype) - exp_uppered = xr.DataArray(["SOME WORD DŽ SS ἫΙ ΣΣ FFI⁵Å Ç Ⅰ"]).astype(dtype) - exp_casefolded = xr.DataArray(["some word dž ss ἣι σσ ffi⁵å ç ⅰ"]).astype(dtype) + value = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + + exp_capitalized = xr.DataArray(["Some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.unicode_) + exp_lowered = xr.DataArray(["some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.unicode_) + exp_swapped = xr.DataArray(["soME WoRD dž SS ᾛ σς FFI⁵å ç ⅰ"]).astype(np.unicode_) + exp_titled = xr.DataArray(["Some Word Dž Ss ᾛ Σς Ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + exp_uppered = xr.DataArray(["SOME WORD DŽ SS ἫΙ ΣΣ FFI⁵Å Ç Ⅰ"]).astype(np.unicode_) + exp_casefolded = xr.DataArray(["some word dž ss ἣι σσ ffi⁵å ç ⅰ"]).astype( + np.unicode_ + ) - exp_norm_nfc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(dtype) - exp_norm_nfkc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(dtype) - exp_norm_nfd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(dtype) - exp_norm_nfkd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(dtype) + exp_norm_nfc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + exp_norm_nfkc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(np.unicode_) + exp_norm_nfd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + exp_norm_nfkd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype( + np.unicode_ + ) res_capitalized = value.str.capitalize() res_casefolded = value.str.casefold() @@ -191,35 +334,52 @@ def test_case_str(): def test_replace(dtype): - values = xr.DataArray(["fooBAD__barBAD"]).astype(dtype) + values = xr.DataArray(["fooBAD__barBAD"], dims=["x"]).astype(dtype) result = values.str.replace("BAD[_]*", "") - expected = xr.DataArray(["foobar"]).astype(dtype) + expected = xr.DataArray(["foobar"], dims=["x"]).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.replace("BAD[_]*", "", n=1) - expected = xr.DataArray(["foobarBAD"]).astype(dtype) + expected = xr.DataArray(["foobarBAD"], dims=["x"]).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) - s = xr.DataArray(["A", "B", "C", "Aaba", "Baca", "", "CABA", "dog", "cat"]).astype( + pat = xr.DataArray(["BAD[_]*", "AD[_]*"], dims=["y"]).astype(dtype) + result = values.str.replace(pat, "") + expected = xr.DataArray([["foobar", "fooBbarB"]], dims=["x", "y"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + repl = xr.DataArray(["", "spam"], dims=["y"]).astype(dtype) + result = values.str.replace(pat, repl, n=1) + expected = xr.DataArray([["foobarBAD", "fooBspambarBAD"]], dims=["x", "y"]).astype( dtype ) - result = s.str.replace("A", "YYY") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + values = xr.DataArray( + ["A", "B", "C", "Aaba", "Baca", "", "CABA", "dog", "cat"] + ).astype(dtype) expected = xr.DataArray( ["YYY", "B", "C", "YYYaba", "Baca", "", "CYYYBYYY", "dog", "cat"] ).astype(dtype) + result = values.str.replace("A", "YYY") + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.replace("A", "YYY", regex=False) assert result.dtype == expected.dtype assert_equal(result, expected) - result = s.str.replace("A", "YYY", case=False) + result = values.str.replace("A", "YYY", case=False) expected = xr.DataArray( ["YYY", "B", "C", "YYYYYYbYYY", "BYYYcYYY", "", "CYYYBYYY", "dog", "cYYYt"] ).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) - result = s.str.replace("^.a|dog", "XX-XX ", case=False) + result = values.str.replace("^.a|dog", "XX-XX ", case=False) expected = xr.DataArray( ["A", "B", "C", "XX-XX ba", "XX-XX ca", "", "XX-XX BA", "XX-XX ", "XX-XX t"] ).astype(dtype) @@ -246,6 +406,22 @@ def test_replace_callable(): assert result.dtype == exp.dtype assert_equal(result, exp) + # test broadcast + values = xr.DataArray(["Foo Bar Baz"], dims=["x"]) + pat = r"(?P\w+) (?P\w+) (?P\w+)" + repl = xr.DataArray( + [ + lambda m: m.group("first").swapcase(), + lambda m: m.group("middle").swapcase(), + lambda m: m.group("last").swapcase(), + ], + dims=["Y"], + ) + result = values.str.replace(pat, repl) + exp = xr.DataArray([["fOO", "bAR", "bAZ"]], dims=["x", "Y"]) + assert result.dtype == exp.dtype + assert_equal(result, exp) + def test_replace_unicode(): # flags + unicode @@ -256,18 +432,50 @@ def test_replace_unicode(): assert result.dtype == expected.dtype assert_equal(result, expected) + # broadcast version + values = xr.DataArray([b"abcd,\xc3\xa0".decode("utf-8")], dims=["X"]) + expected = xr.DataArray( + [[b"abcd, \xc3\xa0".decode("utf-8"), b"BAcd,\xc3\xa0".decode("utf-8")]], + dims=["X", "Y"], + ) + pat = xr.DataArray( + [re.compile(r"(?<=\w),(?=\w)", flags=re.UNICODE), r"ab"], dims=["Y"] + ) + repl = xr.DataArray([", ", "BA"], dims=["Y"]) + result = values.str.replace(pat, repl) + assert result.dtype == expected.dtype + assert_equal(result, expected) + def test_replace_compiled_regex(dtype): - values = xr.DataArray(["fooBAD__barBAD"]).astype(dtype) + values = xr.DataArray(["fooBAD__barBAD"], dims=["x"]).astype(dtype) + # test with compiled regex pat = re.compile(dtype("BAD[_]*")) result = values.str.replace(pat, "") - expected = xr.DataArray(["foobar"]).astype(dtype) + expected = xr.DataArray(["foobar"], dims=["x"]).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.replace(pat, "", n=1) - expected = xr.DataArray(["foobarBAD"]).astype(dtype) + expected = xr.DataArray(["foobarBAD"], dims=["x"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # broadcast + pat = xr.DataArray( + [re.compile(dtype("BAD[_]*")), re.compile(dtype("AD[_]*"))], dims=["y"] + ) + result = values.str.replace(pat, "") + expected = xr.DataArray([["foobar", "fooBbarB"]], dims=["x", "y"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + repl = xr.DataArray(["", "spam"], dims=["y"]).astype(dtype) + result = values.str.replace(pat, repl, n=1) + expected = xr.DataArray([["foobarBAD", "fooBspambarBAD"]], dims=["x", "y"]).astype( + dtype + ) assert result.dtype == expected.dtype assert_equal(result, expected) @@ -276,13 +484,19 @@ def test_replace_compiled_regex(dtype): values = xr.DataArray(["fooBAD__barBAD__bad"]).astype(dtype) pat = re.compile(dtype("BAD[_]*")) - with pytest.raises(ValueError, match="flags cannot be set"): + with pytest.raises( + ValueError, match="Flags cannot be set when pat is a compiled regex." + ): result = values.str.replace(pat, "", flags=re.IGNORECASE) - with pytest.raises(ValueError, match="case cannot be set"): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): result = values.str.replace(pat, "", case=False) - with pytest.raises(ValueError, match="case cannot be set"): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): result = values.str.replace(pat, "", case=True) # test with callable @@ -297,17 +511,33 @@ def test_replace_compiled_regex(dtype): def test_replace_literal(dtype): # GH16808 literal replace (regex=False vs regex=True) - values = xr.DataArray(["f.o", "foo"]).astype(dtype) - expected = xr.DataArray(["bao", "bao"]).astype(dtype) + values = xr.DataArray(["f.o", "foo"], dims=["X"]).astype(dtype) + expected = xr.DataArray(["bao", "bao"], dims=["X"]).astype(dtype) result = values.str.replace("f.", "ba") assert result.dtype == expected.dtype assert_equal(result, expected) - expected = xr.DataArray(["bao", "foo"]).astype(dtype) + expected = xr.DataArray(["bao", "foo"], dims=["X"]).astype(dtype) result = values.str.replace("f.", "ba", regex=False) assert result.dtype == expected.dtype assert_equal(result, expected) + # Broadcast + pat = xr.DataArray(["f.", ".o"], dims=["yy"]).astype(dtype) + expected = xr.DataArray([["bao", "fba"], ["bao", "bao"]], dims=["X", "yy"]).astype( + dtype + ) + result = values.str.replace(pat, "ba") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = xr.DataArray([["bao", "fba"], ["foo", "foo"]], dims=["X", "yy"]).astype( + dtype + ) + result = values.str.replace(pat, "ba", regex=False) + assert result.dtype == expected.dtype + assert_equal(result, expected) + # Cannot do a literal replace if given a callable repl or compiled # pattern callable_repl = lambda m: m.group(0).swapcase() @@ -323,143 +553,133 @@ def test_replace_literal(dtype): def test_extract_extractall_findall_empty_raises(dtype): - pat_str = r"a_\w+_b_\d+_c_.*" + pat_str = dtype(r".*") pat_re = re.compile(pat_str) - value = xr.DataArray( - [ - ["a_first_b_1_c_de", "a_second_b_22_c_efh", "a_third_b_333_c_hijk"], - [ - "a_fourth_b_4444_c_klmno", - "a_fifth_b_5555_c_opqr", - "a_sixth_b_66666_c_rst", - ], - ], - dims=["X", "Y"], - ).astype(dtype) + value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No capture groups found in pattern."): value.str.extract(pat=pat_str, dim="ZZ") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No capture groups found in pattern."): value.str.extract(pat=pat_re, dim="ZZ") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No capture groups found in pattern."): value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No capture groups found in pattern."): value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No capture groups found in pattern."): value.str.findall(pat=pat_str) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No capture groups found in pattern."): value.str.findall(pat=pat_re) def test_extract_multi_None_raises(dtype): - pat_str = r"a_(\w+)_b_(\d+)_c_.*" + pat_str = r"(\w+)_(\d+)" pat_re = re.compile(pat_str) - value = xr.DataArray( - [ - ["a_first_b_1_c_de", "a_second_b_22_c_efh", "a_third_b_333_c_hijk"], - [ - "a_fourth_b_4444_c_klmno", - "a_fifth_b_5555_c_opqr", - "a_sixth_b_66666_c_rst", - ], - ], - dims=["X", "Y"], - ).astype(dtype) + value = xr.DataArray([["a_b"]], dims=["X", "Y"]).astype(dtype) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Dimension must be specified if more than one capture group is given.", + ): value.str.extract(pat=pat_str, dim=None) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Dimension must be specified if more than one capture group is given.", + ): value.str.extract(pat=pat_re, dim=None) def test_extract_extractall_findall_case_re_raises(dtype): - pat_str = r"a_\w+_b_\d+_c_.*" + pat_str = r".*" pat_re = re.compile(pat_str) - value = xr.DataArray( - [ - ["a_first_b_1_c_de", "a_second_b_22_c_efh", "a_third_b_333_c_hijk"], - [ - "a_fourth_b_4444_c_klmno", - "a_fifth_b_5555_c_opqr", - "a_sixth_b_66666_c_rst", - ], - ], - dims=["X", "Y"], - ).astype(dtype) + value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): value.str.extract(pat=pat_re, case=True, dim="ZZ") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): value.str.extract(pat=pat_re, case=False, dim="ZZ") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): value.str.extractall(pat=pat_re, case=True, group_dim="XX", match_dim="YY") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): value.str.extractall(pat=pat_re, case=False, group_dim="XX", match_dim="YY") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): value.str.findall(pat=pat_re, case=True) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): value.str.findall(pat=pat_re, case=False) def test_extract_extractall_name_collision_raises(dtype): - pat_str = r"a_(\w+)_b_\d+_c_.*" + pat_str = r"(\w+)" pat_re = re.compile(pat_str) - value = xr.DataArray( - [ - ["a_first_b_1_c_de", "a_second_b_22_c_efh", "a_third_b_333_c_hijk"], - [ - "a_fourth_b_4444_c_klmno", - "a_fifth_b_5555_c_opqr", - "a_sixth_b_66666_c_rst", - ], - ], - dims=["X", "Y"], - ).astype(dtype) + value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) - with pytest.raises(KeyError): + with pytest.raises(KeyError, match="Dimension X already present in DataArray."): value.str.extract(pat=pat_str, dim="X") - with pytest.raises(KeyError): + with pytest.raises(KeyError, match="Dimension X already present in DataArray."): value.str.extract(pat=pat_re, dim="X") - with pytest.raises(KeyError): + with pytest.raises( + KeyError, match="Group dimension X already present in DataArray." + ): value.str.extractall(pat=pat_str, group_dim="X", match_dim="ZZ") - with pytest.raises(KeyError): + with pytest.raises( + KeyError, match="Group dimension X already present in DataArray." + ): value.str.extractall(pat=pat_re, group_dim="X", match_dim="YY") - with pytest.raises(KeyError): + with pytest.raises( + KeyError, match="Match dimension Y already present in DataArray." + ): value.str.extractall(pat=pat_str, group_dim="XX", match_dim="Y") - with pytest.raises(KeyError): + with pytest.raises( + KeyError, match="Match dimension Y already present in DataArray." + ): value.str.extractall(pat=pat_re, group_dim="XX", match_dim="Y") - with pytest.raises(KeyError): + with pytest.raises( + KeyError, match="Group dimension ZZ is the same as match dimension ZZ." + ): value.str.extractall(pat=pat_str, group_dim="ZZ", match_dim="ZZ") - with pytest.raises(KeyError): + with pytest.raises( + KeyError, match="Group dimension ZZ is the same as match dimension ZZ." + ): value.str.extractall(pat=pat_re, group_dim="ZZ", match_dim="ZZ") def test_extract_single_case(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) value = xr.DataArray( [ @@ -503,16 +723,16 @@ def test_extract_single_case(dtype): def test_extract_single_nocase(dtype): - pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.IGNORECASE) + pat_str = r"(\w+)?_Xy_\d*" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.IGNORECASE) value = xr.DataArray( [ ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], [ "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", - "", + "_Xy_1", "abcdef_Xy_101-fef_Xy_5543210", ], ], @@ -544,8 +764,8 @@ def test_extract_single_nocase(dtype): def test_extract_multi_case(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) value = xr.DataArray( [ @@ -559,7 +779,7 @@ def test_extract_multi_case(dtype): dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [["a", "0"], ["bab", "110"], ["abc", "01"]], [["abcd", ""], ["", ""], ["abcdef", "101"]], @@ -571,19 +791,19 @@ def test_extract_multi_case(dtype): res_re = value.str.extract(pat=pat_re, dim="XX") res_str_case = value.str.extract(pat=pat_str, dim="XX", case=True) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_extract_multi_nocase(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.IGNORECASE) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.IGNORECASE) value = xr.DataArray( [ @@ -597,7 +817,7 @@ def test_extract_multi_nocase(dtype): dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [["a", "0"], ["ab", "10"], ["abc", "01"]], [["abcd", ""], ["", ""], ["abcdef", "101"]], @@ -608,24 +828,53 @@ def test_extract_multi_nocase(dtype): res_str = value.str.extract(pat=pat_str, dim="XX", case=False) res_re = value.str.extract(pat=pat_re, dim="XX") - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + - assert_equal(res_str, targ) - assert_equal(res_re, targ) +def test_extract_broadcast(dtype): + value = xr.DataArray( + ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], + dims=["X"], + ).astype(dtype) + + pat_str = xr.DataArray( + [r"(\w+)_Xy_(\d*)", r"(\w+)_xY_(\d*)"], + dims=["Y"], + ).astype(dtype) + pat_re = value.str._re_compile(pat=pat_str) + + expected = [ + [["a", "0"], ["", ""]], + [["", ""], ["ab", "10"]], + [["abc", "01"], ["", ""]], + ] + expected = xr.DataArray(expected, dims=["X", "Y", "Zz"]).astype(dtype) + + res_str = value.str.extract(pat=pat_str, dim="Zz") + res_re = value.str.extract(pat=pat_re, dim="Zz") + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_extractall_single_single_case(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [[[["a"]], [[""]], [["abc"]]], [[["abcd"]], [[""]], [["abcdef"]]]], dims=["X", "Y", "XX", "YY"], ).astype(dtype) @@ -636,26 +885,26 @@ def test_extractall_single_single_case(dtype): pat=pat_str, group_dim="XX", match_dim="YY", case=True ) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_extractall_single_single_nocase(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.I) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [[[["a"]], [["ab"]], [["abc"]]], [[["abcd"]], [[""]], [["abcdef"]]]], dims=["X", "Y", "XX", "YY"], ).astype(dtype) @@ -665,17 +914,17 @@ def test_extractall_single_single_nocase(dtype): ) res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_extractall_single_multi_case(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) value = xr.DataArray( [ @@ -689,7 +938,7 @@ def test_extractall_single_multi_case(dtype): dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [[["a"], [""], [""]], [["bab"], ["baab"], [""]], [["abc"], ["cbc"], [""]]], [ @@ -707,19 +956,19 @@ def test_extractall_single_multi_case(dtype): pat=pat_str, group_dim="XX", match_dim="YY", case=True ) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_extractall_single_multi_nocase(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.I) value = xr.DataArray( [ @@ -733,7 +982,7 @@ def test_extractall_single_multi_nocase(dtype): dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [ [["a"], [""], [""]], @@ -754,24 +1003,24 @@ def test_extractall_single_multi_nocase(dtype): ) res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_extractall_multi_single_case(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [[["a", "0"]], [["", ""]], [["abc", "01"]]], [[["abcd", ""]], [["", ""]], [["abcdef", "101"]]], @@ -785,26 +1034,26 @@ def test_extractall_multi_single_case(dtype): pat=pat_str, group_dim="XX", match_dim="YY", case=True ) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_extractall_multi_single_nocase(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.I) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [[["a", "0"]], [["ab", "10"]], [["abc", "01"]]], [[["abcd", ""]], [["", ""]], [["abcdef", "101"]]], @@ -817,17 +1066,17 @@ def test_extractall_multi_single_nocase(dtype): ) res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_extractall_multi_multi_case(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) value = xr.DataArray( [ @@ -841,7 +1090,7 @@ def test_extractall_multi_multi_case(dtype): dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [ [["a", "0"], ["", ""], ["", ""]], @@ -863,19 +1112,19 @@ def test_extractall_multi_multi_case(dtype): pat=pat_str, group_dim="XX", match_dim="YY", case=True ) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_extractall_multi_multi_nocase(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.I) value = xr.DataArray( [ @@ -889,7 +1138,7 @@ def test_extractall_multi_multi_nocase(dtype): dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [ [["a", "0"], ["", ""], ["", ""]], @@ -910,71 +1159,96 @@ def test_extractall_multi_multi_nocase(dtype): ) res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_extractall_broadcast(dtype): + value = xr.DataArray( + ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], + dims=["X"], + ).astype(dtype) - assert_equal(res_str, targ) - assert_equal(res_re, targ) + pat_str = xr.DataArray( + [r"(\w+)_Xy_(\d*)", r"(\w+)_xY_(\d*)"], + dims=["Y"], + ).astype(dtype) + pat_re = value.str._re_compile(pat=pat_str) + + expected = [ + [[["a", "0"]], [["", ""]]], + [[["", ""]], [["ab", "10"]]], + [[["abc", "01"]], [["", ""]]], + ] + expected = xr.DataArray(expected, dims=["X", "Y", "ZX", "ZY"]).astype(dtype) + + res_str = value.str.extractall(pat=pat_str, group_dim="ZX", match_dim="ZY") + res_re = value.str.extractall(pat=pat_re, group_dim="ZX", match_dim="ZY") + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_findall_single_single_case(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = re.compile(dtype(pat_str)) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = [[["a"], [], ["abc"]], [["abcd"], [], ["abcdef"]]] - targ = [[[conv(x) for x in y] for y in z] for z in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[["a"], [], ["abc"]], [["abcd"], [], ["abcdef"]]] + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str) res_re = value.str.findall(pat=pat_re) res_str_case = value.str.findall(pat=pat_str, case=True) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_findall_single_single_nocase(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = re.compile(dtype(pat_str), flags=re.I) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = [[["a"], ["ab"], ["abc"]], [["abcd"], [], ["abcdef"]]] - targ = [[[conv(x) for x in y] for y in z] for z in targ] - targ = np.array(targ, dtype=np.object_) - print(targ) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[["a"], ["ab"], ["abc"]], [["abcd"], [], ["abcdef"]]] + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str, case=False) res_re = value.str.findall(pat=pat_re) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_findall_single_multi_case(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = re.compile(dtype(pat_str)) value = xr.DataArray( [ @@ -988,7 +1262,7 @@ def test_findall_single_multi_case(dtype): dims=["X", "Y"], ).astype(dtype) - targ = [ + expected = [ [["a"], ["bab", "baab"], ["abc", "cbc"]], [ ["abcd", "dcd", "dccd"], @@ -996,27 +1270,26 @@ def test_findall_single_multi_case(dtype): ["abcdef", "fef"], ], ] - targ = [[[conv(x) for x in y] for y in z] for z in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str) res_re = value.str.findall(pat=pat_re) res_str_case = value.str.findall(pat=pat_str, case=True) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_findall_single_multi_nocase(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = re.compile(dtype(pat_str), flags=re.I) value = xr.DataArray( [ @@ -1030,7 +1303,7 @@ def test_findall_single_multi_nocase(dtype): dims=["X", "Y"], ).astype(dtype) - targ = [ + expected = [ [ ["a"], ["ab", "bab", "baab"], @@ -1042,83 +1315,80 @@ def test_findall_single_multi_nocase(dtype): ["abcdef", "fef"], ], ] - targ = [[[conv(x) for x in y] for y in z] for z in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str, case=False) res_re = value.str.findall(pat=pat_re) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_findall_multi_single_case(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = re.compile(dtype(pat_str)) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = [ + expected = [ [[["a", "0"]], [], [["abc", "01"]]], [[["abcd", ""]], [], [["abcdef", "101"]]], ] - targ = [[[tuple(conv(x) for x in y) for y in z] for z in w] for w in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str) res_re = value.str.findall(pat=pat_re) res_str_case = value.str.findall(pat=pat_str, case=True) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_findall_multi_single_nocase(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = re.compile(dtype(pat_str), flags=re.I) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = [ + expected = [ [[["a", "0"]], [["ab", "10"]], [["abc", "01"]]], [[["abcd", ""]], [], [["abcdef", "101"]]], ] - targ = [[[tuple(conv(x) for x in y) for y in z] for z in w] for w in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str, case=False) res_re = value.str.findall(pat=pat_re) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_findall_multi_multi_case(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = re.compile(dtype(pat_str)) value = xr.DataArray( [ @@ -1132,7 +1402,7 @@ def test_findall_multi_multi_case(dtype): dims=["X", "Y"], ).astype(dtype) - targ = [ + expected = [ [ [["a", "0"]], [["bab", "110"], ["baab", "1100"]], @@ -1144,27 +1414,26 @@ def test_findall_multi_multi_case(dtype): [["abcdef", "101"], ["fef", "5543210"]], ], ] - targ = [[[tuple(conv(x) for x in y) for y in z] for z in w] for w in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str) res_re = value.str.findall(pat=pat_re) res_str_case = value.str.findall(pat=pat_str, case=True) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_findall_multi_multi_nocase(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = re.compile(dtype(pat_str), flags=re.I) value = xr.DataArray( [ @@ -1178,7 +1447,7 @@ def test_findall_multi_multi_nocase(dtype): dims=["X", "Y"], ).astype(dtype) - targ = [ + expected = [ [ [["a", "0"]], [["ab", "10"], ["bab", "110"], ["baab", "1100"]], @@ -1190,18 +1459,45 @@ def test_findall_multi_multi_nocase(dtype): [["abcdef", "101"], ["fef", "5543210"]], ], ] - targ = [[[tuple(conv(x) for x in y) for y in z] for z in w] for w in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str, case=False) res_re = value.str.findall(pat=pat_re) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) - assert_equal(res_str, targ) - assert_equal(res_re, targ) + +def test_findall_broadcast(dtype): + value = xr.DataArray( + ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], + dims=["X"], + ).astype(dtype) + + pat_str = xr.DataArray( + [r"(\w+)_Xy_\d*", r"\w+_Xy_(\d*)"], + dims=["Y"], + ).astype(dtype) + pat_re = value.str._re_compile(pat=pat_str) + + expected = [[["a"], ["0"]], [[], []], [["abc"], ["01"]]] + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_repeat(dtype): @@ -1219,19 +1515,57 @@ def test_repeat(dtype): assert_equal(result, expected) +def test_repeat_broadcast(dtype): + values = xr.DataArray(["a", "b", "c", "d"], dims=["X"]).astype(dtype) + reps = xr.DataArray([3, 4], dims=["Y"]) + + result = values.str.repeat(reps) + result_mul = values.str * reps + + expected = xr.DataArray( + [["aaa", "aaaa"], ["bbb", "bbbb"], ["ccc", "cccc"], ["ddd", "dddd"]], + dims=["X", "Y"], + ).astype(dtype) + + assert result.dtype == expected.dtype + assert result_mul.dtype == expected.dtype + + assert_equal(result_mul, expected) + assert_equal(result, expected) + + def test_match(dtype): - # New match behavior introduced in 0.13 values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) - result = values.str.match(".*(BAD[_]+).*(BAD)") + + # New match behavior introduced in 0.13 + pat = values.dtype.type(".*(BAD[_]+).*(BAD)") + result = values.str.match(pat) expected = xr.DataArray([True, False]) assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.match(re.compile(pat)) + assert result.dtype == expected.dtype + assert_equal(result, expected) - values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) - result = values.str.match(".*BAD[_]+.*BAD") + # Case-sensitive + pat = values.dtype.type(".*BAD[_]+.*BAD") + result = values.str.match(pat) expected = xr.DataArray([True, False]) assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.match(re.compile(pat)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # Case-insensitive + pat = values.dtype.type(".*bAd[_]+.*bad") + result = values.str.match(pat, case=False) + expected = xr.DataArray([True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.match(re.compile(pat, flags=re.IGNORECASE)) + assert result.dtype == expected.dtype + assert_equal(result, expected) def test_empty_str_methods(): @@ -1400,132 +1734,221 @@ def test_len(dtype): def test_find(dtype): values = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"]) values = values.astype(dtype) - result = values.str.find("EF") - expected = xr.DataArray([4, 3, 1, 0, -1]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - expected = xr.DataArray([v.find(dtype("EF")) for v in values.values]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - result = values.str.rfind("EF") - expected = xr.DataArray([4, 5, 7, 4, -1]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - expected = xr.DataArray([v.rfind(dtype("EF")) for v in values.values]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + result_0 = values.str.find("EF") + result_1 = values.str.find("EF", side="left") + expected_0 = xr.DataArray([4, 3, 1, 0, -1]) + expected_1 = xr.DataArray([v.find(dtype("EF")) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.rfind("EF") + result_1 = values.str.find("EF", side="right") + expected_0 = xr.DataArray([4, 5, 7, 4, -1]) + expected_1 = xr.DataArray([v.rfind(dtype("EF")) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.find("EF", 3) + result_1 = values.str.find("EF", 3, side="left") + expected_0 = xr.DataArray([4, 3, 7, 4, -1]) + expected_1 = xr.DataArray([v.find(dtype("EF"), 3) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.rfind("EF", 3) + result_1 = values.str.find("EF", 3, side="right") + expected_0 = xr.DataArray([4, 5, 7, 4, -1]) + expected_1 = xr.DataArray([v.rfind(dtype("EF"), 3) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.find("EF", 3, 6) + result_1 = values.str.find("EF", 3, 6, side="left") + expected_0 = xr.DataArray([4, 3, -1, 4, -1]) + expected_1 = xr.DataArray([v.find(dtype("EF"), 3, 6) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.rfind("EF", 3, 6) + result_1 = values.str.find("EF", 3, 6, side="right") + expected_0 = xr.DataArray([4, 3, -1, 4, -1]) + expected_1 = xr.DataArray([v.rfind(dtype("EF"), 3, 6) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + +def test_find_broadcast(dtype): + values = xr.DataArray( + ["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"], dims=["X"] + ) + values = values.astype(dtype) + sub = xr.DataArray(["EF", "BC", "XX"], dims=["Y"]).astype(dtype) + start = xr.DataArray([0, 7], dims=["Z"]) + end = xr.DataArray([6, 9], dims=["Z"]) - result = values.str.find("EF", 3) - expected = xr.DataArray([4, 3, 7, 4, -1]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - expected = xr.DataArray([v.find(dtype("EF"), 3) for v in values.values]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + result_0 = values.str.find(sub, start, end) + result_1 = values.str.find(sub, start, end, side="left") + expected = xr.DataArray( + [ + [[4, -1], [1, -1], [-1, -1]], + [[3, -1], [0, -1], [-1, -1]], + [[1, 7], [-1, -1], [-1, -1]], + [[0, -1], [-1, -1], [-1, -1]], + [[-1, -1], [-1, -1], [0, -1]], + ], + dims=["X", "Y", "Z"], + ) - result = values.str.rfind("EF", 3) - expected = xr.DataArray([4, 5, 7, 4, -1]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - expected = xr.DataArray([v.rfind(dtype("EF"), 3) for v in values.values]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = values.str.find("EF", 3, 6) - expected = xr.DataArray([4, 3, -1, 4, -1]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - expected = xr.DataArray([v.find(dtype("EF"), 3, 6) for v in values.values]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + result_0 = values.str.rfind(sub, start, end) + result_1 = values.str.find(sub, start, end, side="right") + expected = xr.DataArray( + [ + [[4, -1], [1, -1], [-1, -1]], + [[3, -1], [0, -1], [-1, -1]], + [[1, 7], [-1, -1], [-1, -1]], + [[4, -1], [-1, -1], [-1, -1]], + [[-1, -1], [-1, -1], [1, -1]], + ], + dims=["X", "Y", "Z"], + ) - result = values.str.rfind("EF", 3, 6) - expected = xr.DataArray([4, 3, -1, 4, -1]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - xp = xr.DataArray([v.rfind(dtype("EF"), 3, 6) for v in values.values]) - assert result.dtype == xp.dtype - assert_equal(result, xp) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) def test_index(dtype): s = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"]).astype(dtype) - result = s.str.index("EF") + result_0 = s.str.index("EF") + result_1 = s.str.index("EF", side="left") expected = xr.DataArray([4, 3, 1, 0]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = s.str.rindex("EF") + result_0 = s.str.rindex("EF") + result_1 = s.str.index("EF", side="right") expected = xr.DataArray([4, 5, 7, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = s.str.index("EF", 3) + result_0 = s.str.index("EF", 3) + result_1 = s.str.index("EF", 3, side="left") expected = xr.DataArray([4, 3, 7, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = s.str.rindex("EF", 3) + result_0 = s.str.rindex("EF", 3) + result_1 = s.str.index("EF", 3, side="right") expected = xr.DataArray([4, 5, 7, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = s.str.index("E", 4, 8) + result_0 = s.str.index("E", 4, 8) + result_1 = s.str.index("E", 4, 8, side="left") expected = xr.DataArray([4, 5, 7, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = s.str.rindex("E", 0, 5) + result_0 = s.str.rindex("E", 0, 5) + result_1 = s.str.index("E", 0, 5, side="right") expected = xr.DataArray([4, 3, 1, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - - with pytest.raises(ValueError): - result = s.str.index("DE") - + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) -def test_pad(dtype): - values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) + matchtype = "subsection" if dtype == np.bytes_ else "substring" + with pytest.raises(ValueError, match=f"{matchtype} not found"): + s.str.index("DE") - result = values.str.pad(5, side="left") - expected = xr.DataArray([" a", " b", " c", "eeeee"]).astype(dtype) - assert result.dtype == expected.dtype - assert_equal(result, expected) - result = values.str.pad(5, side="right") - expected = xr.DataArray(["a ", "b ", "c ", "eeeee"]).astype(dtype) - assert result.dtype == expected.dtype - assert_equal(result, expected) - - result = values.str.pad(5, side="both") - expected = xr.DataArray([" a ", " b ", " c ", "eeeee"]).astype(dtype) - assert result.dtype == expected.dtype - assert_equal(result, expected) - - -def test_pad_fillchar(dtype): - values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) +def test_index_broadcast(dtype): + values = xr.DataArray( + ["ABCDEFGEFDBCA", "BCDEFEFEFDBC", "DEFBCGHIEFBC", "EFGHBCEFBCBCBCEF"], + dims=["X"], + ) + values = values.astype(dtype) + sub = xr.DataArray(["EF", "BC"], dims=["Y"]).astype(dtype) + start = xr.DataArray([0, 6], dims=["Z"]) + end = xr.DataArray([6, 12], dims=["Z"]) - result = values.str.pad(5, side="left", fillchar="X") - expected = xr.DataArray(["XXXXa", "XXXXb", "XXXXc", "eeeee"]).astype(dtype) - assert result.dtype == expected.dtype - assert_equal(result, expected) + result_0 = values.str.index(sub, start, end) + result_1 = values.str.index(sub, start, end, side="left") + expected = xr.DataArray( + [[[4, 7], [1, 10]], [[3, 7], [0, 10]], [[1, 8], [3, 10]], [[0, 6], [4, 8]]], + dims=["X", "Y", "Z"], + ) - result = values.str.pad(5, side="right", fillchar="X") - expected = xr.DataArray(["aXXXX", "bXXXX", "cXXXX", "eeeee"]).astype(dtype) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = values.str.pad(5, side="both", fillchar="X") - expected = xr.DataArray(["XXaXX", "XXbXX", "XXcXX", "eeeee"]).astype(dtype) - assert result.dtype == expected.dtype - assert_equal(result, expected) + result_0 = values.str.rindex(sub, start, end) + result_1 = values.str.index(sub, start, end, side="right") + expected = xr.DataArray( + [[[4, 7], [1, 10]], [[3, 7], [0, 10]], [[1, 8], [3, 10]], [[0, 6], [4, 10]]], + dims=["X", "Y", "Z"], + ) - msg = "fillchar must be a character, not str" - with pytest.raises(TypeError, match=msg): - result = values.str.pad(5, fillchar="XY") + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) def test_translate(): @@ -1537,41 +1960,66 @@ def test_translate(): assert_equal(result, expected) -def test_center_ljust_rjust(dtype): +def test_pad_center_ljust_rjust(dtype): values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) result = values.str.center(5) expected = xr.DataArray([" a ", " b ", " c ", "eeeee"]).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.pad(5, side="both") + assert result.dtype == expected.dtype + assert_equal(result, expected) result = values.str.ljust(5) expected = xr.DataArray(["a ", "b ", "c ", "eeeee"]).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.pad(5, side="right") + assert result.dtype == expected.dtype + assert_equal(result, expected) result = values.str.rjust(5) expected = xr.DataArray([" a", " b", " c", "eeeee"]).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.pad(5, side="left") + assert result.dtype == expected.dtype + assert_equal(result, expected) -def test_center_ljust_rjust_fillchar(dtype): +def test_pad_center_ljust_rjust_fillchar(dtype): values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"]).astype(dtype) + result = values.str.center(5, fillchar="X") - expected = xr.DataArray(["XXaXX", "XXbbX", "Xcccc", "ddddd", "eeeeee"]) - assert result.dtype == expected.astype(dtype).dtype - assert_equal(result, expected.astype(dtype)) + expected = xr.DataArray(["XXaXX", "XXbbX", "Xcccc", "ddddd", "eeeeee"]).astype( + dtype + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(5, side="both", fillchar="X") + assert result.dtype == expected.dtype + assert_equal(result, expected) result = values.str.ljust(5, fillchar="X") - expected = xr.DataArray(["aXXXX", "bbXXX", "ccccX", "ddddd", "eeeeee"]) - assert result.dtype == expected.astype(dtype).dtype + expected = xr.DataArray(["aXXXX", "bbXXX", "ccccX", "ddddd", "eeeeee"]).astype( + dtype + ) + assert result.dtype == expected.dtype assert_equal(result, expected.astype(dtype)) + result = values.str.pad(5, side="right", fillchar="X") + assert result.dtype == expected.dtype + assert_equal(result, expected) result = values.str.rjust(5, fillchar="X") - expected = xr.DataArray(["XXXXa", "XXXbb", "Xcccc", "ddddd", "eeeeee"]) - assert result.dtype == expected.astype(dtype).dtype + expected = xr.DataArray(["XXXXa", "XXXbb", "Xcccc", "ddddd", "eeeeee"]).astype( + dtype + ) + assert result.dtype == expected.dtype assert_equal(result, expected.astype(dtype)) + result = values.str.pad(5, side="left", fillchar="X") + assert result.dtype == expected.dtype + assert_equal(result, expected) # If fillchar is not a charatter, normal str raises TypeError # 'aaa'.ljust(5, 'XY') @@ -1587,19 +2035,91 @@ def test_center_ljust_rjust_fillchar(dtype): with pytest.raises(TypeError, match=template.format(dtype="str")): values.str.rjust(5, fillchar="XY") + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.pad(5, fillchar="XY") + + +def test_pad_center_ljust_rjust_broadcast(dtype): + values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"], dims="X").astype( + dtype + ) + width = xr.DataArray([5, 4], dims="Y") + fillchar = xr.DataArray(["X", "#"], dims="Y").astype(dtype) + + result = values.str.center(width, fillchar=fillchar) + expected = xr.DataArray( + [ + ["XXaXX", "#a##"], + ["XXbbX", "#bb#"], + ["Xcccc", "cccc"], + ["ddddd", "ddddd"], + ["eeeeee", "eeeeee"], + ], + dims=["X", "Y"], + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(width, side="both", fillchar=fillchar) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.ljust(width, fillchar=fillchar) + expected = xr.DataArray( + [ + ["aXXXX", "a###"], + ["bbXXX", "bb##"], + ["ccccX", "cccc"], + ["ddddd", "ddddd"], + ["eeeeee", "eeeeee"], + ], + dims=["X", "Y"], + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected.astype(dtype)) + result = values.str.pad(width, side="right", fillchar=fillchar) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rjust(width, fillchar=fillchar) + expected = xr.DataArray( + [ + ["XXXXa", "###a"], + ["XXXbb", "##bb"], + ["Xcccc", "cccc"], + ["ddddd", "ddddd"], + ["eeeeee", "eeeeee"], + ], + dims=["X", "Y"], + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected.astype(dtype)) + result = values.str.pad(width, side="left", fillchar=fillchar) + assert result.dtype == expected.dtype + assert_equal(result, expected) + def test_zfill(dtype): values = xr.DataArray(["1", "22", "aaa", "333", "45678"]).astype(dtype) result = values.str.zfill(5) - expected = xr.DataArray(["00001", "00022", "00aaa", "00333", "45678"]) - assert result.dtype == expected.astype(dtype).dtype - assert_equal(result, expected.astype(dtype)) + expected = xr.DataArray(["00001", "00022", "00aaa", "00333", "45678"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) result = values.str.zfill(3) - expected = xr.DataArray(["001", "022", "aaa", "333", "45678"]) - assert result.dtype == expected.astype(dtype).dtype - assert_equal(result, expected.astype(dtype)) + expected = xr.DataArray(["001", "022", "aaa", "333", "45678"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_zfill_broadcast(dtype): + values = xr.DataArray(["1", "22", "aaa", "333", "45678"]).astype(dtype) + width = np.array([4, 5, 0, 3, 8]) + + result = values.str.zfill(width) + expected = xr.DataArray(["0001", "00022", "aaa", "333", "00045678"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) def test_slice(dtype): @@ -1620,6 +2140,17 @@ def test_slice(dtype): raise +def test_slice_broadcast(dtype): + arr = xr.DataArray(["aafootwo", "aabartwo", "aabazqux"]).astype(dtype) + start = xr.DataArray([1, 2, 3]) + stop = 5 + + result = arr.str.slice(start=start, stop=stop) + exp = xr.DataArray(["afoo", "bar", "az"]).astype(dtype) + assert result.dtype == exp.dtype + assert_equal(result, exp) + + def test_slice_replace(dtype): da = lambda x: xr.DataArray(x).astype(dtype) values = da(["short", "a bit longer", "evenlongerthanthat", ""]) @@ -1665,6 +2196,22 @@ def test_slice_replace(dtype): assert_equal(result, expected) +def test_slice_replace_broadcast(dtype): + values = xr.DataArray(["short", "a bit longer", "evenlongerthanthat", ""]).astype( + dtype + ) + start = 2 + stop = np.array([4, 5, None, 7]) + repl = "test" + + expected = xr.DataArray(["shtestt", "a test longer", "evtest", "test"]).astype( + dtype + ) + result = values.str.slice_replace(start, stop, repl) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + def test_strip_lstrip_rstrip(dtype): values = xr.DataArray([" aa ", " bb \n", "cc "]).astype(dtype) @@ -1687,20 +2234,40 @@ def test_strip_lstrip_rstrip(dtype): def test_strip_lstrip_rstrip_args(dtype): values = xr.DataArray(["xxABCxx", "xx BNSD", "LDFJH xx"]).astype(dtype) - rs = values.str.strip("x") - xp = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) - assert rs.dtype == xp.dtype - assert_equal(rs, xp) + result = values.str.strip("x") + expected = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.lstrip("x") + expected = xr.DataArray(["ABCxx", " BNSD", "LDFJH xx"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rstrip("x") + expected = xr.DataArray(["xxABC", "xx BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_strip_lstrip_rstrip_broadcast(dtype): + values = xr.DataArray(["xxABCxx", "yy BNSD", "LDFJH zz"]).astype(dtype) + to_strip = xr.DataArray(["x", "y", "z"]).astype(dtype) + + result = values.str.strip(to_strip) + expected = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) - rs = values.str.lstrip("x") - xp = xr.DataArray(["ABCxx", " BNSD", "LDFJH xx"]).astype(dtype) - assert rs.dtype == xp.dtype - assert_equal(rs, xp) + result = values.str.lstrip(to_strip) + expected = xr.DataArray(["ABCxx", " BNSD", "LDFJH zz"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) - rs = values.str.rstrip("x") - xp = xr.DataArray(["xxABC", "xx BNSD", "LDFJH "]).astype(dtype) - assert rs.dtype == xp.dtype - assert_equal(rs, xp) + result = values.str.rstrip(to_strip) + expected = xr.DataArray(["xxABC", "yy BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) def test_wrap(): @@ -1800,6 +2367,18 @@ def test_get_default(dtype): assert_equal(result, expected) +def test_get_broadcast(dtype): + values = xr.DataArray(["a_b_c", "c_d_e", "f_g_h"], dims=["X"]).astype(dtype) + inds = xr.DataArray([0, 2], dims=["Y"]) + + result = values.str.get(inds) + expected = xr.DataArray( + [["a", "b"], ["c", "d"], ["f", "g"]], dims=["X", "Y"] + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + def test_encode_decode(): data = xr.DataArray(["a", "b", "a\xe4"]) encoded = data.str.encode("utf-8") @@ -1938,6 +2517,16 @@ def test_partition_comma(dtype): assert_equal(res_rpart_dim, exp_rpart_dim) +def test_partition_empty(dtype): + values = xr.DataArray([], dims=["X"]).astype(dtype) + expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) + + res = values.str.partition(sep=", ", dim="ZZ") + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + def test_split_whitespace(dtype): values = xr.DataArray( [ @@ -2003,16 +2592,14 @@ def test_split_whitespace(dtype): [["test0\ntest1\ntest2", "test3"], [], ["abra ka\nda", "bra"]], ] - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - exp_split_none_full = [ - [[conv(x) for x in y] for y in z] for z in exp_split_none_full + [[dtype(x) for x in y] for y in z] for z in exp_split_none_full ] exp_rsplit_none_full = [ - [[conv(x) for x in y] for y in z] for z in exp_rsplit_none_full + [[dtype(x) for x in y] for y in z] for z in exp_rsplit_none_full ] - exp_split_none_1 = [[[conv(x) for x in y] for y in z] for z in exp_split_none_1] - exp_rsplit_none_1 = [[[conv(x) for x in y] for y in z] for z in exp_rsplit_none_1] + exp_split_none_1 = [[[dtype(x) for x in y] for y in z] for z in exp_split_none_1] + exp_rsplit_none_1 = [[[dtype(x) for x in y] for y in z] for z in exp_rsplit_none_1] exp_split_none_full = np.array(exp_split_none_full, dtype=np.object_) exp_rsplit_none_full = np.array(exp_rsplit_none_full, dtype=np.object_) @@ -2143,16 +2730,14 @@ def test_split_comma(dtype): [["test0,test1,test2", "test3"], [""], ["abra,ka,da", "bra"]], ] - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - exp_split_none_full = [ - [[conv(x) for x in y] for y in z] for z in exp_split_none_full + [[dtype(x) for x in y] for y in z] for z in exp_split_none_full ] exp_rsplit_none_full = [ - [[conv(x) for x in y] for y in z] for z in exp_rsplit_none_full + [[dtype(x) for x in y] for y in z] for z in exp_rsplit_none_full ] - exp_split_none_1 = [[[conv(x) for x in y] for y in z] for z in exp_split_none_1] - exp_rsplit_none_1 = [[[conv(x) for x in y] for y in z] for z in exp_rsplit_none_1] + exp_split_none_1 = [[[dtype(x) for x in y] for y in z] for z in exp_split_none_1] + exp_rsplit_none_1 = [[[dtype(x) for x in y] for y in z] for z in exp_rsplit_none_1] exp_split_none_full = np.array(exp_split_none_full, dtype=np.object_) exp_rsplit_none_full = np.array(exp_rsplit_none_full, dtype=np.object_) @@ -2218,6 +2803,80 @@ def test_split_comma(dtype): assert_equal(res_rsplit_none_10, exp_rsplit_none_full) +def test_splitters_broadcast(dtype): + values = xr.DataArray( + ["ab cd,de fg", "spam, ,eggs swallow", "red_blue"], + dims=["X"], + ).astype(dtype) + + sep = xr.DataArray( + [" ", ","], + dims=["Y"], + ).astype(dtype) + + expected_left = xr.DataArray( + [ + [["ab", "cd,de fg"], ["ab cd", "de fg"]], + [["spam,", ",eggs swallow"], ["spam", " ,eggs swallow"]], + [["red_blue", ""], ["red_blue", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + expected_right = xr.DataArray( + [ + [["ab cd,de", "fg"], ["ab cd", "de fg"]], + [["spam, ,eggs", "swallow"], ["spam, ", "eggs swallow"]], + [["", "red_blue"], ["", "red_blue"]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + + res_left = values.str.split(dim="ZZ", sep=sep, maxsplit=1) + res_right = values.str.rsplit(dim="ZZ", sep=sep, maxsplit=1) + + # assert res_left.dtype == expected_left.dtype + # assert res_right.dtype == expected_right.dtype + + assert_equal(res_left, expected_left) + assert_equal(res_right, expected_right) + + expected_left = xr.DataArray( + [ + [["ab", " ", "cd,de fg"], ["ab cd", ",", "de fg"]], + [["spam,", " ", ",eggs swallow"], ["spam", ",", " ,eggs swallow"]], + [["red_blue", "", ""], ["red_blue", "", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + expected_right = xr.DataArray( + [ + [["ab", " ", "cd,de fg"], ["ab cd", ",", "de fg"]], + [["spam,", " ", ",eggs swallow"], ["spam", ",", " ,eggs swallow"]], + [["red_blue", "", ""], ["red_blue", "", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + + res_left = values.str.partition(dim="ZZ", sep=sep) + res_right = values.str.partition(dim="ZZ", sep=sep) + + # assert res_left.dtype == expected_left.dtype + # assert res_right.dtype == expected_right.dtype + + assert_equal(res_left, expected_left) + assert_equal(res_right, expected_right) + + +def test_split_empty(dtype): + values = xr.DataArray([], dims=["X"]).astype(dtype) + expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) + + res = values.str.split(sep=", ", dim="ZZ") + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + def test_get_dummies(dtype): values_line = xr.DataArray( [["a|ab~abc|abc", "ab", "a||abc|abcd"], ["abcd|ab|a", "abc|ab~abc", "|a"]], @@ -2230,7 +2889,7 @@ def test_get_dummies(dtype): vals_line = np.array(["a", "ab", "abc", "abcd", "ab~abc"]).astype(dtype) vals_comma = np.array(["a", "ab", "abc", "abcd", "ab|abc"]).astype(dtype) - targ = [ + expected = [ [ [True, False, True, False, True], [False, True, False, False, False], @@ -2242,10 +2901,10 @@ def test_get_dummies(dtype): [True, False, False, False, False], ], ] - targ = np.array(targ) - targ = xr.DataArray(targ, dims=["X", "Y", "ZZ"]) - targ_line = targ.copy() - targ_comma = targ.copy() + expected = np.array(expected) + expected = xr.DataArray(expected, dims=["X", "Y", "ZZ"]) + targ_line = expected.copy() + targ_comma = expected.copy() targ_line.coords["ZZ"] = vals_line targ_comma.coords["ZZ"] = vals_comma @@ -2262,14 +2921,50 @@ def test_get_dummies(dtype): assert_equal(res_comma, targ_comma) +def test_get_dummies_broadcast(dtype): + values = xr.DataArray( + ["x~x|x~x", "x", "x|x~x", "x~x"], + dims=["X"], + ).astype(dtype) + + sep = xr.DataArray( + ["|", "~"], + dims=["Y"], + ).astype(dtype) + + expected = [ + [[False, False, True], [True, True, False]], + [[True, False, False], [True, False, False]], + [[True, False, True], [True, True, False]], + [[False, False, True], [True, False, False]], + ] + expected = np.array(expected) + expected = xr.DataArray(expected, dims=["X", "Y", "ZZ"]) + expected.coords["ZZ"] = np.array(["x", "x|x", "x~x"]).astype(dtype) + + res = values.str.get_dummies(dim="ZZ", sep=sep) + + assert res.dtype == expected.dtype + + assert_equal(res, expected) + + +def test_get_dummies_empty(dtype): + values = xr.DataArray([], dims=["X"]).astype(dtype) + expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) + + res = values.str.get_dummies(dim="ZZ") + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + def test_splitters_empty_str(dtype): values = xr.DataArray( [["", "", ""], ["", "", ""]], dims=["X", "Y"], ).astype(dtype) - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - targ_partition_dim = xr.DataArray( [ [["", "", ""], ["", "", ""], ["", "", ""]], @@ -2283,7 +2978,7 @@ def test_splitters_empty_str(dtype): [["", "", ""], ["", "", ""], ["", "", "", ""]], ] targ_partition_none = [ - [[conv(x) for x in y] for y in z] for z in targ_partition_none + [[dtype(x) for x in y] for y in z] for z in targ_partition_none ] targ_partition_none = np.array(targ_partition_none, dtype=np.object_) del targ_partition_none[-1, -1][-1] @@ -2339,58 +3034,6 @@ def test_splitters_empty_str(dtype): assert_equal(res_dummies, targ_split_dim) -def test_splitters_empty_array(dtype): - values = xr.DataArray( - [[], []], - dims=["X", "Y"], - ).astype(dtype) - - targ_dim = xr.DataArray( - np.empty([2, 0, 0]), - dims=["X", "Y", "ZZ"], - ).astype(dtype) - targ_none = xr.DataArray( - np.empty([2, 0]), - dims=["X", "Y"], - ).astype(np.object_) - - res_part_dim = values.str.partition(dim="ZZ") - res_rpart_dim = values.str.rpartition(dim="ZZ") - res_part_none = values.str.partition(dim=None) - res_rpart_none = values.str.rpartition(dim=None) - - res_split_dim = values.str.split(dim="ZZ") - res_rsplit_dim = values.str.rsplit(dim="ZZ") - res_split_none = values.str.split(dim=None) - res_rsplit_none = values.str.rsplit(dim=None) - - res_dummies = values.str.get_dummies(dim="ZZ") - - assert res_part_dim.dtype == targ_dim.dtype - assert res_rpart_dim.dtype == targ_dim.dtype - assert res_part_none.dtype == targ_none.dtype - assert res_rpart_none.dtype == targ_none.dtype - - assert res_split_dim.dtype == targ_dim.dtype - assert res_rsplit_dim.dtype == targ_dim.dtype - assert res_split_none.dtype == targ_none.dtype - assert res_rsplit_none.dtype == targ_none.dtype - - assert res_dummies.dtype == targ_dim.dtype - - assert_equal(res_part_dim, targ_dim) - assert_equal(res_rpart_dim, targ_dim) - assert_equal(res_part_none, targ_none) - assert_equal(res_rpart_none, targ_none) - - assert_equal(res_split_dim, targ_dim) - assert_equal(res_rsplit_dim, targ_dim) - assert_equal(res_split_none, targ_none) - assert_equal(res_rsplit_none, targ_none) - - assert_equal(res_dummies, targ_dim) - - def test_cat_str(dtype): values_1 = xr.DataArray( [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], @@ -2666,7 +3309,6 @@ def test_cat_broadcast_both(dtype): def test_cat_multi(): - dtype = np.unicode_ values_1 = xr.DataArray( ["11111", "4"], dims=["X"], @@ -2686,9 +3328,9 @@ def test_cat_multi(): sep = xr.DataArray( [" ", ", "], dims=["ZZ"], - ).astype(dtype) + ).astype(np.unicode_) - targ = xr.DataArray( + expected = xr.DataArray( [ [ ["11111 a 3.4 ", "11111, a, 3.4, , "], @@ -2702,12 +3344,27 @@ def test_cat_multi(): ], ], dims=["X", "Y", "ZZ"], - ).astype(dtype) + ).astype(np.unicode_) res = values_1.str.cat(values_2, values_3, values_4, values_5, sep=sep) - assert res.dtype == targ.dtype - assert_equal(res, targ) + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_join_scalar(dtype): + values = xr.DataArray("aaa").astype(dtype) + + targ = xr.DataArray("aaa").astype(dtype) + + res_blank = values.str.join() + res_space = values.str.join(sep=" ") + + assert res_blank.dtype == targ.dtype + assert res_space.dtype == targ.dtype + + assert_identical(res_blank, targ) + assert_identical(res_space, targ) def test_join_vector(dtype): @@ -2723,7 +3380,7 @@ def test_join_vector(dtype): res_blank_y = values.str.join(dim="Y") res_space_none = values.str.join(sep=" ") - res_space_y = values.str.join(sep=" ", dim="Y") + res_space_y = values.str.join(dim="Y", sep=" ") assert res_blank_none.dtype == targ_blank.dtype assert res_blank_y.dtype == targ_blank.dtype @@ -2764,7 +3421,7 @@ def test_join_2d(dtype): res_blank_y = values.str.join(dim="Y") res_space_x = values.str.join(dim="X", sep=" ") - res_space_y = values.str.join(sep=" ", dim="Y") + res_space_y = values.str.join(dim="Y", sep=" ") assert res_blank_x.dtype == targ_blank_x.dtype assert res_blank_y.dtype == targ_blank_y.dtype @@ -2776,7 +3433,9 @@ def test_join_2d(dtype): assert_identical(res_space_x, targ_space_x) assert_identical(res_space_y, targ_space_y) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Dimension must be specified for multidimensional arrays." + ): values.str.join() @@ -2791,23 +3450,22 @@ def test_join_broadcast(dtype): dims=["ZZ"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( ["a bb cccc", "a, bb, cccc"], dims=["ZZ"], ).astype(dtype) res = values.str.join(sep=sep) - assert res.dtype == targ.dtype - assert_identical(res, targ) + assert res.dtype == expected.dtype + assert_identical(res, expected) def test_format_scalar(): - dtype = np.unicode_ values = xr.DataArray( ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], dims=["X"], - ).astype(dtype) + ).astype(np.unicode_) pos0 = 1 pos1 = 1.2 @@ -2817,23 +3475,22 @@ def test_format_scalar(): ZZ = None W = "NO!" - targ = xr.DataArray( + expected = xr.DataArray( ["1.X.None", "1,1.2,'test','test'", "'test'-X-None"], dims=["X"], - ).astype(dtype) + ).astype(np.unicode_) res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) - assert res.dtype == targ.dtype - assert_equal(res, targ) + assert res.dtype == expected.dtype + assert_equal(res, expected) -def test_format_broadcast(dtype): - dtype = np.unicode_ +def test_format_broadcast(): values = xr.DataArray( ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], dims=["X"], - ).astype(dtype) + ).astype(np.unicode_) pos0 = 1 pos1 = 1.2 @@ -2848,16 +3505,109 @@ def test_format_broadcast(dtype): ZZ = None W = "NO!" - targ = xr.DataArray( + expected = xr.DataArray( [ ["1.X.None", "1.X.None"], ["1,1.2,'test','test'", "1,1.2,'test','test'"], ["'test'-X-None", "'test'-X-None"], ], dims=["X", "YY"], - ).astype(dtype) + ).astype(np.unicode_) res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) - assert res.dtype == targ.dtype - assert_equal(res, targ) + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_scalar(): + values = xr.DataArray( + ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], + dims=["X"], + ).astype(np.unicode_) + + pos0 = 1 + pos1 = 1.2 + pos2 = "2.3" + + expected = xr.DataArray( + ["1.1.2.2.3", "1,1.2,2.3", "1-1.2-2.3"], + dims=["X"], + ).astype(np.unicode_) + + res = values.str % (pos0, pos1, pos2) + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_dict(): + values = xr.DataArray( + ["%(a)s.%(a)s.%(b)s", "%(b)s,%(c)s,%(b)s", "%(c)s-%(b)s-%(a)s"], + dims=["X"], + ).astype(np.unicode_) + + a = 1 + b = 1.2 + c = "2.3" + + expected = xr.DataArray( + ["1.1.1.2", "1.2,2.3,1.2", "2.3-1.2-1"], + dims=["X"], + ).astype(np.unicode_) + + res = values.str % {"a": a, "b": b, "c": c} + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_broadcast_single(): + values = xr.DataArray( + ["%s_1", "%s_2", "%s_3"], + dims=["X"], + ).astype(np.unicode_) + + pos = xr.DataArray( + ["2.3", "3.44444"], + dims=["YY"], + ) + + expected = xr.DataArray( + [["2.3_1", "3.44444_1"], ["2.3_2", "3.44444_2"], ["2.3_3", "3.44444_3"]], + dims=["X", "YY"], + ).astype(np.unicode_) + + res = values.str % pos + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_broadcast_multi(): + values = xr.DataArray( + ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], + dims=["X"], + ).astype(np.unicode_) + + pos0 = 1 + pos1 = 1.2 + + pos2 = xr.DataArray( + ["2.3", "3.44444"], + dims=["YY"], + ) + + expected = xr.DataArray( + [ + ["1.1.2.2.3", "1.1.2.3.44444"], + ["1,1.2,2.3", "1,1.2,3.44444"], + ["1-1.2-2.3", "1-1.2-3.44444"], + ], + dims=["X", "YY"], + ).astype(np.unicode_) + + res = values.str % (pos0, pos1, pos2) + + assert res.dtype == expected.dtype + assert_equal(res, expected) From 1ffc79efd3fed3efd5bd4b2991fa5bd1382559fb Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Sat, 6 Mar 2021 00:02:24 -0500 Subject: [PATCH 07/14] update whats-new.rst, api.rst, and api-hidden.rst --- doc/api-hidden.rst | 13 +++++++++++++ doc/api.rst | 20 ++++++++++++++++++++ doc/whats-new.rst | 17 +++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 14d79039a3a..a5280c21b7e 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -324,14 +324,21 @@ core.accessor_dt.TimedeltaAccessor.seconds core.accessor_str.StringAccessor.capitalize + core.accessor_str.StringAccessor.casefold + core.accessor_str.StringAccessor.cat core.accessor_str.StringAccessor.center core.accessor_str.StringAccessor.contains core.accessor_str.StringAccessor.count core.accessor_str.StringAccessor.decode core.accessor_str.StringAccessor.encode core.accessor_str.StringAccessor.endswith + core.accessor_str.StringAccessor.extract + core.accessor_str.StringAccessor.extractall core.accessor_str.StringAccessor.find + core.accessor_str.StringAccessor.findall + core.accessor_str.StringAccessor.format core.accessor_str.StringAccessor.get + core.accessor_str.StringAccessor.get_dummies core.accessor_str.StringAccessor.index core.accessor_str.StringAccessor.isalnum core.accessor_str.StringAccessor.isalpha @@ -342,20 +349,26 @@ core.accessor_str.StringAccessor.isspace core.accessor_str.StringAccessor.istitle core.accessor_str.StringAccessor.isupper + core.accessor_str.StringAccessor.join core.accessor_str.StringAccessor.len core.accessor_str.StringAccessor.ljust core.accessor_str.StringAccessor.lower core.accessor_str.StringAccessor.lstrip core.accessor_str.StringAccessor.match + core.accessor_str.StringAccessor.normalize core.accessor_str.StringAccessor.pad + core.accessor_str.StringAccessor.partition core.accessor_str.StringAccessor.repeat core.accessor_str.StringAccessor.replace core.accessor_str.StringAccessor.rfind core.accessor_str.StringAccessor.rindex core.accessor_str.StringAccessor.rjust + core.accessor_str.StringAccessor.rpartition + core.accessor_str.StringAccessor.rsplit core.accessor_str.StringAccessor.rstrip core.accessor_str.StringAccessor.slice core.accessor_str.StringAccessor.slice_replace + core.accessor_str.StringAccessor.split core.accessor_str.StringAccessor.startswith core.accessor_str.StringAccessor.strip core.accessor_str.StringAccessor.swapcase diff --git a/doc/api.rst b/doc/api.rst index 9add7a96109..f58b89b0766 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -420,38 +420,58 @@ String manipulation :toctree: generated/ :template: autosummary/accessor_method.rst + DataArray.str._apply + DataArray.str._padder + DataArray.str._partitioner + DataArray.str._re_compile + DataArray.str._splitter + DataArray.str._stringify DataArray.str.capitalize + DataArray.str.casefold + DataArray.str.cat DataArray.str.center DataArray.str.contains DataArray.str.count DataArray.str.decode DataArray.str.encode DataArray.str.endswith + DataArray.str.extract + DataArray.str.extractall DataArray.str.find + DataArray.str.findall + DataArray.str.format DataArray.str.get + DataArray.str.get_dummies DataArray.str.index DataArray.str.isalnum DataArray.str.isalpha DataArray.str.isdecimal DataArray.str.isdigit + DataArray.str.islower DataArray.str.isnumeric DataArray.str.isspace DataArray.str.istitle DataArray.str.isupper + DataArray.str.join DataArray.str.len DataArray.str.ljust DataArray.str.lower DataArray.str.lstrip DataArray.str.match + DataArray.str.normalize DataArray.str.pad + DataArray.str.partition DataArray.str.repeat DataArray.str.replace DataArray.str.rfind DataArray.str.rindex DataArray.str.rjust + DataArray.str.rpartition + DataArray.str.rsplit DataArray.str.rstrip DataArray.str.slice DataArray.str.slice_replace + DataArray.str.split DataArray.str.startswith DataArray.str.strip DataArray.str.swapcase diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9e59fdc5b35..20521d32b4f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,23 @@ New Features - Support for `dask.graph_manipulation `_ (requires dask >=2021.3) By `Guido Imperiale `_ +- Many of the arguments for the :py:attr:`DataArray.str` methods now support + providing an array-like input. In this case, the array provided to the + arguments is broadcast against the original array and applied elementwise. +- :py:attr:`DataArray.str` now supports `+`, `*`, and `%` operators. These + behave the same as they do for :py:class:`str`, except that they follow + array broadcasting rules. +- A large number of new :py:attr:`DataArray.str` methods were implemented, + :py:meth:`DataArray.str.casefold`, :py:meth:`DataArray.str.cat`, + :py:meth:`DataArray.str.extract`, :py:meth:`DataArray.str.extractall`, + :py:meth:`DataArray.str.findall`, :py:meth:`DataArray.str.format`, + :py:meth:`DataArray.str.get_dummies`, :py:meth:`DataArray.str.islower`, + :py:meth:`DataArray.str.join`, :py:meth:`DataArray.str.normalize`, + :py:meth:`DataArray.str.partition`, :py:meth:`DataArray.str.rpartition`, + :py:meth:`DataArray.str.rsplit`, and :py:meth:`DataArray.str.split`. + A number of these methods allow for splitting or joining the strings in an + array. (:issue:`4622`) + Breaking changes ~~~~~~~~~~~~~~~~ From 19c8a910f0989ace54ce7a571d814d4638ada0eb Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Sat, 6 Mar 2021 09:01:55 -0500 Subject: [PATCH 08/14] test fixes --- xarray/tests/test_accessor_str.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index 9bf33893241..a8987944700 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -639,39 +639,39 @@ def test_extract_extractall_name_collision_raises(dtype): value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) - with pytest.raises(KeyError, match="Dimension X already present in DataArray."): + with pytest.raises(KeyError, match="Dimension 'X' already present in DataArray."): value.str.extract(pat=pat_str, dim="X") - with pytest.raises(KeyError, match="Dimension X already present in DataArray."): + with pytest.raises(KeyError, match="Dimension 'X' already present in DataArray."): value.str.extract(pat=pat_re, dim="X") with pytest.raises( - KeyError, match="Group dimension X already present in DataArray." + KeyError, match="Group dimension 'X' already present in DataArray." ): value.str.extractall(pat=pat_str, group_dim="X", match_dim="ZZ") with pytest.raises( - KeyError, match="Group dimension X already present in DataArray." + KeyError, match="Group dimension 'X' already present in DataArray." ): value.str.extractall(pat=pat_re, group_dim="X", match_dim="YY") with pytest.raises( - KeyError, match="Match dimension Y already present in DataArray." + KeyError, match="Match dimension 'Y' already present in DataArray." ): value.str.extractall(pat=pat_str, group_dim="XX", match_dim="Y") with pytest.raises( - KeyError, match="Match dimension Y already present in DataArray." + KeyError, match="Match dimension 'Y' already present in DataArray." ): value.str.extractall(pat=pat_re, group_dim="XX", match_dim="Y") with pytest.raises( - KeyError, match="Group dimension ZZ is the same as match dimension ZZ." + KeyError, match="Group dimension 'ZZ' is the same as match dimension 'ZZ'." ): value.str.extractall(pat=pat_str, group_dim="ZZ", match_dim="ZZ") with pytest.raises( - KeyError, match="Group dimension ZZ is the same as match dimension ZZ." + KeyError, match="Group dimension 'ZZ' is the same as match dimension 'ZZ'." ): value.str.extractall(pat=pat_re, group_dim="ZZ", match_dim="ZZ") From 48d67c76bc185a7e966e53162f85341601fc8c5f Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Sat, 6 Mar 2021 09:08:49 -0500 Subject: [PATCH 09/14] implement requested fixes --- xarray/core/accessor_str.py | 64 ++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 81ba87b12ca..f0857b60252 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -130,7 +130,7 @@ def _apply_str_ufunc( class StringAccessor: - """Vectorized string functions for string-like arrays. + r"""Vectorized string functions for string-like arrays. Similar to pandas, fields can be accessed through the `.str` attribute for applicable DataArrays. @@ -187,7 +187,7 @@ class StringAccessor: >>> da2 = xr.DataArray([1, 2, 3], dims=["Y"]) >>> da1 % {"a": da2} - array(['\\narray([1, 2, 3])\\nDimensions without coordinates: Y'], + array(['\narray([1, 2, 3])\nDimensions without coordinates: Y'], dtype=object) Dimensions without coordinates: X """ @@ -1631,7 +1631,7 @@ def extract( case: bool = None, flags: int = 0, ) -> Any: - """ + r""" Extract the first match of capture groups in the regex pat as a new dimension in a DataArray. @@ -1697,7 +1697,7 @@ def extract( Extract matches - >>> value.str.extract(r"(\\w+)_Xy_(\\d*)", dim="match") + >>> value.str.extract(r"(\w+)_Xy_(\d*)", dim="match") array([[['a', '0'], ['bab', '110'], @@ -1776,7 +1776,7 @@ def extractall( case: bool = None, flags: int = 0, ) -> Any: - """ + r""" Extract all matches of capture groups in the regex pat as new dimensions in a DataArray. @@ -1847,7 +1847,7 @@ def extractall( Extract matches >>> value.str.extractall( - ... r"(\\w+)_Xy_(\\d*)", group_dim="group", match_dim="match" + ... r"(\w+)_Xy_(\d*)", group_dim="group", match_dim="match" ... ) array([[[['a', '0'], @@ -1949,7 +1949,7 @@ def findall( case: bool = None, flags: int = 0, ) -> Any: - """ + r""" Find all occurrences of pattern or regular expression in the DataArray. Equivalent to applying re.findall() to all the elements in the DataArray. @@ -2009,7 +2009,7 @@ def findall( Extract matches - >>> value.str.findall(r"(\\w+)_Xy_(\\d*)") + >>> value.str.findall(r"(\w+)_Xy_(\d*)") array([[list([('a', '0')]), list([('bab', '110'), ('baab', '1100')]), list([('abc', '01'), ('cbc', '2210')])], @@ -2199,7 +2199,7 @@ def split( sep: Union[str, bytes, Any] = None, maxsplit: int = -1, ) -> Any: - """ + r""" Split strings in a DataArray around the given separator/delimiter `sep`. Splits the string in the DataArray from the beginning, @@ -2230,8 +2230,8 @@ def split( >>> values = xr.DataArray( ... [ - ... ["abc def", "spam\\t\\teggs\\tswallow", "red_blue"], - ... ["test0\\ntest1\\ntest2\\n\\ntest3", "", "abra ka\\nda\\tbra"], + ... ["abc def", "spam\t\teggs\tswallow", "red_blue"], + ... ["test0\ntest1\ntest2\n\ntest3", "", "abra ka\nda\tbra"], ... ], ... dims=["X", "Y"], ... ) @@ -2241,12 +2241,12 @@ def split( >>> values.str.split(dim="splitted", maxsplit=1) array([[['abc', 'def'], - ['spam', 'eggs\\tswallow'], + ['spam', 'eggs\tswallow'], ['red_blue', '']], - [['test0', 'test1\\ntest2\\n\\ntest3'], + [['test0', 'test1\ntest2\n\ntest3'], ['', ''], - ['abra', 'ka\\nda\\tbra']]], dtype='>> values.str.split(dim=None, maxsplit=1) - array([[list(['abc', 'def']), list(['spam', 'eggs\\tswallow']), + array([[list(['abc', 'def']), list(['spam', 'eggs\tswallow']), list(['red_blue'])], - [list(['test0', 'test1\\ntest2\\n\\ntest3']), list([]), - list(['abra', 'ka\\nda\\tbra'])]], dtype=object) + [list(['test0', 'test1\ntest2\n\ntest3']), list([]), + list(['abra', 'ka\nda\tbra'])]], dtype=object) Dimensions without coordinates: X, Y Split as many times as needed and put the results in a list @@ -2287,12 +2287,12 @@ def split( >>> values.str.split(dim="splitted", sep=" ") array([[['abc', 'def', ''], - ['spam\\t\\teggs\\tswallow', '', ''], + ['spam\t\teggs\tswallow', '', ''], ['red_blue', '', '']], - [['test0\\ntest1\\ntest2\\n\\ntest3', '', ''], + [['test0\ntest1\ntest2\n\ntest3', '', ''], ['', '', ''], - ['abra', '', 'ka\\nda\\tbra']]], dtype=' Any: - """ + r""" Split strings in a DataArray around the given separator/delimiter `sep`. Splits the string in the DataArray from the end, @@ -2348,8 +2348,8 @@ def rsplit( >>> values = xr.DataArray( ... [ - ... ["abc def", "spam\\t\\teggs\\tswallow", "red_blue"], - ... ["test0\\ntest1\\ntest2\\n\\ntest3", "", "abra ka\\nda\\tbra"], + ... ["abc def", "spam\t\teggs\tswallow", "red_blue"], + ... ["test0\ntest1\ntest2\n\ntest3", "", "abra ka\nda\tbra"], ... ], ... dims=["X", "Y"], ... ) @@ -2359,12 +2359,12 @@ def rsplit( >>> values.str.rsplit(dim="splitted", maxsplit=1) array([[['abc', 'def'], - ['spam\\t\\teggs', 'swallow'], + ['spam\t\teggs', 'swallow'], ['', 'red_blue']], - [['test0\\ntest1\\ntest2', 'test3'], + [['test0\ntest1\ntest2', 'test3'], ['', ''], - ['abra ka\\nda', 'bra']]], dtype='>> values.str.rsplit(dim=None, maxsplit=1) - array([[list(['abc', 'def']), list(['spam\\t\\teggs', 'swallow']), + array([[list(['abc', 'def']), list(['spam\t\teggs', 'swallow']), list(['red_blue'])], - [list(['test0\\ntest1\\ntest2', 'test3']), list([]), - list(['abra ka\\nda', 'bra'])]], dtype=object) + [list(['test0\ntest1\ntest2', 'test3']), list([]), + list(['abra ka\nda', 'bra'])]], dtype=object) Dimensions without coordinates: X, Y Split as many times as needed and put the results in a list @@ -2405,12 +2405,12 @@ def rsplit( >>> values.str.rsplit(dim="splitted", sep=" ") array([[['', 'abc', 'def'], - ['', '', 'spam\\t\\teggs\\tswallow'], + ['', '', 'spam\t\teggs\tswallow'], ['', '', 'red_blue']], - [['', '', 'test0\\ntest1\\ntest2\\n\\ntest3'], + [['', '', 'test0\ntest1\ntest2\n\ntest3'], ['', '', ''], - ['abra', '', 'ka\\nda\\tbra']]], dtype=' Date: Sat, 6 Mar 2021 09:45:46 -0500 Subject: [PATCH 10/14] more fixes --- xarray/core/accessor_str.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index f0857b60252..4954a65f0ee 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -42,7 +42,17 @@ import textwrap from functools import reduce from operator import or_ as set_union -from typing import Any, Callable, Hashable, Mapping, Optional, Pattern, Tuple, Union +from typing import ( + Any, + Callable, + Hashable, + Mapping, + Optional, + Pattern, + Tuple, + Type, + Union, +) from unicodedata import normalize import numpy as np @@ -215,7 +225,7 @@ def _apply( self, *, func: Callable, - dtype: Union[str, np.dtype] = None, + dtype: Union[str, np.dtype, Type] = None, output_core_dims: Union[list, tuple] = ((),), output_sizes: Mapping[Hashable, int] = None, func_args: Tuple = (), From 408a58bf7808467441046913ce4b96c71b28b025 Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Sat, 6 Mar 2021 10:07:37 -0500 Subject: [PATCH 11/14] typing fixes --- xarray/core/accessor_str.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 4954a65f0ee..e2f8133dc78 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -112,7 +112,7 @@ def _apply_str_ufunc( *, func: Callable, obj: Any, - dtype: Union[str, np.dtype] = None, + dtype: Union[str, np.dtype, Type] = None, output_core_dims: Union[list, tuple] = ((),), output_sizes: Mapping[Hashable, int] = None, func_args: Tuple = (), From adfd09d37fe1ab511e414e22501121280b9f34c9 Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Sun, 7 Mar 2021 00:11:02 -0500 Subject: [PATCH 12/14] fix docstring --- xarray/core/accessor_str.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index e2f8133dc78..b4080330e04 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -2528,19 +2528,16 @@ def decode( ---------- encoding : str The encoding to use. - Please see the Python `codecs `_ documentation for a list - of encodings handlers + Please see the Python documentation `codecs standard encoders `_ + section for a list of encodings handlers. errors : str, optional The handler for encoding errors. - Please see the Python `codecs `_ documentation for a list - of error handlers + Please see the Python documentation `codecs error handlers `_ + for a list of error handlers. Returns ------- decoded : same type as values - - .. _encodings: https://docs.python.org/3/library/codecs.html#standard-encodings - .. _handlers: https://docs.python.org/3/library/codecs.html#error-handlers """ if encoding in _cpython_optimized_decoders: func = lambda x: x.decode(encoding, errors) From 736c9940f4275ca09b785f339f960663b53f3567 Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Sun, 7 Mar 2021 00:33:58 -0500 Subject: [PATCH 13/14] fix more docstring --- xarray/core/accessor_str.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index b4080330e04..f0e416b52e6 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -2558,19 +2558,16 @@ def encode( ---------- encoding : str The encoding to use. - Please see the Python `codecs `_ documentation for a list - of encodings handlers + Please see the Python documentation `codecs standard encoders `_ + section for a list of encodings handlers. errors : str, optional The handler for encoding errors. - Please see the Python `codecs `_ documentation for a list - of error handlers + Please see the Python documentation `codecs error handlers `_ + for a list of error handlers. Returns ------- encoded : same type as values - - .. _encodings: https://docs.python.org/3/library/codecs.html#standard-encodings - .. _handlers: https://docs.python.org/3/library/codecs.html#error-handlers """ if encoding in _cpython_optimized_encoders: func = lambda x: x.encode(encoding, errors) From 3473ac3815d29d73ae1b713ff2ef5ed76f19b1ee Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Sun, 7 Mar 2021 23:03:16 -0500 Subject: [PATCH 14/14] remove encoding header --- xarray/tests/test_accessor_str.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index a8987944700..519ca762c41 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - # Tests for the `str` accessor are derived from the original # pandas string accessor tests.