Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transition from zip-list to entry core #454

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions snakebids/core/_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Mapping

import attrs
import more_itertools as itx

from snakebids.io.printing import format_zip_lists
from snakebids.types import ZipList, ZipListLike
from snakebids.utils.containers import ContainerBag, MultiSelectDict, RegexContainer

if TYPE_CHECKING:

def wcard_tuple(x: Iterable[str]) -> tuple[str, ...]:
return tuple(x)

def entries_list(x: Iterable[tuple[str, ...]]) -> list[tuple[str, ...]]:
return list(x)

def liststr() -> list[str]:
return []
else:
wcard_tuple = tuple
entries_list = list
liststr = list


@attrs.frozen(kw_only=True)
class BidsTable:
"""Container holding the entries of a BidsComponent."""

wildcards: tuple[str, ...] = attrs.field(converter=wcard_tuple)
entries: list[tuple[str, ...]] = attrs.field(converter=entries_list)

def __bool__(self):
"""Return True if one or more entries, otherwise False."""
return bool(self.entries)

def __eq__(self, other: BidsTable | object):
if not isinstance(other, self.__class__):
return False
if set(self.wildcards) != set(other.wildcards):
return False
if len(self.entries) != len(other.entries):
return False
if self.wildcards == other.wildcards:
return sorted(self.entries) == sorted(other.entries)
ixs = [other.wildcards.index(w) for w in self.wildcards]
entries = self.entries.copy()
try:
for entry in other.entries:
sorted_entry = tuple(entry[i] for i in ixs)
entries.remove(sorted_entry)
except ValueError:
return False
return True

@classmethod
def from_dict(cls, d: ZipListLike):
"""Construct BidsTable from a mapping of entities to value lists."""
lengths = {len(val) for val in d.values()}
if len(lengths) > 1:
msg = "each entity must have the same number of values"
raise ValueError(msg)
return cls(wildcards=d.keys(), entries=zip(*d.values()))

def to_dict(self) -> ZipList:
"""Convert into a zip_list."""
if not self.entries:
return MultiSelectDict(zip(self.wildcards, itx.repeatfunc(liststr)))
return MultiSelectDict(zip(self.wildcards, map(list, zip(*self.entries))))

def pformat(self, max_width: int | float | None = None, tabstop: int = 4) -> str:
"""Pretty-format."""
return format_zip_lists(self.to_dict(), max_width=max_width, tabstop=tabstop)

def get(self, wildcard: str):
"""Get values for a single wildcard."""
index = self.wildcards.index(wildcard)
return [entry[index] for entry in self.entries]

def pick(self, wildcards: Iterable[str]):
"""Select wildcards without deduplication."""
# Use dict.fromkeys for de-duplication to preserve order
unique_keys = list(dict.fromkeys(wildcards))
indices = [self.wildcards.index(w) for w in unique_keys]

entries = [tuple(entry[i] for i in indices) for entry in self.entries]

return self.__class__(wildcards=unique_keys, entries=entries)

def filter(
self,
filters: Mapping[str, Iterable[str] | str],
regex_search: bool = False,
):
"""Apply filtering operation."""
valid_filters = set(self.wildcards)
if regex_search:
filter_sets = {
self.wildcards.index(key): ContainerBag(
*(RegexContainer(r) for r in itx.always_iterable(vals))
)
for key, vals in filters.items()
if key in valid_filters
}
else:
filter_sets = {
self.wildcards.index(key): set(itx.always_iterable(vals))
for key, vals in filters.items()
if key in valid_filters
}

keep = [
entry
for entry in self.entries
if all(
i not in filter_sets or val in filter_sets[i]
for i, val in enumerate(entry)
)
]

return self.__class__(wildcards=self.wildcards, entries=keep)
94 changes: 47 additions & 47 deletions snakebids/core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from math import inf
from pathlib import Path
from string import Formatter
from typing import Any, Iterable, NoReturn, cast, overload
from typing import Any, Iterable, Mapping, NoReturn, cast, overload

