Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add get_name patch/spool method #471

Merged
merged 9 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dascore/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dascore.core.coords import BaseCoord
from dascore.utils.display import array_to_text, attrs_to_text, get_dascore_text
from dascore.utils.models import ArrayLike
from dascore.utils.patch import check_patch_attrs, check_patch_coords
from dascore.utils.patch import check_patch_attrs, check_patch_coords, get_patch_names
from dascore.utils.time import to_float
from dascore.viz import VizPatchNameSpace

Expand Down Expand Up @@ -260,6 +260,7 @@ def channel_count(self) -> int:
coords_from_df = dascore.proc.coords_from_df
make_broadcastable_to = dascore.proc.make_broadcastable_to
apply_ufunc = dascore.proc.apply_ufunc
get_patch_names = get_patch_names

def assign_coords(self, *args, **kwargs):
"""Deprecated method for update_coords."""
Expand Down
11 changes: 9 additions & 2 deletions dascore/core/spool.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_force_patch_merge,
_spool_up,
concatenate_patches,
get_patch_names,
patches_to_df,
stack_patches,
)
Expand Down Expand Up @@ -601,6 +602,8 @@ def get_contents(self) -> pd.DataFrame:
"""{doc}."""
return self._df[filter_df(self._df, **self._select_kwargs)]

get_patch_names = get_patch_names


class MemorySpool(DataFrameSpool):
"""A Spool for storing patches in memory."""
Expand All @@ -617,10 +620,14 @@ def __rich__(self):
base = super().__rich__()
df = self._df
if len(df):
t1, t2 = df["time_min"].min(), df["time_max"].max()
t1 = df["time_min"].min() if "time_min" in df.columns else ""
t2 = df["time_min"].max() if "time_min" in df.columns else ""
tmin = get_nice_text(t1)
tmax = get_nice_text(t2)
duration = get_nice_text(t2 - t1)
if t1 != "" and t2 != "":
duration = get_nice_text(t2 - t1)
else:
duration = ""
base += Text(f"\n Time Span: <{duration}> {tmin} to {tmax}")
return base

Expand Down
5 changes: 3 additions & 2 deletions dascore/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dascore.utils.docs import compose_docstring
from dascore.utils.downloader import fetch
from dascore.utils.misc import register_func
from dascore.utils.patch import get_default_patch_name
from dascore.utils.patch import get_patch_names
from dascore.utils.time import to_timedelta64

EXAMPLE_PATCHES = {}
Expand Down Expand Up @@ -506,7 +506,8 @@ def spool_to_directory(spool, path=None, file_format="DASDAE", extention="hdf5")
path = Path(tempfile.mkdtemp())
assert path.exists()
for patch in spool:
out_path = path / (f"{get_default_patch_name(patch)}.{extention}")
name = get_patch_names(patch).iloc[0]
out_path = path / (f"{name}.{extention}")
patch.io.write(out_path, file_format=file_format)
return path