import attr
import more_itertools as itx
Expand All @@ -16,14 +16,14 @@
from typing_extensions import Self, TypedDict

import snakebids.utils.sb_itertools as sb_it
from snakebids.core.filtering import filter_list
from snakebids.core._table import BidsTable
from snakebids.exceptions import DuplicateComponentError
from snakebids.io.console import get_console_size
from snakebids.io.printing import format_zip_lists, quote_wrap
from snakebids.io.printing import quote_wrap
from snakebids.snakemake_compat import expand as sn_expand
from snakebids.types import ZipList
from snakebids.types import ZipList, ZipListLike
from snakebids.utils.containers import ImmutableList, MultiSelectDict, UserDictPy38
from snakebids.utils.utils import get_wildcard_dict, property_alias, zip_list_eq
from snakebids.utils.utils import get_wildcard_dict, property_alias


class BidsDatasetDict(TypedDict):
Expand Down Expand Up @@ -176,12 +176,21 @@ def filter(
msg = "Both __spec and filters cannot be used simultaneously"
raise ValueError(msg)
filters = {self.entity: spec}
entity, data = itx.first(
filter_list(
{self.entity: self._data}, filters, regex_search=regex_search
).items()
data = it.chain.from_iterable(
BidsTable(wildcards=[self.entity], entries=[(el,) for el in self._data])
.filter(filters, regex_search=regex_search)
.entries
)
return self.__class__(data, entity=entity)
return self.__class__(data, entity=self.entity)


def _to_bids_table(tbl: BidsTable | ZipListLike) -> BidsTable:
if isinstance(tbl, BidsTable):
return tbl
if isinstance(tbl, Mapping): # type: ignore
return BidsTable.from_dict(tbl)
msg = f"Cannot convert '{tbl}' to BidsTable"
raise TypeError(msg)


@attr.define(kw_only=True)
Expand All @@ -208,17 +217,12 @@ class BidsPartialComponent:
``BidsPartialComponents`` are immutable: their values cannot be altered.
"""

_zip_lists: ZipList = attr.field(
on_setattr=attr.setters.frozen, converter=MultiSelectDict, alias="zip_lists"
_table: BidsTable = attr.field(
converter=_to_bids_table,
on_setattr=attr.setters.frozen,
alias="table",
)

@_zip_lists.validator # type: ignore
def _validate_zip_lists(self, __attr: str, value: dict[str, list[str]]) -> None:
lengths = {len(val) for val in value.values()}
if len(lengths) > 1:
msg = "zip_lists must all be of equal length"
raise ValueError(msg)

def __repr__(self) -> str:
return self.pformat()

Expand All @@ -232,11 +236,8 @@ def __getitem__(
self, key: str | tuple[str, ...], /
) -> BidsComponentRow | BidsPartialComponent:
if isinstance(key, tuple):
# Use dict.fromkeys for de-duplication
return BidsPartialComponent(
zip_lists={key: self.zip_lists[key] for key in dict.fromkeys(key)}
)
return BidsComponentRow(self.zip_lists[key], entity=key)
return BidsPartialComponent(table=self._table.pick(key))
return BidsComponentRow(self._table.get(key), entity=key)

def __bool__(self) -> bool:
"""Truth of a BidsComponent is based on whether it has values.
Expand All @@ -247,7 +248,7 @@ def __bool__(self) -> bool:
consistent with :class:`BidsComponentRow`, which always has an entity name
stored, but may or may not have values.
"""
return bool(itx.first(self.zip_lists))
return bool(self._table.entries)

def _pformat_body(self) -> None | str | list[str]:
"""Extra properties to be printed within pformat.
Expand All @@ -271,8 +272,7 @@ def pformat(self, max_width: int | float | None = None, tabstop: int = 4) -> str
body = it.chain(
itx.always_iterable(self._pformat_body() or []),
[
"zip_lists="
f"{format_zip_lists(self.zip_lists, width - tabstop, tabstop)},",
"table=" f"{self._table.pformat(width - tabstop, tabstop)},",
],
)
output = [
Expand All @@ -292,6 +292,9 @@ def pformat(self, max_width: int | float | None = None, tabstop: int = 4) -> str
_entities: list[str] | None = attr.field(
default=None, init=False, eq=False, repr=False
)
_zip_lists: ZipList | None = attr.field(
default=None, init=False, eq=False, repr=False
)

@property
def zip_lists(self):
Expand All @@ -302,6 +305,9 @@ def zip_lists(self):
of images matched for this modality, so they can be zipped together to get a
list of the wildcard values for each file.
"""
if self._zip_lists is not None:
return self._zip_lists
self._zip_lists = self._table.to_dict()
return self._zip_lists

@property
Expand All @@ -328,15 +334,8 @@ def wildcards(self) -> MultiSelectDict[str, str]:
self._input_wildcards = MultiSelectDict(get_wildcard_dict(self.zip_lists))
return self._input_wildcards

@property
@property_alias(zip_lists, "zip_lists", "snakebids.BidsPartialComponent.zip_lists")
def input_zip_lists(self) -> ZipList:
"""Alias of :attr:`zip_lists <snakebids.BidsComponent.zip_lists>`.

Dictionary where each key is a wildcard entity and each value is a list of the
values found for that entity. Each of these lists has length equal to the number
of images matched for this modality, so they can be zipped together to get a
list of the wildcard values for each file.
"""
return self.zip_lists

@property_alias(entities, "entities", "snakebids.BidsComponent.entities")
Expand All @@ -351,7 +350,7 @@ def __eq__(self, other: BidsComponent | object) -> bool:
if not isinstance(other, self.__class__):
return False

return zip_list_eq(self.zip_lists, other.zip_lists)
return self._table == other._table

def expand(
self,
Expand Down Expand Up @@ -458,7 +457,7 @@ def filter(
return self
return attr.evolve(
self,
zip_lists=filter_list(self.zip_lists, filters, regex_search=regex_search),
table=self._table.filter(filters, regex_search=regex_search),
)


Expand Down Expand Up @@ -497,27 +496,28 @@ class BidsComponent(BidsPartialComponent):
BidsComponents are immutable: their values cannot be altered.
"""

_table: BidsTable = attr.field(
converter=_to_bids_table,
on_setattr=attr.setters.frozen,
alias="table",
)

name: str = attr.field(on_setattr=attr.setters.frozen)
"""Name of the component"""

path: str = attr.field(on_setattr=attr.setters.frozen)
"""Wildcard-filled path that matches the files for this component."""

_zip_lists: ZipList = attr.field(
on_setattr=attr.setters.frozen, converter=MultiSelectDict, alias="zip_lists"
)

@_zip_lists.validator # type: ignore
def _validate_zip_lists(self, __attr: str, value: dict[str, list[str]]) -> None:
super()._validate_zip_lists(__attr, value)
@_table.validator # type: ignore
def _validate_zip_lists(self, __attr: str, value: BidsTable) -> None:
_, raw_fields, *_ = sb_it.unpack(
zip(*Formatter().parse(self.path)), [[], [], []]
)
raw_fields = cast("Iterable[str]", raw_fields)
if (fields := set(filter(None, raw_fields))) != set(value):
if (fields := set(filter(None, raw_fields))) != set(value.wildcards):
msg = (
"zip_lists entries must match the wildcards in input_path: "
f"{self.path}: {fields} != zip_lists: {set(value)}"
"entries have the same wildcards as the input path: "
f"{self.path}: {fields} != entries: {set(value.wildcards)}"
)
raise ValueError(msg)

Expand Down
Loading
Loading