Expand Down
9 changes: 7 additions & 2 deletions dascore/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
timeable_types,
)
from dascore.core.attrs import str_validator
from dascore.core.spool import DataFrameSpool
from dascore.exceptions import (
InvalidFiberIOError,
MissingOptionalDependencyError,
Expand Down Expand Up @@ -648,7 +649,7 @@ def read(


def scan_to_df(
path: Path | str | PatchType | SpoolType | IOResourceManager,
path: Path | str | PatchType | SpoolType | IOResourceManager | pd.DataFrame,
file_format: str | None = None,
file_version: str | None = None,
ext: str | None = None,
Expand All @@ -665,7 +666,7 @@ def scan_to_df(
Parameters
----------
path
The path the to file to scan
The path to the to file to scan
file_format
Format of the file. If not provided DASCore will try to determine it.
file_version
Expand All @@ -682,6 +683,10 @@ def scan_to_df(
>>>
>>> df = dc.scan_to_df(file_path)
"""
if isinstance(path, pd.DataFrame):
return path
if isinstance(path, DataFrameSpool):
return path.get_contents()
info = scan(
path=path,
file_format=file_format,
Expand Down
9 changes: 5 additions & 4 deletions dascore/io/dasdae/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
PyTablesWriter,
)
from dascore.utils.misc import unbyte
from dascore.utils.patch import get_default_patch_name
from dascore.utils.patch import get_patch_names

from .utils import (
_get_contents_from_patch_groups,
Expand Down Expand Up @@ -78,8 +78,9 @@ def write(self, spool: SpoolType, resource: PyTablesWriter, index=False, **kwarg
resource.create_group(resource.root, "waveforms")
waveforms = resource.get_node("/waveforms")
# write new patches to file
for patch in patches:
_save_patch(patch, waveforms, resource)
patch_names = get_patch_names(patches).values
for patch, name in zip(patches, patch_names):
_save_patch(patch, waveforms, resource, name)
indexer = HDFPatchIndexManager(resource)
if index or indexer.has_index:
df = self._get_patch_summary(patches)
Expand All @@ -90,7 +91,7 @@ def _get_patch_summary(self, patches) -> pd.DataFrame:
df = (
dc.scan_to_df(patches)
.assign(
path=[f"waveforms/{get_default_patch_name(x)}" for x in patches],
path=lambda x: get_patch_names(x),
file_format=self.name,
file_version=self.version,
)
Expand Down
4 changes: 1 addition & 3 deletions dascore/io/dasdae/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dascore.core.coordmanager import get_coord_manager
from dascore.core.coords import get_coord
from dascore.utils.misc import suppress_warnings
from dascore.utils.patch import get_default_patch_name
from dascore.utils.time import to_int

# --- Functions for writing DASDAE format
Expand Down Expand Up @@ -80,9 +79,8 @@ def _save_coords(patch, patch_group, h5):
patch_group._v_attrs[save_name] = ",".join(dims)


def _save_patch(patch, wave_group, h5):
def _save_patch(patch, wave_group, h5, name):
"""Save the patch to disk."""
name = get_default_patch_name(patch)
patch_group = _create_or_get_group(h5, wave_group, name)
_save_attrs_and_dims(patch, patch_group)
_save_coords(patch, patch_group, h5)
Expand Down
120 changes: 87 additions & 33 deletions dascore/utils/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
warn_or_raise,
yield_sub_sequences,
)
from dascore.utils.time import to_datetime64, to_float
from dascore.utils.time import to_float

attr_type = dict[str, Any] | str | Sequence[str] | None

Expand Down Expand Up @@ -306,7 +306,10 @@ def patches_to_df(
elif isinstance(patches, pd.DataFrame):
df = patches
else:
df = pd.DataFrame([x.flat_dump() for x in scan_patches(patches)])
df = dc.scan_to_df(
patches,
exclude=(),
)
if df.empty: # create empty df with appropriate columns
cols = list(dc.PatchAttrs().model_dump())
df = pd.DataFrame(columns=cols).assign(patch=None, history=None)
Expand Down Expand Up @@ -418,24 +421,6 @@ def _get_new_coord(df, merge_dim, coords):
return [new_dict]


def scan_patches(patches: PatchType | Sequence[PatchType]) -> list[dc.PatchAttrs]:
"""
Scan a sequence of patches and return a list of summaries.

The summary dicts have the following fields:
{fields}

Parameters
----------
patches
A single patch or a sequence of patches.
"""
if isinstance(patches, dc.Patch):
patches = [patches] # make sure we have an iterable
out = [pa.attrs for pa in patches]
return out


def get_start_stop_step(patch: PatchType, dim):
"""Convenience method for getting start, stop, step for a given coord."""
assert dim in patch.dims, f"{dim} is not in Patch dimensions of {patch.dims}"
Expand All @@ -446,21 +431,90 @@ def get_start_stop_step(patch: PatchType, dim):
return start, stop, step


def get_default_patch_name(patch):
"""Generates the name of the node."""
def get_patch_names(
patch_data: pd.DataFrame | dc.Patch | dc.BaseSpool,
prefix="DAS",
attrs=("network", "station", "tag"),
coords=("time",),
sep="__",
Comment on lines +436 to +440
Copy link
Collaborator

@ahmadtourei ahmadtourei Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need test(s) for optional inputs . For example, get_patch_names(patch, patch) does not raise an error.

) -> pd.Series:
"""
Generates the default name of patch data.

Parameters
----------
prefix
A string to prefix the names.
patch_data
A container with patch data.
coords
The coordinate ranges to use for names.
sep
The separator for the strings.

Notes
-----
There are two special cases where the default logic is overwritten.
The first one, is when a column called "name" already exists. This
will simply be returned.

The second is when a column called "path" exists. In this case, the
output will be the file name with the extension removed. The path must
use / as a delinater.

def _format_datetime64(dt):
Examples
--------
>>> import dascore as dc
>>> from dascore.utils.patch import get_patch_names
>>> patch = dc.get_example_patch()
>>> name = get_patch_names(patch)
d-chambers marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

@ahmadtourei ahmadtourei Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name is:

0    DAS_______random__2017_09_18__2017_09_18T00_00_07
Name: network, dtype: object

We get DAS_______random__2017_09_18__2017_09_18T00_00_07 using name[0] because the function returns a <class 'pandas.core.series.Series'>. Is there a specific reason that it returns a pandas series with Name: network, dtype: object (a little confusing) instead of the names in string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for the function to work with a spool it needs to return multiple names. I avoided returning multiple types (based on if the input is a Patch or not) from get_patch_names as that is a bit of an anti-pattern.

I understand the confusion though. I will add get_patch_name as a method of Patch that will do this. I think with the different name (get_patch_name vs get_patch_names) it will be clear.

"""

def _format_time_column(ser):
"""Format the time column."""
ser = ser.astype(str).str.split(".", expand=True)[0]
chars_to_replace = (":", "-")
for char in chars_to_replace:
ser = ser.str.replace(char, "_")
ser = ser.str.replace(" ", "T")
return ser

def _format_time_columns(df):
"""Format the datetime string in a sensible way."""
out = str(to_datetime64(dt))
return out.replace(":", "_").replace("-", "_").replace(".", "_")

attrs = patch.attrs
start = _format_datetime64(attrs.get("time_min", ""))
end = _format_datetime64(attrs.get("time_max", ""))
net = attrs.get("network", "")
sta = attrs.get("station", "")
tag = attrs.get("tag", "")
return f"DAS__{net}__{sta}__{tag}__{start}__{end}"
sub = df.select_dtypes(include=["datetime64", "timedelta64"])
out = {}
for col in sub.columns:
out[col] = _format_time_column(df[col])
return df.assign(**out)

def _get_filename(path_ser):
"""Get the file name from a path series."""
ser = path_ser.astype(str)
file_names = [x[-1].split(".")[0] for x in ser.str.split("/")]
Copy link
Collaborator

@ahmadtourei ahmadtourei Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we introduce an optional argument (such as keep_extension) so the user can decide whether to get patch's name with or without file extension? I can work on this as I already implemented this in my version #461. What do you think?

return pd.Series(file_names)

# Ensure we are working with a dataframe.
df = dc.scan_to_df(
patch_data,
exclude=(),
)
if df.empty:
return pd.Series(dtype=str)
col_set = set(df.columns)
# Handle special cases.
if "name" in col_set:
return df["name"].astype(str)
if "path" in col_set:
return _get_filename(df["path"])
# Determine the requested fields and get the ones that are there.
coord_fields = zip([f"{x}_min" for x in coords], [f"{x}_max" for x in coords])
requested_fields = list(attrs) + list(*coord_fields)
current = set(df.columns)
fields = [x for x in requested_fields if x in current]
# Get a sub dataframe and convert any datetime things to strings.
sub = df[fields].pipe(_format_time_columns).fillna("").astype(str)
out = f"{prefix}_{sep}" + sub[fields[0]].str.cat(sub[fields[1:]], sep=sep)
return out


def get_dim_axis_value(
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,15 @@ def one_file_dir(tmp_path_factory, random_patch):
return ex.spool_to_directory(spool, path=out)


@pytest.fixture(scope="session")
def random_spool_directory(tmp_path_factory):
"""A directory with a few patch files."""
out = Path(tmp_path_factory.mktemp("one_file_file_spool"))
spool = ex.get_example_spool("random_das")
out_path = ex.spool_to_directory(spool, path=out)
return dc.spool(out_path).update()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have this as an example in dascore.examples as all of the example spools are memory spools.



@pytest.fixture(scope="class")
def two_patch_directory(tmp_path_factory, terra15_das_example_path, random_patch):
"""Create a directory of DAS files for testing."""
Expand Down
8 changes: 7 additions & 1 deletion tests/test_core/test_spool.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_spool_from_emtpy_sequence(self):
assert len(out) == 0

def test_updated_spool_eq(self, random_spool):
"""Ensure updating the spool doesnt change equality."""
"""Ensure updating the spool doesn't change equality."""
assert random_spool == random_spool.update()

def test_empty_spool_str(self):
Expand All @@ -65,6 +65,12 @@ def test_empty_spool_str(self):
spool_str = str(spool)
assert "Spool" in spool_str

def test_spool_with_empty_patch_str(self):
"""A spool with an empty patch should have a str."""
spool = dc.spool(dc.Patch())
spool_str = str(spool)
assert "Spool" in spool_str

def test_base_concat_raises(self, random_spool):
"""Ensure BaseSpool.concatenate raises NotImplementedError."""
msg = "has no concatenate implementation"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_io/test_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import dascore as dc
from dascore.examples import spool_to_directory
from dascore.io.indexer import DirectoryIndexer
from dascore.utils.patch import get_default_patch_name
from dascore.utils.patch import get_patch_names


@pytest.fixture(scope="class")
Expand Down Expand Up @@ -224,7 +224,7 @@ def index_new_version(self, monkeypatch, tmp_path_factory):

def test_add_one_patch(self, empty_index, random_patch):
"""Ensure a new patch added to the directory shows up."""
path = empty_index.path / get_default_patch_name(random_patch)
path = empty_index.path / get_patch_names(random_patch).iloc[0]
random_patch.io.write(path, file_format="dasdae")
new_index = empty_index.update()
contents = new_index()
Expand Down
16 changes: 16 additions & 0 deletions tests/test_io/test_io_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,22 @@ def raise_os_error(*args, **kwargs):
assert not len(scan)


class TestScanToDF:
"""Tests for scanning to dataframes."""

def test_input_dataframe(self, random_spool):
"""Ensure a dataframe returns a dataframe."""
df = random_spool.get_contents()
out = dc.scan_to_df(df)
assert out is df

def test_spool_dataframe(self, random_spool_directory):
"""Ensure scan_to_df just gets the dataframe from the spool."""
expected = random_spool_directory.get_contents()
out = dc.scan_to_df(random_spool_directory)
assert out.equals(expected)


class TestCastType:
"""Test suite to ensure types are intelligently cast to type hints."""

Expand Down
Loading
Loading