From 1e6a2571e8455f7006ea7a226bf2ef157d7d81a4 Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Wed, 10 Jan 2024 12:12:21 -0600 Subject: [PATCH 01/14] all setup in pyproject.toml --- .pre-commit-config.yaml | 8 +++---- pyproject.toml | 47 +++++++++++++++++++++++++++++++---------- setup.cfg | 36 ------------------------------- setup.py | 4 ---- 4 files changed, 40 insertions(+), 55 deletions(-) delete mode 100644 setup.cfg delete mode 100644 setup.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 30a97c0..a021871 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-yaml - id: end-of-file-fixer @@ -14,18 +14,18 @@ repos: - id: nbstripout - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.274 + rev: v0.1.11 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort args: ["--profile", "black", "--filter-files"] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.12.1 hooks: - id: black diff --git a/pyproject.toml b/pyproject.toml index dea5508..e3a390e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,31 +1,56 @@ [build-system] requires = [ - "setuptools>=42,<64", - "wheel", - "setuptools_scm[toml]>=7.0", + "setuptools>=69", + "setuptools_scm>=8", ] build-backend = "setuptools.build_meta" +[project] +name = "sharrow" +license = "BSD-3-Clause" +requires-python = ">=3.9" +dynamic = ["version"] +dependencies = [ + "numpy >= 1.19", + "pandas >= 1.2", + "pyarrow", + "xarray", + "numba >= 0.51.2", + "numexpr", + "filelock", + "dask", + "networkx", +] + [tool.setuptools_scm] fallback_version = "1999" write_to = "sharrow/_version.py" -[tool.isort] -profile = "black" -skip_gitignore = true -float_to_top = true -default_section = "THIRDPARTY" -known_first_party = "sharrow" - [tool.ruff] # Enable flake8-bugbear (`B`) and pyupgrade ('UP') rules. -select = ["E", "F", "B", "UP"] +select = [ "F", # Pyflakes + "E", # Pycodestyle Errors + "W", # Pycodestyle Warnings + "I", # isort + "UP", # pyupgrade + "D", # pydocstyle + "B", # flake8-bugbear +] fix = true ignore-init-module-imports = true line-length = 120 ignore = ["B905"] target-version = "py39" +[tool.ruff.lint.isort] +known-first-party = ["larch"] + +[tool.ruff.lint.pycodestyle] +max-line-length = 120 + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + [tool.pytest.ini_options] minversion = "6.0" addopts = "-v --nbmake --disable-warnings" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 28d8c14..0000000 --- a/setup.cfg +++ /dev/null @@ -1,36 +0,0 @@ -[metadata] -name = sharrow -author = Cambridge Systematics -author_email = jeffnewman@camsys.com -license = BSD-3-Clause -url = https://github.com/ActivitySim/sharrow -description = numba for ActivitySim-style spec files -long_description = file: README.md -long_description_content_type = text/markdown - -[options] -packages = find: -zip_safe = False -include_package_data = True -python_requires = >=3.7 -install_requires = - numpy >= 1.19 - pandas >= 1.2 - pyarrow >= 3.0.0 - xarray >= 0.20.0 - numba >= 0.54 - sparse - numexpr - filelock - dask - networkx - astunparse;python_version<'3.9' - -[flake8] -exclude = - .git, - __pycache__, - docs/_build, - sharrow/__init__.py -max-line-length = 160 -extend-ignore = E203, E731 diff --git a/setup.py b/setup.py deleted file mode 100644 index 02aeac1..0000000 --- a/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env python -from setuptools import setup - -setup(use_scm_version={"fallback_version": "1999"}) From 3da22aecd76533c993765c78842bcaeb6589b456 Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Wed, 10 Jan 2024 13:32:12 -0600 Subject: [PATCH 02/14] ruff formatting --- pyproject.toml | 2 +- sharrow/accessors.py | 2 + sharrow/categorical.py | 4 +- sharrow/dataset.py | 94 +++---------- sharrow/datastore.py | 42 ++---- sharrow/digital_encoding.py | 32 ++--- sharrow/flows.py | 274 +++++++++--------------------------- sharrow/relationships.py | 168 +++++----------------- sharrow/shared_memory.py | 37 ++--- sharrow/sparse.py | 15 +- sharrow/tests/conftest.py | 12 +- sharrow/translate.py | 2 +- sharrow/utils/tar_zst.py | 4 +- sharrow/wrappers.py | 25 ++-- 14 files changed, 181 insertions(+), 532 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e3a390e..8ad452e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ select = [ "F", # Pyflakes fix = true ignore-init-module-imports = true line-length = 120 -ignore = ["B905"] +ignore = ["B905", "D1"] target-version = "py39" [tool.ruff.lint.isort] diff --git a/sharrow/accessors.py b/sharrow/accessors.py index f83de83..9962db2 100644 --- a/sharrow/accessors.py +++ b/sharrow/accessors.py @@ -1,3 +1,5 @@ +"""Convenience accessor wrappers for xarray objects.""" + import xarray as xr diff --git a/sharrow/categorical.py b/sharrow/categorical.py index abddae2..a39bf89 100644 --- a/sharrow/categorical.py +++ b/sharrow/categorical.py @@ -14,9 +14,7 @@ class ArrayIsNotCategoricalError(TypeError): @xr.register_dataarray_accessor("cat") class _Categorical: - """ - Accessor for pseudo-categorical arrays. - """ + """Accessor for pseudo-categorical arrays.""" __slots__ = ("dataarray",) diff --git a/sharrow/dataset.py b/sharrow/dataset.py index 6ffab94..5f6cc25 100755 --- a/sharrow/dataset.py +++ b/sharrow/dataset.py @@ -70,16 +70,14 @@ def clean(s): cleaned = re.sub(r"\W|^(?=\d)", "_", s) if cleaned != s or len(cleaned) > 120: # digest size 15 creates a 24 character base32 string - h = base64.b32encode( - hashlib.blake2b(s.encode(), digest_size=15).digest() - ).decode() + h = base64.b32encode(hashlib.blake2b(s.encode(), digest_size=15).digest()).decode() cleaned = f"{cleaned[:90]}_{h}" return cleaned def construct(source): """ - A generic constructor for creating Datasets from various similar objects. + Create Datasets from various similar objects. Parameters ---------- @@ -111,7 +109,7 @@ def dataset_from_dataframe_fast( sparse: bool = False, preserve_cat: bool = True, ) -> Dataset: - """Convert a pandas.DataFrame into an xarray.Dataset + """Convert a pandas.DataFrame into an xarray.Dataset. Each column will be converted into an independent variable in the Dataset. If the dataframe's index is a MultiIndex, it will be expanded @@ -146,7 +144,6 @@ def dataset_from_dataframe_fast( xarray.DataArray.from_series pandas.DataFrame.to_xarray """ - # this is much faster than the default xarray version when not # using a MultiIndex. @@ -170,9 +167,7 @@ def dataset_from_dataframe_fast( if cannot_fix: break dupe_column_names = [f"- {i}" for i in dupe_column_names] - logger.error( - "DataFrame has non-unique columns\n" + "\n".join(dupe_column_names) - ) + logger.error("DataFrame has non-unique columns\n" + "\n".join(dupe_column_names)) if cannot_fix: raise ValueError("cannot convert DataFrame with non-unique columns") else: @@ -215,7 +210,7 @@ def from_table( index=None, ): """ - Convert a pyarrow.Table into an xarray.Dataset + Convert a pyarrow.Table into an xarray.Dataset. Parameters ---------- @@ -238,13 +233,9 @@ def from_table( index = pd.RangeIndex(len(tbl), name=index_name) else: if len(index) != len(tbl): - raise ValueError( - f"length of index ({len(index)}) does not match length of table ({len(tbl)})" - ) + raise ValueError(f"length of index ({len(index)}) does not match length of table ({len(tbl)})") if isinstance(index, pd.MultiIndex) and not index.is_unique: - raise ValueError( - "cannot attach a non-unique MultiIndex and convert into xarray" - ) + raise ValueError("cannot attach a non-unique MultiIndex and convert into xarray") arrays = [] metadata = {} for n in range(len(tbl.column_names)): @@ -262,10 +253,7 @@ def from_table( arrays.append((tbl.column_names[n], np.asarray(c))) result = xr.Dataset() if isinstance(index, pd.MultiIndex): - dims = tuple( - name if name is not None else "level_%i" % n - for n, name in enumerate(index.names) - ) + dims = tuple(name if name is not None else "level_%i" % n for n, name in enumerate(index.names)) for dim, lev in zip(dims, index.levels): result[dim] = (dim, lev) else: @@ -320,7 +308,6 @@ def from_omx( ------- Dataset """ - # handle both larch.OMX and openmatrix.open_file versions if "lar" in type(omx).__module__: omx_data = omx.data @@ -380,10 +367,7 @@ def from_omx( raise KeyError(f"{i} not found in OMX lookups") indexes = indexes_ if indexes is not None: - d["coords"] = { - index_name: {"dims": index_name, "data": index} - for index_name, index in indexes.items() - } + d["coords"] = {index_name: {"dims": index_name, "data": index} for index_name, index in indexes.items()} return xr.Dataset.from_dict(d) @@ -474,9 +458,7 @@ def from_omx_3d( elif indexes in set(omx_lookup._v_children): ranger = None else: - raise NotImplementedError( - "only one-based, zero-based, and named indexes are implemented" - ) + raise NotImplementedError("only one-based, zero-based, and named indexes are implemented") if ranger is not None: r1 = ranger(n1) r2 = ranger(n2) @@ -496,9 +478,7 @@ def from_omx_3d( base_k, time_k = k.split(time_period_sep, 1) if base_k not in pending_3d: pending_3d[base_k] = [None] * len(time_periods) - pending_3d[base_k][time_periods_map[time_k]] = dask.array.from_array( - omx_data[omx_data_map[k]][k] - ) + pending_3d[base_k][time_periods_map[time_k]] = dask.array.from_array(omx_data[omx_data_map[k]][k]) else: content[k] = xr.DataArray( dask.array.from_array(omx_data[omx_data_map[k]][k]), @@ -517,9 +497,7 @@ def from_omx_3d( break if prototype is None: raise ValueError("no prototype") - darrs_ = [ - (i if i is not None else dask.array.zeros_like(prototype)) for i in darrs - ] + darrs_ = [(i if i is not None else dask.array.zeros_like(prototype)) for i in darrs] content[base_k] = xr.DataArray( dask.array.stack(darrs_, axis=-1), dims=index_names, @@ -569,10 +547,7 @@ def from_amx( elif indexes.get(i) == "0": indexes[i] = zero_based(amx.shape[n]) if indexes is not None: - d["coords"] = { - index_name: {"dims": index_name, "data": index} - for index_name, index in indexes.items() - } + d["coords"] = {index_name: {"dims": index_name, "data": index} for index_name, index in indexes.items()} return xr.Dataset.from_dict(d) @@ -695,9 +670,7 @@ def is_dict_like(value: Any) -> bool: @xr.register_dataset_accessor("single_dim") class _SingleDim: - """ - Convenience accessor for single-dimension datasets. - """ + """Convenience accessor for single-dimension datasets.""" __slots__ = ("dataset", "dim_name") @@ -724,10 +697,7 @@ def _to_pydict(self): data = [] for k in columns: a = self.dataset._variables[k] - if ( - "digital_encoding" in a.attrs - and "dictionary" in a.attrs["digital_encoding"] - ): + if "digital_encoding" in a.attrs and "dictionary" in a.attrs["digital_encoding"]: de = a.attrs["digital_encoding"] data.append( pd.Categorical.from_codes( @@ -745,10 +715,7 @@ def to_pyarrow(self) -> pa.Table: data = [] for k in columns: a = self.dataset._variables[k] - if ( - "digital_encoding" in a.attrs - and "dictionary" in a.attrs["digital_encoding"] - ): + if "digital_encoding" in a.attrs and "dictionary" in a.attrs["digital_encoding"]: de = a.attrs["digital_encoding"] data.append( pa.DictionaryArray.from_arrays( @@ -839,9 +806,7 @@ def eval( @xr.register_dataarray_accessor("single_dim") class _SingleDimArray: - """ - Convenience accessor for single-dimension datasets. - """ + """Convenience accessor for single-dimension datasets.""" __slots__ = ("dataarray", "dim_name") @@ -892,9 +857,7 @@ def to_pandas(self) -> pd.Series: def to_pyarrow(self): if self.dataarray.cat.is_categorical(): - return pa.DictionaryArray.from_arrays( - self.dataarray.data, self.dataarray.cat.categories - ) + return pa.DictionaryArray.from_arrays(self.dataarray.data, self.dataarray.cat.categories) else: return pa.array(self.dataarray.data) @@ -924,10 +887,7 @@ def __getitem__(self, key: Mapping[Hashable, Any]) -> Dataset: dim_name = self.dataset.dims.__iter__().__next__() key = {dim_name: key} else: - raise TypeError( - "can only lookup dictionaries from Dataset.iloc, " - "unless there is only one dimension" - ) + raise TypeError("can only lookup dictionaries from Dataset.iloc, " "unless there is only one dimension") return self.dataset.isel(key) @@ -939,9 +899,7 @@ def rename_or_ignore(self, dims_dict=None, **dims_kwargs): from xarray.core.utils import either_dict_or_kwargs dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "rename_dims_and_coords") - dims_dict = { - k: v for (k, v) in dims_dict.items() if (k in self.dims or k in self._variables) - } + dims_dict = {k: v for (k, v) in dims_dict.items() if (k in self.dims or k in self._variables)} return self.rename(dims_dict) @@ -1033,13 +991,7 @@ def to_zarr_zip(self, *args, **kwargs): def _to_ast_literal(x): if isinstance(x, dict): - return ( - "{" - + ", ".join( - f"{_to_ast_literal(k)}: {_to_ast_literal(v)}" for k, v in x.items() - ) - + "}" - ) + return "{" + ", ".join(f"{_to_ast_literal(k)}: {_to_ast_literal(v)}" for k, v in x.items()) + "}" elif isinstance(x, list): return "[" + ", ".join(_to_ast_literal(i) for i in x) + "]" elif isinstance(x, tuple): @@ -1194,7 +1146,7 @@ def to_table(self): from .relationships import sparse_array_type def to_numpy(var): - """Coerces wrapped data to numpy and returns a numpy.ndarray""" + """Coerces wrapped data to numpy and returns a numpy.ndarray.""" data = var.data if hasattr(data, "chunks"): data = data.compute() @@ -1218,7 +1170,7 @@ def to_numpy(var): @register_dataset_method def select_and_rename(self, name_dict=None, **names): """ - Select and rename variables from this Dataset + Select and rename variables from this Dataset. Parameters ---------- diff --git a/sharrow/datastore.py b/sharrow/datastore.py index 4be79df..bb3379b 100644 --- a/sharrow/datastore.py +++ b/sharrow/datastore.py @@ -19,7 +19,7 @@ def timestamp(): class ReadOnlyError(ValueError): - """This object is read-only.""" + """Object is read-only.""" def _read_parquet(filename, index_col=None) -> xr.Dataset: @@ -128,9 +128,7 @@ def _update_dataset( if k in data.coords: continue assert v.name == k - partial_update = self._update_dataarray( - name, v, last_checkpoint, partial_update=partial_update - ) + partial_update = self._update_dataarray(name, v, last_checkpoint, partial_update=partial_update) for k, v in data.coords.items(): assert v.name == k partial_update = self._update_dataarray( @@ -158,12 +156,8 @@ def _update_dataarray( {data.name: data.assign_attrs(last_checkpoint=last_checkpoint)} ) else: - updated_dataset = base_data.assign( - {data.name: data.assign_attrs(last_checkpoint=last_checkpoint)} - ) - self._tree = self._tree.replace_datasets( - {name: updated_dataset}, redigitize=self._keep_digitized - ) + updated_dataset = base_data.assign({data.name: data.assign_attrs(last_checkpoint=last_checkpoint)}) + self._tree = self._tree.replace_datasets({name: updated_dataset}, redigitize=self._keep_digitized) return updated_dataset else: raise TypeError(type(data)) @@ -266,9 +260,7 @@ def _zarr_subdir(self, table_name, checkpoint_name): return self.directory.joinpath(table_name, checkpoint_name).with_suffix(".zarr") def _parquet_name(self, table_name, checkpoint_name): - return self.directory.joinpath(table_name, checkpoint_name).with_suffix( - ".parquet" - ) + return self.directory.joinpath(table_name, checkpoint_name).with_suffix(".parquet") def make_checkpoint(self, checkpoint_name: str, overwrite: bool = True): """ @@ -302,9 +294,7 @@ def make_checkpoint(self, checkpoint_name: str, overwrite: bool = True): os.unlink(target) target.parent.mkdir(parents=True, exist_ok=True) table_data.single_dim.to_parquet(str(target)) - elif self._storage_format == "zarr" or ( - self._storage_format == "parquet" and len(table_data.dims) > 1 - ): + elif self._storage_format == "zarr" or (self._storage_format == "parquet" and len(table_data.dims) > 1): # zarr is used if ndim > 1 target = self._zarr_subdir(table_name, checkpoint_name) if overwrite and target.is_dir(): @@ -314,9 +304,7 @@ def make_checkpoint(self, checkpoint_name: str, overwrite: bool = True): elif self._storage_format == "hdf5": raise NotImplementedError else: - raise ValueError( - f"cannot write with storage format {self._storage_format!r}" - ) + raise ValueError(f"cannot write with storage format {self._storage_format!r}") self.update(table_name, table_data, last_checkpoint=checkpoint_name) for table_name, table_data in self._tree.subspaces_iter(): inventory = {"data_vars": {}, "coords": {}} @@ -344,9 +332,7 @@ def make_checkpoint(self, checkpoint_name: str, overwrite: bool = True): def _write_checkpoint(self, name, checkpoint): if self._mode == "r": raise ReadOnlyError - checkpoint_metadata_target = self.directory.joinpath( - self.checkpoint_subdir, f"{name}.yaml" - ) + checkpoint_metadata_target = self.directory.joinpath(self.checkpoint_subdir, f"{name}.yaml") if checkpoint_metadata_target.exists(): n = 1 while checkpoint_metadata_target.with_suffix(f".{n}.yaml").exists(): @@ -377,7 +363,7 @@ def _write_metadata(self): def read_metadata(self, checkpoints=None): """ - Read storage metadata + Read storage metadata. Parameters ---------- @@ -400,9 +386,7 @@ def read_metadata(self, checkpoints=None): else: checkpoints = [checkpoints] for c in checkpoints: - with open( - self.directory.joinpath(self.checkpoint_subdir, f"{c}.yaml") - ) as f: + with open(self.directory.joinpath(self.checkpoint_subdir, f"{c}.yaml")) as f: self._checkpoints[c] = yaml.safe_load(f) def restore_checkpoint(self, checkpoint_name: str): @@ -430,9 +414,7 @@ def restore_checkpoint(self, checkpoint_name: str): opened_targets[target] = from_zarr_with_attr(target) else: # zarr not found, try parquet - target2 = self._parquet_name( - table_name, coord_def["last_checkpoint"] - ) + target2 = self._parquet_name(table_name, coord_def["last_checkpoint"]) if target2.exists(): if target not in opened_targets: opened_targets[target] = _read_parquet(target2, index_name) @@ -478,5 +460,5 @@ def digitize_relationships(self, redigitize=True): @property def relationships_are_digitized(self) -> bool: - """bool : Whether all relationships are digital (by position).""" + """Bool : Whether all relationships are digital (by position).""" return self._tree.relationships_are_digitized diff --git a/sharrow/digital_encoding.py b/sharrow/digital_encoding.py index e260022..a62fdef 100644 --- a/sharrow/digital_encoding.py +++ b/sharrow/digital_encoding.py @@ -101,8 +101,7 @@ def array_decode(x, digital_encoding=None, aux_data=None): if offset_source: if aux_data is None: raise ValueDecodeError( - "cannot independently decode multivalue DataArray, " - "provide aux_data or decode from dataset" + "cannot independently decode multivalue DataArray, " "provide aux_data or decode from dataset" ) result = aux_data[offset_source].copy() result.data = x.to_numpy()[result.data] @@ -167,7 +166,7 @@ def digitize_by_dictionary(arr, bitwidth=8): bin_edges = (bins[1:] - bins[:-1]) / 2 + bins[:-1] except TypeError: # bins are not numeric - bin_map = {x:n for n,x in enumerate(bins)} + bin_map = {x: n for n, x in enumerate(bins)} u, inv = np.unique(arr.data, return_inverse=True) result.data = np.array([bin_map.get(x) for x in u])[inv].reshape(arr.shape) result.attrs["digital_encoding"] = { @@ -334,7 +333,7 @@ def multivalue_digitize_by_dictionary(ds, encode_vars=None, encoding_name=None): Returns ------- - + Dataset """ logger = logging.getLogger("sharrow") if not isinstance(encoding_name, str): @@ -349,35 +348,23 @@ def multivalue_digitize_by_dictionary(ds, encode_vars=None, encoding_name=None): # check each name in encode_vars to make sure it's not already encoded # if you want to re-encode first decode - encode_vars = [ - v - for v in encode_vars - if "offset_source" not in ds[v].attrs.get("digital_encoding", {}) - ] + encode_vars = [v for v in encode_vars if "offset_source" not in ds[v].attrs.get("digital_encoding", {})] if len(encode_vars) == 0: return ds encode_var_dims = ds[encode_vars[0]].dims for v in encode_vars[1:]: - assert ( - encode_var_dims == ds[v].dims - ), f"dims must match, {encode_var_dims} != {ds[v].dims}" + assert encode_var_dims == ds[v].dims, f"dims must match, {encode_var_dims} != {ds[v].dims}" logger.info("assembling data stack") - conjoined = np.stack( - [array_decode(ds[v].compute(), aux_data=ds) for v in encode_vars], axis=-1 - ) + conjoined = np.stack([array_decode(ds[v].compute(), aux_data=ds) for v in encode_vars], axis=-1) logger.info("constructing stack view") baseshape = conjoined.shape[:-1] conjoined = conjoined.reshape([-1, conjoined.shape[-1]]) - voidview = np.ascontiguousarray(conjoined).view( - np.dtype((np.void, conjoined.dtype.itemsize * conjoined.shape[1])) - ) + voidview = np.ascontiguousarray(conjoined).view(np.dtype((np.void, conjoined.dtype.itemsize * conjoined.shape[1]))) logger.info("finding unique value combinations") unique_values, pointers = np.unique(voidview, return_inverse=True) pointers = pointers.reshape(baseshape) - unique_values = unique_values.view(np.dtype(conjoined.dtype)).reshape( - [-1, len(encode_vars)] - ) + unique_values = unique_values.view(np.dtype(conjoined.dtype)).reshape([-1, len(encode_vars)]) logger.info("downsampling offsets") if unique_values.shape[0] < 1 << 8: pointers = pointers.astype(np.uint8) @@ -410,7 +397,6 @@ def multivalue_digitize_by_dictionary(ds, encode_vars=None, encoding_name=None): bytes_saved = original_footprint - encoded_footprint savings_ratio = bytes_saved / original_footprint logger.info( - f"multivalue_digitize_by_dictionary {encoding_name} " - f"saved {si_units(bytes_saved)} {savings_ratio:.1%}" + f"multivalue_digitize_by_dictionary {encoding_name} " f"saved {si_units(bytes_saved)} {savings_ratio:.1%}" ) return out diff --git a/sharrow/flows.py b/sharrow/flows.py index 6345b6d..e98c5e2 100644 --- a/sharrow/flows.py +++ b/sharrow/flows.py @@ -80,9 +80,7 @@ def clean(s): cleaned = re.sub(r"\W|^(?=\d)", "_", s) if cleaned != s or len(cleaned) > 120: # digest size 15 creates a 24 character base32 string - h = base64.b32encode( - hashlib.blake2b(s.encode(), digest_size=15).digest() - ).decode() + h = base64.b32encode(hashlib.blake2b(s.encode(), digest_size=15).digest()).decode() cleaned = f"{cleaned[:90]}_{h}" return cleaned @@ -155,29 +153,16 @@ def visit_Call(self, node): if len(node.args) == 1: if isinstance(node.args[0], ast.Constant): if len(node.keywords) == 0: - self.required_get_tokens.add( - (node.func.value.id, node.args[0].value) - ) - elif ( - len(node.keywords) == 1 - and node.keywords[0].arg == "default" - ): - self.optional_get_tokens.add( - (node.func.value.id, node.args[0].value) - ) + self.required_get_tokens.add((node.func.value.id, node.args[0].value)) + elif len(node.keywords) == 1 and node.keywords[0].arg == "default": + self.optional_get_tokens.add((node.func.value.id, node.args[0].value)) else: - raise ValueError( - f"{node.func.value.id}.get with unexpected keyword arguments" - ) + raise ValueError(f"{node.func.value.id}.get with unexpected keyword arguments") if len(node.args) == 2: if isinstance(node.args[0], ast.Constant): - self.optional_get_tokens.add( - (node.func.value.id, node.args[0].value) - ) + self.optional_get_tokens.add((node.func.value.id, node.args[0].value)) if len(node.args) > 2: - raise ValueError( - f"{node.func.value.id}.get with more than 2 positional arguments" - ) + raise ValueError(f"{node.func.value.id}.get with more than 2 positional arguments") self.generic_visit(node) def check(self, node): @@ -1022,7 +1007,7 @@ def __initialize_1( bool_wrapping=False, ): """ - Initialize up to the flow_hash + Initialize up to the flow_hash. See main docstring for arguments. """ @@ -1078,17 +1063,12 @@ def __initialize_1( subspace_names.add(k) for k in self.tree.subspace_fallbacks: subspace_names.add(k) - optional_get_tokens = ExtractOptionalGetTokens(from_names=subspace_names).check( - defs.values() - ) + optional_get_tokens = ExtractOptionalGetTokens(from_names=subspace_names).check(defs.values()) self._optional_get_tokens = [] if optional_get_tokens: for _spacename, _varname in optional_get_tokens: found = False - if ( - _spacename in self.tree.subspaces - and _varname in self.tree.subspaces[_spacename] - ): + if _spacename in self.tree.subspaces and _varname in self.tree.subspaces[_spacename]: self._optional_get_tokens.append(f"__{_spacename}__{_varname}:True") found = True elif _spacename in self.tree.subspace_fallbacks: @@ -1100,9 +1080,7 @@ def __initialize_1( found = True break if not found: - self._optional_get_tokens.append( - f"__{_spacename}__{_varname}:False" - ) + self._optional_get_tokens.append(f"__{_spacename}__{_varname}:False") self._hashing_level = hashing_level if self._hashing_level > 1: @@ -1171,9 +1149,7 @@ def _flow_hash_push(x): parts = k.split("__") if len(parts) > 2: try: - digital_encoding = self.tree.subspaces[parts[1]][ - "__".join(parts[2:]) - ].attrs["digital_encoding"] + digital_encoding = self.tree.subspaces[parts[1]]["__".join(parts[2:])].attrs["digital_encoding"] except (AttributeError, KeyError) as err: pass print(f"$$$$/ndigital_encoding=ERR\n{err}\n\n\n$$$") @@ -1199,12 +1175,7 @@ def _flow_hash_push(x): self.flow_hash_audit = "]\n# [".join(flow_hash_audit) def _index_slots(self): - return { - i: n - for n, i in enumerate( - presorted(self.tree.sizes, self.dim_order, self.dim_exclude) - ) - } + return {i: n for n, i in enumerate(presorted(self.tree.sizes, self.dim_order, self.dim_exclude))} def init_sub_funcs( self, @@ -1216,12 +1187,7 @@ def init_sub_funcs( ): func_code = "" all_name_tokens = set() - index_slots = { - i: n - for n, i in enumerate( - presorted(self.tree.sizes, self.dim_order, self.dim_exclude) - ) - } + index_slots = {i: n for n, i in enumerate(presorted(self.tree.sizes, self.dim_order, self.dim_exclude))} self.arg_name_positions = index_slots candidate_names = self.tree.namespace_names() candidate_names |= set(f"__aux_var__{i}" for i in self.tree.aux_vars.keys()) @@ -1314,12 +1280,8 @@ def init_sub_funcs( else: raise # check if we can resolve this name on any other subspace - for other_spacename in self.tree.subspace_fallbacks.get( - topkey, [] - ): - dim_slots, digital_encodings, blenders = meta_data[ - other_spacename - ] + for other_spacename in self.tree.subspace_fallbacks.get(topkey, []): + dim_slots, digital_encodings, blenders = meta_data[other_spacename] try: expr = expression_for_numba( expr, @@ -1359,12 +1321,8 @@ def init_sub_funcs( # at least one variable was found in a get break # check if we can resolve this "get" on any other subspace - for other_spacename in self.tree.subspace_fallbacks.get( - topkey, [] - ): - dim_slots, digital_encodings, blenders = meta_data[ - other_spacename - ] + for other_spacename in self.tree.subspace_fallbacks.get(topkey, []): + dim_slots, digital_encodings, blenders = meta_data[other_spacename] try: expr = expression_for_numba( expr, @@ -1403,9 +1361,7 @@ def init_sub_funcs( actual_spacenames, ) in self.tree.subspace_fallbacks.items(): for actual_spacename in actual_spacenames: - dim_slots, digital_encodings, blenders = meta_data[ - actual_spacename - ] + dim_slots, digital_encodings, blenders = meta_data[actual_spacename] try: expr = expression_for_numba( expr, @@ -1448,10 +1404,7 @@ def init_sub_funcs( bool_wrapping=self.bool_wrapping, ) - aux_tokens = { - k: ast.parse(f"__aux_var__{k}", mode="eval").body - for k in self.tree.aux_vars.keys() - } + aux_tokens = {k: ast.parse(f"__aux_var__{k}", mode="eval").body for k in self.tree.aux_vars.keys()} # now handle aux vars expr = expression_for_numba( @@ -1514,6 +1467,7 @@ def __initialize_2( with_root_node_name=None, ): """ + Second step in initialization, only used if the flow is not cached. Parameters ---------- @@ -1535,7 +1489,6 @@ def __initialize_2( be sure to avoid name conflicts with other flow's in the same directory. """ - if self._hashing_level <= 1: func_code, all_name_tokens = self.init_sub_funcs( defs, @@ -1552,9 +1505,7 @@ def __initialize_2( parts = k.split("__") if len(parts) > 2: try: - digital_encoding = self.tree.subspaces[parts[1]][ - "__".join(parts[2:]) - ].attrs["digital_encoding"] + digital_encoding = self.tree.subspaces[parts[1]]["__".join(parts[2:])].attrs["digital_encoding"] except (AttributeError, KeyError): pass else: @@ -1632,9 +1583,7 @@ def __initialize_2( if isinstance(x_var, (float, int, str)): buffer.write(f"{x_name} = {x_var!r}\n") else: - buffer.write( - f"{x_name} = pickle.loads({repr(pickle.dumps(x_var))})\n" - ) + buffer.write(f"{x_name} = pickle.loads({repr(pickle.dumps(x_var))})\n") dependencies.add("import pickle") with io.StringIO() as x_code: x_code.write("\n") @@ -1651,9 +1600,7 @@ def __initialize_2( import pickle buffer = io.StringIO() for x_name, x_dict in self.encoding_dictionaries.items(): - buffer.write( - f"__encoding_dict{x_name} = pickle.loads({repr(pickle.dumps(x_dict))})\n" - ) + buffer.write(f"__encoding_dict{x_name} = pickle.loads({repr(pickle.dumps(x_dict))})\n") with io.StringIO() as x_code: x_code.write("\n") x_code.write(buffer.getvalue()) @@ -1662,9 +1609,7 @@ def __initialize_2( # write the master module for this flow os.makedirs(os.path.join(self.cache_dir, self.name), exist_ok=True) - with rewrite( - os.path.join(self.cache_dir, self.name, "__init__.py"), "wt" - ) as f_code: + with rewrite(os.path.join(self.cache_dir, self.name, "__init__.py"), "wt") as f_code: f_code.write( textwrap.dedent( f""" @@ -1738,9 +1683,7 @@ def __initialize_2( elif n_root_dims == 2: js = "j0, j1" else: - raise NotImplementedError( - f"n_root_dims only supported up to 2, not {n_root_dims}" - ) + raise NotImplementedError(f"n_root_dims only supported up to 2, not {n_root_dims}") meta_code = [] meta_code_dot = [] @@ -1757,48 +1700,24 @@ def __initialize_2( meta_code_dot.append( f"intermediate[{n}] = ({clean(k)}({f_args_j}intermediate, {f_name_tokens})).item()" ) - meta_code_stack = textwrap.indent( - "\n".join(meta_code), " " * 12 - ).lstrip() - meta_code_stack_dot = textwrap.indent( - "\n".join(meta_code_dot), " " * 12 - ).lstrip() + meta_code_stack = textwrap.indent("\n".join(meta_code), " " * 12).lstrip() + meta_code_stack_dot = textwrap.indent("\n".join(meta_code_dot), " " * 12).lstrip() len_self_raw_functions = len(self._raw_functions) - joined_namespace_names = "\n ".join( - f"{nn}," for nn in self._namespace_names - ) + joined_namespace_names = "\n ".join(f"{nn}," for nn in self._namespace_names) linefeed = "\n " if not meta_code_stack_dot: meta_code_stack_dot = "pass" if n_root_dims == 1: - meta_template = IRUNNER_1D_TEMPLATE.format(**locals()).format( - **locals() - ) - meta_template_dot = IDOTTER_1D_TEMPLATE.format( - **locals() - ).format(**locals()) - line_template = ILINER_1D_TEMPLATE.format(**locals()).format( - **locals() - ) - mnl_template = MNL_1D_TEMPLATE.format(**locals()).format( - **locals() - ) - nl_template = NL_1D_TEMPLATE.format(**locals()).format( - **locals() - ) + meta_template = IRUNNER_1D_TEMPLATE.format(**locals()).format(**locals()) + meta_template_dot = IDOTTER_1D_TEMPLATE.format(**locals()).format(**locals()) + line_template = ILINER_1D_TEMPLATE.format(**locals()).format(**locals()) + mnl_template = MNL_1D_TEMPLATE.format(**locals()).format(**locals()) + nl_template = NL_1D_TEMPLATE.format(**locals()).format(**locals()) elif n_root_dims == 2: - meta_template = IRUNNER_2D_TEMPLATE.format(**locals()).format( - **locals() - ) - meta_template_dot = IDOTTER_2D_TEMPLATE.format( - **locals() - ).format(**locals()) - line_template = ILINER_2D_TEMPLATE.format(**locals()).format( - **locals() - ) - mnl_template = MNL_2D_TEMPLATE.format(**locals()).format( - **locals() - ) + meta_template = IRUNNER_2D_TEMPLATE.format(**locals()).format(**locals()) + meta_template_dot = IDOTTER_2D_TEMPLATE.format(**locals()).format(**locals()) + line_template = ILINER_2D_TEMPLATE.format(**locals()).format(**locals()) + mnl_template = MNL_2D_TEMPLATE.format(**locals()).format(**locals()) nl_template = "" else: raise ValueError(f"invalid n_root_dims {n_root_dims}") @@ -1871,15 +1790,11 @@ def __initialize_2( def load_raw(self, rg, args, runner=None, dtype=None, dot=None): assert isinstance(rg, DataTree) with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", category=nb.NumbaExperimentalFeatureWarning - ) + warnings.filterwarnings("ignore", category=nb.NumbaExperimentalFeatureWarning) assembled_args = [args.get(k) for k in self.arg_name_positions.keys()] for aa in assembled_args: if aa.dtype.kind != "i": - warnings.warn( - "position arguments are not all integers", stacklevel=2 - ) + warnings.warn("position arguments are not all integers", stacklevel=2) try: if runner is None: if dot is None: @@ -1938,7 +1853,7 @@ def load_raw(self, rg, args, runner=None, dtype=None, dot=None): # raise the inner key error which is more helpful context = getattr(err, "__context__", None) if context: - raise context + raise context from None else: raise err @@ -1957,9 +1872,7 @@ def _iload_raw( ): assert isinstance(rg, DataTree) with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", category=nb.NumbaExperimentalFeatureWarning - ) + warnings.filterwarnings("ignore", category=nb.NumbaExperimentalFeatureWarning) try: known_arg_names = { "dtype", @@ -1995,11 +1908,7 @@ def _iload_raw( elif dot is None: runner_ = self._irunner known_arg_names.update({"mask"}) - if ( - mask is not None - and dtype is not None - and not np.issubdtype(dtype, np.floating) - ): + if mask is not None and dtype is not None and not np.issubdtype(dtype, np.floating): raise TypeError("cannot use mask unless dtype is float") else: runner_ = self._idotter @@ -2050,13 +1959,8 @@ def _iload_raw( if self.with_root_node_name is None: tree_root_dims = rg.root_dataset.sizes else: - tree_root_dims = rg._graph.nodes[self.with_root_node_name][ - "dataset" - ].sizes - argshape = [ - tree_root_dims[i] - for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude) - ] + tree_root_dims = rg._graph.nodes[self.with_root_node_name]["dataset"].sizes + argshape = [tree_root_dims[i] for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude)] if mnl is not None: if nesting is not None: n_alts = nesting["n_alts"] @@ -2069,9 +1973,7 @@ def _iload_raw( elif n_alts < 32768: kwargs["choice_dtype"] = np.int16 if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "========= PASSING ARGUMENT TO SHARROW LOAD ==========" - ) + logger.debug("========= PASSING ARGUMENT TO SHARROW LOAD ==========") logger.debug(f"{argshape=}") for _name, _info in zip(_arguments_names, arguments): try: @@ -2091,14 +1993,10 @@ def _iload_raw( logger.debug(f"KWARG {_name}: {alt_repr}") else: logger.debug(f"KWARG {_name}: type={type(_info)}") - logger.debug( - "========= ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ==========" - ) + logger.debug("========= ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ==========") result = runner_(np.asarray(argshape), *arguments, **kwargs) if compile_watch: - self.check_cache_misses( - runner_, log_details=compile_watch != "simple" - ) + self.check_cache_misses(runner_, log_details=compile_watch != "simple") return result except nb.TypingError as err: _raw_functions = getattr(self, "_raw_functions", {}) @@ -2155,18 +2053,11 @@ def check_cache_misses(self, *funcs, fresh=True, log_details=True): for k, v in cache_misses.items(): if v > known_cache_misses.get(k, 0): if log_details: - warning_text = "\n".join( - f" - {argname}: {sig}" - for (sig, argname) in zip(k, named_args) - ) + warning_text = "\n".join(f" - {argname}: {sig}" for (sig, argname) in zip(k, named_args)) warning_text = f"\n{runner_name}(\n{warning_text}\n)" else: warning_text = "" - timers = ( - f.overloads[k] - .metadata["timers"] - .get("compiler_lock", "N/A") - ) + timers = f.overloads[k].metadata["timers"].get("compiler_lock", "N/A") if isinstance(timers, float): if timers < 1e-3: timers = f"{timers/1e-6:.0f} µs" @@ -2174,13 +2065,8 @@ def check_cache_misses(self, *funcs, fresh=True, log_details=True): timers = f"{timers/1e-3:.1f} ms" else: timers = f"{timers:.2f} s" - logger.warning( - f"cache miss in {self.flow_hash}{warning_text}\n" - f"Compile Time: {timers}" - ) - warnings.warn( - f"{self.flow_hash}", CacheMissWarning, stacklevel=1 - ) + logger.warning(f"cache miss in {self.flow_hash}{warning_text}\n" f"Compile Time: {timers}") + warnings.warn(f"{self.flow_hash}", CacheMissWarning, stacklevel=1) self.compiled_recently = True self._known_cache_misses[runner_name][k] = v return self.compiled_recently @@ -2265,9 +2151,7 @@ def _load( logit_draws = np.zeros(source.shape + (0,), dtype=dtype) if self.with_root_node_name is None: - use_dims = list( - presorted(source.root_dataset.sizes, self.dim_order, self.dim_exclude) - ) + use_dims = list(presorted(source.root_dataset.sizes, self.dim_order, self.dim_exclude)) else: use_dims = list( presorted( @@ -2324,11 +2208,7 @@ def _load( ): result_dims = use_dims[:-1] result_squeeze = (-1,) - elif ( - dot.ndim == 2 - and dot.shape[1] == 1 - and logit_draws.ndim == len(use_dims) - ): + elif dot.ndim == 2 and dot.shape[1] == 1 and logit_draws.ndim == len(use_dims): result_dims = use_dims[:-1] + logit_draws_trailing_dim elif dot.ndim == 2 and logit_draws.ndim == len(use_dims): result_dims = use_dims[:-1] + dot_trailing_dim @@ -2351,11 +2231,7 @@ def _load( and self._logit_ndims == 1 ): result_dims = use_dims + logit_draws_trailing_dim - elif ( - dot.ndim == 2 - and logit_draws.ndim == len(use_dims) + 1 - and logit_draws.shape[-1] == 0 - ): + elif dot.ndim == 2 and logit_draws.ndim == len(use_dims) + 1 and logit_draws.shape[-1] == 0: # logsums only result_dims = use_dims result_squeeze = (-1,) @@ -2433,30 +2309,20 @@ def _load( raise RuntimeError("please digitize") if as_dataframe: index = getattr(source.root_dataset, "index", None) - result = pd.DataFrame( - result, index=index, columns=list(self._raw_functions.keys()) - ) + result = pd.DataFrame(result, index=index, columns=list(self._raw_functions.keys())) elif as_table: - result = Table( - {k: result[:, n] for n, k in enumerate(self._raw_functions.keys())} - ) + result = Table({k: result[:, n] for n, k in enumerate(self._raw_functions.keys())}) elif as_dataarray: if result_squeeze: result = squeeze(result, result_squeeze) result_p = squeeze(result_p, result_squeeze) pick_count = squeeze(pick_count, result_squeeze) if self.with_root_node_name is None: - result_coords = { - k: v - for k, v in source.root_dataset.coords.items() - if k in result_dims - } + result_coords = {k: v for k, v in source.root_dataset.coords.items() if k in result_dims} else: result_coords = { k: v - for k, v in source._graph.nodes[self.with_root_node_name][ - "dataset" - ].coords.items() + for k, v in source._graph.nodes[self.with_root_node_name]["dataset"].coords.items() if k in result_dims } if result is not None: @@ -2483,11 +2349,7 @@ def _load( out_logsum = xr.DataArray( out_logsum, dims=result_dims[: out_logsum.ndim], - coords={ - k: v - for k, v in source.root_dataset.coords.items() - if k in result_dims[: out_logsum.ndim] - }, + coords={k: v for k, v in source.root_dataset.coords.items() if k in result_dims[: out_logsum.ndim]}, ) else: @@ -2551,9 +2413,7 @@ def load(self, source=None, dtype=None, compile_watch=False, mask=None): ------- numpy.array """ - return self._load( - source=source, dtype=dtype, compile_watch=compile_watch, mask=mask - ) + return self._load(source=source, dtype=dtype, compile_watch=compile_watch, mask=mask) def load_dataframe(self, source=None, dtype=None, compile_watch=False, mask=None): """ @@ -2768,8 +2628,7 @@ def function_names(self, x): self._raw_functions[name] = (None, None, set(), []) def _spill(self, all_name_tokens=None): - cmds = [self.tree._spill(all_name_tokens)] - cmds.append("\n") + cmds = ["\n"] cmds.append(f"output_name_positions = {self.output_name_positions!r}") cmds.append(f"function_names = {self.function_names!r}") return "\n".join(cmds) @@ -2849,10 +2708,7 @@ def init_streamer(self, source=None, dtype=None): selected_args = tuple(general_mapping[k] for k in named_args) len_self_raw_functions = len(self._raw_functions) tree_root_dims = source.root_dataset.sizes - argshape = tuple( - tree_root_dims[i] - for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude) - ) + argshape = tuple(tree_root_dims[i] for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude)) if len(argshape) == 1: linemaker = self._linemaker @@ -2884,8 +2740,6 @@ def streamer(c, out=None): return result else: - raise NotImplementedError( - f"root tree with {len(argshape)} dims {argshape=}" - ) + raise NotImplementedError(f"root tree with {len(argshape)} dims {argshape=}") return streamer diff --git a/sharrow/relationships.py b/sharrow/relationships.py index ab025d6..e2989e5 100644 --- a/sharrow/relationships.py +++ b/sharrow/relationships.py @@ -154,7 +154,7 @@ def xgather(source, positions, indexes): def _dataarray_to_numpy(self) -> np.ndarray: - """Coerces wrapped data to numpy and returns a numpy.ndarray""" + """Coerces wrapped data to numpy and returns a numpy.ndarray.""" data = self.data if isinstance(data, dask_array_type): data = data.compute() @@ -165,9 +165,7 @@ def _dataarray_to_numpy(self) -> np.ndarray: class Relationship: - """ - Defines a linkage between datasets in a `DataTree`. - """ + """Defines a linkage between datasets in a `DataTree`.""" def __init__( self, @@ -391,17 +389,13 @@ def shape(self): from .flows import presorted dim_order = presorted(self.root_dataset.dims, self.dim_order) - return tuple( - self.root_dataset.dims[i] for i in dim_order if i not in self.dim_exclude - ) + return tuple(self.root_dataset.dims[i] for i in dim_order if i not in self.dim_exclude) @property def root_dims(self): from .flows import presorted - return tuple( - presorted(self.root_dataset.dims, self.dim_order, self.dim_exclude) - ) + return tuple(presorted(self.root_dataset.dims, self.dim_order, self.dim_exclude)) def __shallow_copy_extras(self): return dict( @@ -428,9 +422,7 @@ def __repr__(self): if len(self._graph.edges): s += "\n relationships:" for e in self._graph.edges: - s += f"\n - {self._get_relationship(e)!r}".replace( - "") + s += f"\n - {self._get_relationship(e)!r}".replace("") else: s += "\n relationships: none" return s @@ -453,9 +445,7 @@ def _hash_features(self): h.append("datasets:none") if len(self._graph.edges): for e in self._graph.edges: - r = f"relationship:{self._get_relationship(e)!r}".replace( - "") + r = f"relationship:{self._get_relationship(e)!r}".replace("") h.append(r) else: h.append("relationships:none") @@ -478,9 +468,7 @@ def root_node_name(self, name): self._root_node_name = name return if not isinstance(name, str): - raise TypeError( - f"root_node_name must be one of [str, None, False] not {type(name)}" - ) + raise TypeError(f"root_node_name must be one of [str, None, False] not {type(name)}") if name not in self._graph.nodes: raise KeyError(name) self._root_node_name = name @@ -543,7 +531,7 @@ def get_relationship(self, parent, child): return Relationship(parent_data=parent, child_data=child, **attrs) def list_relationships(self) -> list[Relationship]: - """list : List all relationships defined in this tree.""" + """List : List all relationships defined in this tree.""" result = [] for e in self._graph.edges: result.append(self._get_relationship(e)) @@ -618,9 +606,7 @@ def root_dataset(self, x): self._graph.nodes[self.root_node_name]["dataset"] = x def _get_relationship(self, edge): - return Relationship( - parent_data=edge[0], child_data=edge[1], **self._graph.edges[edge] - ) + return Relationship(parent_data=edge[0], child_data=edge[1], **self._graph.edges[edge]) def __getitem__(self, item): return self.get(item) @@ -649,19 +635,12 @@ def get(self, item, default=None, broadcast=True, coords=True): if isinstance(item, (list, tuple)): from .dataset import Dataset - return Dataset( - { - k: self.get(k, default=default, broadcast=broadcast, coords=coords) - for k in item - } - ) + return Dataset({k: self.get(k, default=default, broadcast=broadcast, coords=coords) for k in item}) try: result = self._getitem(item, dim_names_from_top=True) except KeyError: try: - result = self._getitem( - item, include_blank_dims=True, dim_names_from_top=True - ) + result = self._getitem(item, include_blank_dims=True, dim_names_from_top=True) except KeyError: if default is None: raise @@ -742,9 +721,7 @@ def _getitem( return current_node if current_node == start_from: if by_dims: - return xr.DataArray( - pd.RangeIndex(dataset.dims[item]), dims=item - ) + return xr.DataArray(pd.RangeIndex(dataset.dims[item]), dims=item) else: return dataset[item] else: @@ -764,22 +741,15 @@ def _getitem( result = dataset[item] dims_in_result = set(result.dims) top_dim_names = {} - for path in nx.algorithms.simple_paths.all_simple_edge_paths( - self._graph, start_from, current_node - ): + for path in nx.algorithms.simple_paths.all_simple_edge_paths(self._graph, start_from, current_node): if dim_names_from_top: e = path[0] top_dim_name = self._graph.edges[e].get("parent_name") start_dataset = self._graph.nodes[start_from]["dataset"] # deconvert digitized dim names back to native dims - if ( - top_dim_name not in start_dataset.dims - and top_dim_name in start_dataset.variables - ): + if top_dim_name not in start_dataset.dims and top_dim_name in start_dataset.variables: if start_dataset.variables[top_dim_name].ndim == 1: - top_dim_name = start_dataset.variables[ - top_dim_name - ].dims[0] + top_dim_name = start_dataset.variables[top_dim_name].dims[0] else: top_dim_name = None path_dim = self._graph.edges[path[-1]].get("child_name") @@ -793,36 +763,18 @@ def _getitem( r_next = self._get_relationship(e_next) if t1 is None: t1 = self._graph.nodes[r.parent_data].get("dataset") - t2 = self._graph.nodes[r.child_data].get("dataset")[ - [r_next.parent_name] - ] + t2 = self._graph.nodes[r.child_data].get("dataset")[[r_next.parent_name]] if r.indexing == "label": - t1 = t2.sel( - { - r.child_name: _dataarray_to_numpy( - t1[r.parent_name] - ) - } - ) + t1 = t2.sel({r.child_name: _dataarray_to_numpy(t1[r.parent_name])}) else: # by position - t1 = t2.isel( - { - r.child_name: _dataarray_to_numpy( - t1[r.parent_name] - ) - } - ) + t1 = t2.isel({r.child_name: _dataarray_to_numpy(t1[r.parent_name])}) # final node in path e = path[-1] - r = Relationship( - parent_data=e[0], child_data=e[1], **self._graph.edges[e] - ) + r = Relationship(parent_data=e[0], child_data=e[1], **self._graph.edges[e]) if t1 is None: t1 = self._graph.nodes[r.parent_data].get("dataset") if r.indexing == "label": - _labels[r.child_name] = _dataarray_to_numpy( - t1[r.parent_name] - ) + _labels[r.child_name] = _dataarray_to_numpy(t1[r.parent_name]) else: # by position _idx = _dataarray_to_numpy(t1[r.parent_name]) if not np.issubdtype(_idx.dtype, np.integer): @@ -871,11 +823,7 @@ def get_expr(self, expression, engine="sharrow", allow_native=True): raise KeyError except (KeyError, IndexError): if engine == "sharrow": - result = ( - self.setup_flow({expression: expression}) - .load_dataarray() - .isel(expressions=0) - ) + result = self.setup_flow({expression: expression}).load_dataarray().isel(expressions=0) elif engine == "numexpr": from xarray import DataArray @@ -904,7 +852,7 @@ def subspaces_iter(self): def contains_subspace(self, key) -> bool: """ - Is this named Dataset in this tree's subspaces + Is this named Dataset in this tree's subspaces. Parameters ---------- @@ -918,7 +866,7 @@ def contains_subspace(self, key) -> bool: def get_subspace(self, key, default_empty=False) -> xr.Dataset: """ - Access named Dataset from this tree's subspaces + Access named Dataset from this tree's subspaces. Parameters ---------- @@ -954,17 +902,13 @@ def namespace_names(self): @property def dims(self): - """ - Mapping from dimension names to lengths across all dataset nodes. - """ + """Mapping from dimension names to lengths across all dataset nodes.""" dims = {} for _k, v in self.subspaces_iter(): for name, length in v.dims.items(): if name in dims: if dims[name] != length: - raise ValueError( - "inconsistent dimensions\n" + self.dims_detail() - ) + raise ValueError("inconsistent dimensions\n" + self.dims_detail()) else: dims[name] = length return xr.core.utils.Frozen(dims) @@ -1005,7 +949,6 @@ def drop_dims(self, dims, inplace=False, ignore_missing_dims=True): Returns self if dropping inplace, otherwise returns a copy with dimensions dropped. """ - if isinstance(dims, str): dims = [dims] if inplace: @@ -1039,7 +982,7 @@ def drop_dims(self, dims, inplace=False, ignore_missing_dims=True): while boot_queue: b = boot_queue.pop() booted.add(b) - for (up, dn, _n) in obj._graph.edges.keys(): + for up, dn, _n in obj._graph.edges.keys(): if up == b: boot_queue.add(dn) @@ -1063,11 +1006,7 @@ def get_indexes( return self._cached_indexes[(position_only, as_dict)] if not position_only: raise NotImplementedError - dims = [ - d - for d in self.dims - if d[-1:] != "_" or (d[-1:] == "_" and d[:-1] not in self.dims) - ] + dims = [d for d in self.dims if d[-1:] != "_" or (d[-1:] == "_" and d[:-1] not in self.dims)] if replacements is not None: obj = self.replace_datasets(replacements) else: @@ -1257,21 +1196,6 @@ def setup_flow( with_root_node_name=with_root_node_name, ) - def _spill(self, all_name_tokens=()): - """ - Write backup code for sharrow-lite. - - Parameters - ---------- - all_name_tokens - - Returns - ------- - - """ - cmds = [] - return "\n".join(cmds) - def get_named_array(self, mangled_name): if mangled_name[:2] != "__": raise KeyError(mangled_name) @@ -1311,7 +1235,6 @@ def digitize_relationships(self, inplace=False, redigitize=True): DataTree or None Only returns a copy if not digitizing in-place. """ - if inplace: obj = self else: @@ -1349,9 +1272,7 @@ def mapper_get(x, mapper=mapper): ) # candidate name for write back - r_parent_name_new = ( - f"{self._BY_OFFSET}{r.parent_name}_{r.child_data}_{r.child_name}" - ) + r_parent_name_new = f"{self._BY_OFFSET}{r.parent_name}_{r.child_data}_{r.child_name}" # it is common to have mirrored offsets in various dimensions. # we'd like to retain only the same data in memory once, so we'll @@ -1362,16 +1283,14 @@ def mapper_get(x, mapper=mapper): if p_dataset[k].equals(offsets): # we found a match, so we'll assign this name to # the match's memory storage instead of replicating it. - obj._graph.nodes[r.parent_data][ - "dataset" - ] = p_dataset.assign({r_parent_name_new: p_dataset[k]}) + obj._graph.nodes[r.parent_data]["dataset"] = p_dataset.assign( + {r_parent_name_new: p_dataset[k]} + ) # r_parent_name_new = k break else: # no existing offset arrays match, make this new one - obj._graph.nodes[r.parent_data]["dataset"] = p_dataset.assign( - {r_parent_name_new: offsets} - ) + obj._graph.nodes[r.parent_data]["dataset"] = p_dataset.assign({r_parent_name_new: offsets}) obj._graph.edges[e].update( dict( parent_name=r_parent_name_new, @@ -1385,16 +1304,14 @@ def mapper_get(x, mapper=mapper): @property def relationships_are_digitized(self): - """bool : Whether all relationships are digital (by position).""" + """Bool : Whether all relationships are digital (by position).""" for e in self._graph.edges: r = self._get_relationship(e) if r.indexing != "position": return False return True - def _arg_tokenizer( - self, spacename, spacearray, spacearrayname, exclude_dims=None, blends=None - ): + def _arg_tokenizer(self, spacename, spacearray, spacearrayname, exclude_dims=None, blends=None): if blends is None: blends = {} @@ -1408,10 +1325,7 @@ def _arg_tokenizer( else: from_dims = spacearray.dims return ( - tuple( - ast.parse(f"_arg{root_dims.index(dim):02}", mode="eval").body - for dim in from_dims - ), + tuple(ast.parse(f"_arg{root_dims.index(dim):02}", mode="eval").body for dim in from_dims), blends, ) @@ -1421,9 +1335,7 @@ def _arg_tokenizer( spacearray_ = spacearray from_dims = spacearray_.dims - offset_source = spacearray_.attrs.get("digital_encoding", {}).get( - "offset_source", None - ) + offset_source = spacearray_.attrs.get("digital_encoding", {}).get("offset_source", None) if offset_source is not None: from_dims = self._graph.nodes[spacename]["dataset"][offset_source].dims @@ -1436,9 +1348,7 @@ def _arg_tokenizer( this_dim_name = self._graph.edges[e]["child_name"] retarget = None if dimname != this_dim_name: - retarget = self._graph.nodes[spacename][ - "dataset" - ].redirection.target(this_dim_name) + retarget = self._graph.nodes[spacename]["dataset"].redirection.target(this_dim_name) if dimname != retarget: continue parent_name = self._graph.edges[e]["parent_name"] @@ -1513,9 +1423,7 @@ def get_index(self, dim): return subspace.indexes[dim] def copy(self): - return type(self)( - self._graph.copy(), self.root_node_name, **self.__shallow_copy_extras() - ) + return type(self)(self._graph.copy(), self.root_node_name, **self.__shallow_copy_extras()) def all_var_names(self, uniquify=False, _duplicated_names=None): ordered_names = [] diff --git a/sharrow/shared_memory.py b/sharrow/shared_memory.py index 3bd273b..dfc5ed7 100644 --- a/sharrow/shared_memory.py +++ b/sharrow/shared_memory.py @@ -139,9 +139,7 @@ def open_shared_memory_array(key, mode="r+"): except FileNotFoundError: raise FileNotFoundError(f"sharrow_shared_memory_array:{key}") from None else: - logger.info( - f"shared memory array from ephemeral memory, {si_units(result.size)}" - ) + logger.info(f"shared memory array from ephemeral memory, {si_units(result.size)}") return result if backing.startswith("memmap:"): @@ -237,9 +235,7 @@ def __repr__(self): return r def release_shared_memory(self): - """ - Release shared memory allocated to this Dataset. - """ + """Release shared memory allocated to this Dataset.""" release_shared_memory(self._shared_memory_key_) @staticmethod @@ -364,17 +360,13 @@ def emit(k, a, is_coord): mem_arr_p = np.ndarray( shape=(_size_p // ad.indptr.dtype.itemsize,), dtype=ad.indptr.dtype, - buffer=buffer[ - _pos + _size_d + _size_i : _pos + _size_d + _size_i + _size_p - ], + buffer=buffer[_pos + _size_d + _size_i : _pos + _size_d + _size_i + _size_p], ) mem_arr_d[:] = ad.data[:] mem_arr_i[:] = ad.indices[:] mem_arr_p[:] = ad.indptr[:] else: - mem_arr = np.ndarray( - shape=a.shape, dtype=a.dtype, buffer=buffer[_pos : _pos + _size] - ) + mem_arr = np.ndarray(shape=a.shape, dtype=a.dtype, buffer=buffer[_pos : _pos + _size]) if isinstance(a, xr.DataArray) and isinstance(a.data, da.Array): tasks.append(da.store(a.data, mem_arr, lock=False, compute=False)) else: @@ -385,9 +377,7 @@ def emit(k, a, is_coord): if key.startswith("memmap:"): mem.flush() - create_shared_list( - [pickle.dumps(self._obj.attrs)] + [pickle.dumps(i) for i in wrappers], key - ) + create_shared_list([pickle.dumps(self._obj.attrs)] + [pickle.dumps(i) for i in wrappers], key) return type(self).from_shared_memory(key, own_data=mem, mode=mode) @property @@ -474,14 +464,7 @@ def from_shared_memory(cls, key, own_data=False, mode="r+"): mem_arr_p = np.ndarray( _size_p // _dtype_p.itemsize, dtype=_dtype_p, - buffer=buffer[ - position - + _size_d - + _size_i : position - + _size_d - + _size_i - + _size_p - ], + buffer=buffer[position + _size_d + _size_i : position + _size_d + _size_i + _size_p], ) mem_arr = sparse.GCXS( ( @@ -493,9 +476,7 @@ def from_shared_memory(cls, key, own_data=False, mode="r+"): compressed_axes=(0,), ) else: - mem_arr = np.ndarray( - shape, dtype=dtype, buffer=buffer[position : position + nbytes] - ) + mem_arr = np.ndarray(shape, dtype=dtype, buffer=buffer[position : position + nbytes]) content[name] = DataArray(mem_arr, **t) obj = cls._parent_class(content) @@ -507,7 +488,7 @@ def from_shared_memory(cls, key, own_data=False, mode="r+"): @property def shared_memory_size(self): - """int : Size (in bytes) in shared memory, raises ValueError if not shared.""" + """Int : Size (in bytes) in shared memory, raises ValueError if not shared.""" try: return sum(i.size for i in self._shared_memory_objs_) except AttributeError: @@ -515,7 +496,7 @@ def shared_memory_size(self): @property def is_shared_memory(self): - """bool : Whether this Dataset is in shared memory.""" + """Bool : Whether this Dataset is in shared memory.""" try: return sum(i.size for i in self._shared_memory_objs_) > 0 except AttributeError: diff --git a/sharrow/sparse.py b/sharrow/sparse.py index 77448ea..4fa7095 100644 --- a/sharrow/sparse.py +++ b/sharrow/sparse.py @@ -70,9 +70,7 @@ def __init__(self, i, j, data, shape=None): if isinstance(data, scipy.sparse.csr_matrix): self._sparse_data = data else: - self._sparse_data = scipy.sparse.coo_matrix( - (data, (i, j)), shape=shape - ).tocsr() + self._sparse_data = scipy.sparse.coo_matrix((data, (i, j)), shape=shape).tocsr() self._sparse_data.sort_indices() def __getitem__(self, item): @@ -95,12 +93,13 @@ def __init__(self, xarray_obj): def set(self, m2t, map_to, map_also=None, name=None): """ + Set the redirection of a dimension. + Parameters ---------- m2t : pandas.Series Mapping maz's to tazs """ - if name is None: name = f"redirect_{map_to}" @@ -148,9 +147,7 @@ def apply_mapper(x): i_ = i j_ = j - sparse_data = sparse.GCXS( - sparse.COO((i_, j_), data, shape=shape), compressed_axes=(0,) - ) + sparse_data = sparse.GCXS(sparse.COO((i_, j_), data, shape=shape), compressed_axes=(0,)) self._obj[f"_s_{name}"] = xr.DataArray( sparse_data, dims=(i_dim, j_dim), @@ -240,9 +237,7 @@ def get_blended_2(backstop_value, indices, indptr, data, i, j, blend_limit=np.in @nb.njit -def get_blended_2_arr( - backstop_values_, indices, indptr, data, i_, j_, blend_limit=np.inf -): +def get_blended_2_arr(backstop_values_, indices, indptr, data, i_, j_, blend_limit=np.inf): out = np.zeros_like(backstop_values_) for z in range(backstop_values_.size): out[z] = get_blended_2( diff --git a/sharrow/tests/conftest.py b/sharrow/tests/conftest.py index e532de1..f3ba25f 100644 --- a/sharrow/tests/conftest.py +++ b/sharrow/tests/conftest.py @@ -7,9 +7,7 @@ @pytest.fixture def person_dataset() -> xr.Dataset: - """ - Sample persons dataset with dummy data. - """ + """Sample persons dataset with dummy data.""" df = pd.DataFrame( { "Income": [45, 88, 56, 15, 71], @@ -26,9 +24,7 @@ def person_dataset() -> xr.Dataset: @pytest.fixture def household_dataset() -> xr.Dataset: - """ - Sample household dataset with dummy data. - """ + """Sample household dataset with dummy data.""" df = pd.DataFrame( { "n_cars": [1, 2, 1], @@ -40,9 +36,7 @@ def household_dataset() -> xr.Dataset: @pytest.fixture def tours_dataset() -> xr.Dataset: - """ - Sample tours dataset with dummy data. - """ + """Sample tours dataset with dummy data.""" df = pd.DataFrame( { "TourMode": ["Car", "Bus", "Car", "Car", "Walk"], diff --git a/sharrow/translate.py b/sharrow/translate.py index 2e90611..a79c4dd 100644 --- a/sharrow/translate.py +++ b/sharrow/translate.py @@ -3,8 +3,8 @@ import numpy as np import pandas as pd import xarray as xr -from larch import OMX +from larch import OMX from sharrow.dataset import Dataset from .dataset import one_based, zero_based diff --git a/sharrow/utils/tar_zst.py b/sharrow/utils/tar_zst.py index a5e78ad..90ac1f8 100644 --- a/sharrow/utils/tar_zst.py +++ b/sharrow/utils/tar_zst.py @@ -14,7 +14,8 @@ def extract_zst(archive: Path, out_path: Path): """ - extract .zst file + Extract content of zst file to a target file system directory. + works on Windows, Linux, MacOS, etc. Parameters @@ -24,7 +25,6 @@ def extract_zst(archive: Path, out_path: Path): out_path: pathlib.Path or str directory to extract files and directories to """ - if zstandard is None: raise ImportError("pip install zstandard") diff --git a/sharrow/wrappers.py b/sharrow/wrappers.py index d4fa3b4..796b723 100644 --- a/sharrow/wrappers.py +++ b/sharrow/wrappers.py @@ -79,6 +79,7 @@ def igather(source, positions): class DatasetWrapper: def __init__(self, dataset, orig_key, dest_key, time_key=None): """ + Emulate ActivitySim's SkimWrapper. Parameters ---------- @@ -97,7 +98,7 @@ def __init__(self, dataset, orig_key, dest_key, time_key=None): def set_df(self, df): """ - Set the dataframe + Set the dataframe. Parameters ---------- @@ -108,16 +109,10 @@ def set_df(self, df): ------- self (to facilitate chaining) """ - assert ( - self.orig_key in df - ), f"orig_key '{self.orig_key}' not in df columns: {list(df.columns)}" - assert ( - self.dest_key in df - ), f"dest_key '{self.dest_key}' not in df columns: {list(df.columns)}" + assert self.orig_key in df, f"orig_key '{self.orig_key}' not in df columns: {list(df.columns)}" + assert self.dest_key in df, f"dest_key '{self.dest_key}' not in df columns: {list(df.columns)}" if self.time_key: - assert ( - self.time_key in df - ), f"time_key '{self.time_key}' not in df columns: {list(df.columns)}" + assert self.time_key in df, f"time_key '{self.time_key}' not in df columns: {list(df.columns)}" self.df = df # TODO allow non-1 offsets @@ -137,7 +132,7 @@ def set_df(self, df): def lookup(self, key, reverse=False): """ - Generally not called by the user - use __getitem__ instead + Generally not called by the user - use __getitem__ instead. Parameters ---------- @@ -154,7 +149,6 @@ def lookup(self, key, reverse=False): A Series of impedances which are elements of the Skim object and with the same index as df """ - assert self.df is not None, "Call set_df first" if reverse: x = self.positions.rename(columns={"otaz": "dtaz", "dtaz": "otaz"}) @@ -166,7 +160,9 @@ def lookup(self, key, reverse=False): def __getitem__(self, key): """ - Get the lookup for an available skim object (df and orig/dest and column names implicit) + Get the lookup for an available skim object. + + The `df` and orig/dest and column names are implicit. Parameters ---------- @@ -176,6 +172,7 @@ def __getitem__(self, key): Returns ------- impedances: pd.Series with the same index as df - A Series of impedances values from the single Skim with specified key, indexed byt orig/dest pair + A Series of impedances values from the single Skim with specified key, + indexed byt orig/dest pair """ return self.lookup(key) From 50d70ce870bc9fe50a9409a12372d744823e215f Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Wed, 10 Jan 2024 13:46:26 -0600 Subject: [PATCH 03/14] line len 88, like black --- .pre-commit-config.yaml | 21 ++- pyproject.toml | 2 +- sharrow/dataset.py | 74 +++++++--- sharrow/datastore.py | 36 +++-- sharrow/digital_encoding.py | 28 +++- sharrow/flows.py | 273 ++++++++++++++++++++++++++++-------- sharrow/relationships.py | 131 +++++++++++++---- sharrow/selectors.py | 1 - sharrow/shared_memory.py | 27 +++- sharrow/sparse.py | 12 +- sharrow/wrappers.py | 12 +- 11 files changed, 466 insertions(+), 151 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a021871..27bcc2e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,19 +13,14 @@ repos: hooks: - id: nbstripout -- repo: https://github.com/charliermarsh/ruff-pre-commit +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. rev: v0.1.11 hooks: + # Run the linter. - id: ruff - args: [--fix, --exit-non-zero-on-fix] - -- repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - args: ["--profile", "black", "--filter-files"] - -- repo: https://github.com/psf/black - rev: 23.12.1 - hooks: - - id: black + types_or: [ python, pyi, jupyter ] + args: [ --fix, --exit-non-zero-on-fix ] + # Run the formatter. + - id: ruff-format + types_or: [ python, pyi, jupyter ] diff --git a/pyproject.toml b/pyproject.toml index 8ad452e..3fee2d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ select = [ "F", # Pyflakes ] fix = true ignore-init-module-imports = true -line-length = 120 +line-length = 88 ignore = ["B905", "D1"] target-version = "py39" diff --git a/sharrow/dataset.py b/sharrow/dataset.py index 5f6cc25..85b7de2 100755 --- a/sharrow/dataset.py +++ b/sharrow/dataset.py @@ -70,7 +70,9 @@ def clean(s): cleaned = re.sub(r"\W|^(?=\d)", "_", s) if cleaned != s or len(cleaned) > 120: # digest size 15 creates a 24 character base32 string - h = base64.b32encode(hashlib.blake2b(s.encode(), digest_size=15).digest()).decode() + h = base64.b32encode( + hashlib.blake2b(s.encode(), digest_size=15).digest() + ).decode() cleaned = f"{cleaned[:90]}_{h}" return cleaned @@ -167,7 +169,9 @@ def dataset_from_dataframe_fast( if cannot_fix: break dupe_column_names = [f"- {i}" for i in dupe_column_names] - logger.error("DataFrame has non-unique columns\n" + "\n".join(dupe_column_names)) + logger.error( + "DataFrame has non-unique columns\n" + "\n".join(dupe_column_names) + ) if cannot_fix: raise ValueError("cannot convert DataFrame with non-unique columns") else: @@ -233,9 +237,13 @@ def from_table( index = pd.RangeIndex(len(tbl), name=index_name) else: if len(index) != len(tbl): - raise ValueError(f"length of index ({len(index)}) does not match length of table ({len(tbl)})") + raise ValueError( + f"length of index ({len(index)}) does not match length of table ({len(tbl)})" + ) if isinstance(index, pd.MultiIndex) and not index.is_unique: - raise ValueError("cannot attach a non-unique MultiIndex and convert into xarray") + raise ValueError( + "cannot attach a non-unique MultiIndex and convert into xarray" + ) arrays = [] metadata = {} for n in range(len(tbl.column_names)): @@ -253,7 +261,10 @@ def from_table( arrays.append((tbl.column_names[n], np.asarray(c))) result = xr.Dataset() if isinstance(index, pd.MultiIndex): - dims = tuple(name if name is not None else "level_%i" % n for n, name in enumerate(index.names)) + dims = tuple( + name if name is not None else "level_%i" % n + for n, name in enumerate(index.names) + ) for dim, lev in zip(dims, index.levels): result[dim] = (dim, lev) else: @@ -367,7 +378,10 @@ def from_omx( raise KeyError(f"{i} not found in OMX lookups") indexes = indexes_ if indexes is not None: - d["coords"] = {index_name: {"dims": index_name, "data": index} for index_name, index in indexes.items()} + d["coords"] = { + index_name: {"dims": index_name, "data": index} + for index_name, index in indexes.items() + } return xr.Dataset.from_dict(d) @@ -458,7 +472,9 @@ def from_omx_3d( elif indexes in set(omx_lookup._v_children): ranger = None else: - raise NotImplementedError("only one-based, zero-based, and named indexes are implemented") + raise NotImplementedError( + "only one-based, zero-based, and named indexes are implemented" + ) if ranger is not None: r1 = ranger(n1) r2 = ranger(n2) @@ -478,7 +494,9 @@ def from_omx_3d( base_k, time_k = k.split(time_period_sep, 1) if base_k not in pending_3d: pending_3d[base_k] = [None] * len(time_periods) - pending_3d[base_k][time_periods_map[time_k]] = dask.array.from_array(omx_data[omx_data_map[k]][k]) + pending_3d[base_k][time_periods_map[time_k]] = dask.array.from_array( + omx_data[omx_data_map[k]][k] + ) else: content[k] = xr.DataArray( dask.array.from_array(omx_data[omx_data_map[k]][k]), @@ -497,7 +515,9 @@ def from_omx_3d( break if prototype is None: raise ValueError("no prototype") - darrs_ = [(i if i is not None else dask.array.zeros_like(prototype)) for i in darrs] + darrs_ = [ + (i if i is not None else dask.array.zeros_like(prototype)) for i in darrs + ] content[base_k] = xr.DataArray( dask.array.stack(darrs_, axis=-1), dims=index_names, @@ -547,7 +567,10 @@ def from_amx( elif indexes.get(i) == "0": indexes[i] = zero_based(amx.shape[n]) if indexes is not None: - d["coords"] = {index_name: {"dims": index_name, "data": index} for index_name, index in indexes.items()} + d["coords"] = { + index_name: {"dims": index_name, "data": index} + for index_name, index in indexes.items() + } return xr.Dataset.from_dict(d) @@ -697,7 +720,10 @@ def _to_pydict(self): data = [] for k in columns: a = self.dataset._variables[k] - if "digital_encoding" in a.attrs and "dictionary" in a.attrs["digital_encoding"]: + if ( + "digital_encoding" in a.attrs + and "dictionary" in a.attrs["digital_encoding"] + ): de = a.attrs["digital_encoding"] data.append( pd.Categorical.from_codes( @@ -715,7 +741,10 @@ def to_pyarrow(self) -> pa.Table: data = [] for k in columns: a = self.dataset._variables[k] - if "digital_encoding" in a.attrs and "dictionary" in a.attrs["digital_encoding"]: + if ( + "digital_encoding" in a.attrs + and "dictionary" in a.attrs["digital_encoding"] + ): de = a.attrs["digital_encoding"] data.append( pa.DictionaryArray.from_arrays( @@ -857,7 +886,9 @@ def to_pandas(self) -> pd.Series: def to_pyarrow(self): if self.dataarray.cat.is_categorical(): - return pa.DictionaryArray.from_arrays(self.dataarray.data, self.dataarray.cat.categories) + return pa.DictionaryArray.from_arrays( + self.dataarray.data, self.dataarray.cat.categories + ) else: return pa.array(self.dataarray.data) @@ -887,7 +918,10 @@ def __getitem__(self, key: Mapping[Hashable, Any]) -> Dataset: dim_name = self.dataset.dims.__iter__().__next__() key = {dim_name: key} else: - raise TypeError("can only lookup dictionaries from Dataset.iloc, " "unless there is only one dimension") + raise TypeError( + "can only lookup dictionaries from Dataset.iloc, " + "unless there is only one dimension" + ) return self.dataset.isel(key) @@ -899,7 +933,9 @@ def rename_or_ignore(self, dims_dict=None, **dims_kwargs): from xarray.core.utils import either_dict_or_kwargs dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "rename_dims_and_coords") - dims_dict = {k: v for (k, v) in dims_dict.items() if (k in self.dims or k in self._variables)} + dims_dict = { + k: v for (k, v) in dims_dict.items() if (k in self.dims or k in self._variables) + } return self.rename(dims_dict) @@ -991,7 +1027,13 @@ def to_zarr_zip(self, *args, **kwargs): def _to_ast_literal(x): if isinstance(x, dict): - return "{" + ", ".join(f"{_to_ast_literal(k)}: {_to_ast_literal(v)}" for k, v in x.items()) + "}" + return ( + "{" + + ", ".join( + f"{_to_ast_literal(k)}: {_to_ast_literal(v)}" for k, v in x.items() + ) + + "}" + ) elif isinstance(x, list): return "[" + ", ".join(_to_ast_literal(i) for i in x) + "]" elif isinstance(x, tuple): diff --git a/sharrow/datastore.py b/sharrow/datastore.py index bb3379b..6a98d2c 100644 --- a/sharrow/datastore.py +++ b/sharrow/datastore.py @@ -128,7 +128,9 @@ def _update_dataset( if k in data.coords: continue assert v.name == k - partial_update = self._update_dataarray(name, v, last_checkpoint, partial_update=partial_update) + partial_update = self._update_dataarray( + name, v, last_checkpoint, partial_update=partial_update + ) for k, v in data.coords.items(): assert v.name == k partial_update = self._update_dataarray( @@ -156,8 +158,12 @@ def _update_dataarray( {data.name: data.assign_attrs(last_checkpoint=last_checkpoint)} ) else: - updated_dataset = base_data.assign({data.name: data.assign_attrs(last_checkpoint=last_checkpoint)}) - self._tree = self._tree.replace_datasets({name: updated_dataset}, redigitize=self._keep_digitized) + updated_dataset = base_data.assign( + {data.name: data.assign_attrs(last_checkpoint=last_checkpoint)} + ) + self._tree = self._tree.replace_datasets( + {name: updated_dataset}, redigitize=self._keep_digitized + ) return updated_dataset else: raise TypeError(type(data)) @@ -260,7 +266,9 @@ def _zarr_subdir(self, table_name, checkpoint_name): return self.directory.joinpath(table_name, checkpoint_name).with_suffix(".zarr") def _parquet_name(self, table_name, checkpoint_name): - return self.directory.joinpath(table_name, checkpoint_name).with_suffix(".parquet") + return self.directory.joinpath(table_name, checkpoint_name).with_suffix( + ".parquet" + ) def make_checkpoint(self, checkpoint_name: str, overwrite: bool = True): """ @@ -294,7 +302,9 @@ def make_checkpoint(self, checkpoint_name: str, overwrite: bool = True): os.unlink(target) target.parent.mkdir(parents=True, exist_ok=True) table_data.single_dim.to_parquet(str(target)) - elif self._storage_format == "zarr" or (self._storage_format == "parquet" and len(table_data.dims) > 1): + elif self._storage_format == "zarr" or ( + self._storage_format == "parquet" and len(table_data.dims) > 1 + ): # zarr is used if ndim > 1 target = self._zarr_subdir(table_name, checkpoint_name) if overwrite and target.is_dir(): @@ -304,7 +314,9 @@ def make_checkpoint(self, checkpoint_name: str, overwrite: bool = True): elif self._storage_format == "hdf5": raise NotImplementedError else: - raise ValueError(f"cannot write with storage format {self._storage_format!r}") + raise ValueError( + f"cannot write with storage format {self._storage_format!r}" + ) self.update(table_name, table_data, last_checkpoint=checkpoint_name) for table_name, table_data in self._tree.subspaces_iter(): inventory = {"data_vars": {}, "coords": {}} @@ -332,7 +344,9 @@ def make_checkpoint(self, checkpoint_name: str, overwrite: bool = True): def _write_checkpoint(self, name, checkpoint): if self._mode == "r": raise ReadOnlyError - checkpoint_metadata_target = self.directory.joinpath(self.checkpoint_subdir, f"{name}.yaml") + checkpoint_metadata_target = self.directory.joinpath( + self.checkpoint_subdir, f"{name}.yaml" + ) if checkpoint_metadata_target.exists(): n = 1 while checkpoint_metadata_target.with_suffix(f".{n}.yaml").exists(): @@ -386,7 +400,9 @@ def read_metadata(self, checkpoints=None): else: checkpoints = [checkpoints] for c in checkpoints: - with open(self.directory.joinpath(self.checkpoint_subdir, f"{c}.yaml")) as f: + with open( + self.directory.joinpath(self.checkpoint_subdir, f"{c}.yaml") + ) as f: self._checkpoints[c] = yaml.safe_load(f) def restore_checkpoint(self, checkpoint_name: str): @@ -414,7 +430,9 @@ def restore_checkpoint(self, checkpoint_name: str): opened_targets[target] = from_zarr_with_attr(target) else: # zarr not found, try parquet - target2 = self._parquet_name(table_name, coord_def["last_checkpoint"]) + target2 = self._parquet_name( + table_name, coord_def["last_checkpoint"] + ) if target2.exists(): if target not in opened_targets: opened_targets[target] = _read_parquet(target2, index_name) diff --git a/sharrow/digital_encoding.py b/sharrow/digital_encoding.py index a62fdef..95a1cc1 100644 --- a/sharrow/digital_encoding.py +++ b/sharrow/digital_encoding.py @@ -101,7 +101,8 @@ def array_decode(x, digital_encoding=None, aux_data=None): if offset_source: if aux_data is None: raise ValueDecodeError( - "cannot independently decode multivalue DataArray, " "provide aux_data or decode from dataset" + "cannot independently decode multivalue DataArray, " + "provide aux_data or decode from dataset" ) result = aux_data[offset_source].copy() result.data = x.to_numpy()[result.data] @@ -348,23 +349,35 @@ def multivalue_digitize_by_dictionary(ds, encode_vars=None, encoding_name=None): # check each name in encode_vars to make sure it's not already encoded # if you want to re-encode first decode - encode_vars = [v for v in encode_vars if "offset_source" not in ds[v].attrs.get("digital_encoding", {})] + encode_vars = [ + v + for v in encode_vars + if "offset_source" not in ds[v].attrs.get("digital_encoding", {}) + ] if len(encode_vars) == 0: return ds encode_var_dims = ds[encode_vars[0]].dims for v in encode_vars[1:]: - assert encode_var_dims == ds[v].dims, f"dims must match, {encode_var_dims} != {ds[v].dims}" + assert ( + encode_var_dims == ds[v].dims + ), f"dims must match, {encode_var_dims} != {ds[v].dims}" logger.info("assembling data stack") - conjoined = np.stack([array_decode(ds[v].compute(), aux_data=ds) for v in encode_vars], axis=-1) + conjoined = np.stack( + [array_decode(ds[v].compute(), aux_data=ds) for v in encode_vars], axis=-1 + ) logger.info("constructing stack view") baseshape = conjoined.shape[:-1] conjoined = conjoined.reshape([-1, conjoined.shape[-1]]) - voidview = np.ascontiguousarray(conjoined).view(np.dtype((np.void, conjoined.dtype.itemsize * conjoined.shape[1]))) + voidview = np.ascontiguousarray(conjoined).view( + np.dtype((np.void, conjoined.dtype.itemsize * conjoined.shape[1])) + ) logger.info("finding unique value combinations") unique_values, pointers = np.unique(voidview, return_inverse=True) pointers = pointers.reshape(baseshape) - unique_values = unique_values.view(np.dtype(conjoined.dtype)).reshape([-1, len(encode_vars)]) + unique_values = unique_values.view(np.dtype(conjoined.dtype)).reshape( + [-1, len(encode_vars)] + ) logger.info("downsampling offsets") if unique_values.shape[0] < 1 << 8: pointers = pointers.astype(np.uint8) @@ -397,6 +410,7 @@ def multivalue_digitize_by_dictionary(ds, encode_vars=None, encoding_name=None): bytes_saved = original_footprint - encoded_footprint savings_ratio = bytes_saved / original_footprint logger.info( - f"multivalue_digitize_by_dictionary {encoding_name} " f"saved {si_units(bytes_saved)} {savings_ratio:.1%}" + f"multivalue_digitize_by_dictionary {encoding_name} " + f"saved {si_units(bytes_saved)} {savings_ratio:.1%}" ) return out diff --git a/sharrow/flows.py b/sharrow/flows.py index e98c5e2..849cff8 100644 --- a/sharrow/flows.py +++ b/sharrow/flows.py @@ -80,7 +80,9 @@ def clean(s): cleaned = re.sub(r"\W|^(?=\d)", "_", s) if cleaned != s or len(cleaned) > 120: # digest size 15 creates a 24 character base32 string - h = base64.b32encode(hashlib.blake2b(s.encode(), digest_size=15).digest()).decode() + h = base64.b32encode( + hashlib.blake2b(s.encode(), digest_size=15).digest() + ).decode() cleaned = f"{cleaned[:90]}_{h}" return cleaned @@ -153,16 +155,29 @@ def visit_Call(self, node): if len(node.args) == 1: if isinstance(node.args[0], ast.Constant): if len(node.keywords) == 0: - self.required_get_tokens.add((node.func.value.id, node.args[0].value)) - elif len(node.keywords) == 1 and node.keywords[0].arg == "default": - self.optional_get_tokens.add((node.func.value.id, node.args[0].value)) + self.required_get_tokens.add( + (node.func.value.id, node.args[0].value) + ) + elif ( + len(node.keywords) == 1 + and node.keywords[0].arg == "default" + ): + self.optional_get_tokens.add( + (node.func.value.id, node.args[0].value) + ) else: - raise ValueError(f"{node.func.value.id}.get with unexpected keyword arguments") + raise ValueError( + f"{node.func.value.id}.get with unexpected keyword arguments" + ) if len(node.args) == 2: if isinstance(node.args[0], ast.Constant): - self.optional_get_tokens.add((node.func.value.id, node.args[0].value)) + self.optional_get_tokens.add( + (node.func.value.id, node.args[0].value) + ) if len(node.args) > 2: - raise ValueError(f"{node.func.value.id}.get with more than 2 positional arguments") + raise ValueError( + f"{node.func.value.id}.get with more than 2 positional arguments" + ) self.generic_visit(node) def check(self, node): @@ -1035,7 +1050,9 @@ def __initialize_1( all_raw_names |= attribute_pairs.get(self.tree.root_node_name, set()) all_raw_names |= subscript_pairs.get(self.tree.root_node_name, set()) - dimensions_ordered = presorted(self.tree.sizes, self.dim_order, self.dim_exclude) + dimensions_ordered = presorted( + self.tree.sizes, self.dim_order, self.dim_exclude + ) index_slots = {i: n for n, i in enumerate(dimensions_ordered)} self.arg_name_positions = index_slots self.arg_names = dimensions_ordered @@ -1063,12 +1080,17 @@ def __initialize_1( subspace_names.add(k) for k in self.tree.subspace_fallbacks: subspace_names.add(k) - optional_get_tokens = ExtractOptionalGetTokens(from_names=subspace_names).check(defs.values()) + optional_get_tokens = ExtractOptionalGetTokens(from_names=subspace_names).check( + defs.values() + ) self._optional_get_tokens = [] if optional_get_tokens: for _spacename, _varname in optional_get_tokens: found = False - if _spacename in self.tree.subspaces and _varname in self.tree.subspaces[_spacename]: + if ( + _spacename in self.tree.subspaces + and _varname in self.tree.subspaces[_spacename] + ): self._optional_get_tokens.append(f"__{_spacename}__{_varname}:True") found = True elif _spacename in self.tree.subspace_fallbacks: @@ -1080,7 +1102,9 @@ def __initialize_1( found = True break if not found: - self._optional_get_tokens.append(f"__{_spacename}__{_varname}:False") + self._optional_get_tokens.append( + f"__{_spacename}__{_varname}:False" + ) self._hashing_level = hashing_level if self._hashing_level > 1: @@ -1149,7 +1173,9 @@ def _flow_hash_push(x): parts = k.split("__") if len(parts) > 2: try: - digital_encoding = self.tree.subspaces[parts[1]]["__".join(parts[2:])].attrs["digital_encoding"] + digital_encoding = self.tree.subspaces[parts[1]][ + "__".join(parts[2:]) + ].attrs["digital_encoding"] except (AttributeError, KeyError) as err: pass print(f"$$$$/ndigital_encoding=ERR\n{err}\n\n\n$$$") @@ -1175,7 +1201,12 @@ def _flow_hash_push(x): self.flow_hash_audit = "]\n# [".join(flow_hash_audit) def _index_slots(self): - return {i: n for n, i in enumerate(presorted(self.tree.sizes, self.dim_order, self.dim_exclude))} + return { + i: n + for n, i in enumerate( + presorted(self.tree.sizes, self.dim_order, self.dim_exclude) + ) + } def init_sub_funcs( self, @@ -1187,7 +1218,12 @@ def init_sub_funcs( ): func_code = "" all_name_tokens = set() - index_slots = {i: n for n, i in enumerate(presorted(self.tree.sizes, self.dim_order, self.dim_exclude))} + index_slots = { + i: n + for n, i in enumerate( + presorted(self.tree.sizes, self.dim_order, self.dim_exclude) + ) + } self.arg_name_positions = index_slots candidate_names = self.tree.namespace_names() candidate_names |= set(f"__aux_var__{i}" for i in self.tree.aux_vars.keys()) @@ -1280,8 +1316,12 @@ def init_sub_funcs( else: raise # check if we can resolve this name on any other subspace - for other_spacename in self.tree.subspace_fallbacks.get(topkey, []): - dim_slots, digital_encodings, blenders = meta_data[other_spacename] + for other_spacename in self.tree.subspace_fallbacks.get( + topkey, [] + ): + dim_slots, digital_encodings, blenders = meta_data[ + other_spacename + ] try: expr = expression_for_numba( expr, @@ -1321,8 +1361,12 @@ def init_sub_funcs( # at least one variable was found in a get break # check if we can resolve this "get" on any other subspace - for other_spacename in self.tree.subspace_fallbacks.get(topkey, []): - dim_slots, digital_encodings, blenders = meta_data[other_spacename] + for other_spacename in self.tree.subspace_fallbacks.get( + topkey, [] + ): + dim_slots, digital_encodings, blenders = meta_data[ + other_spacename + ] try: expr = expression_for_numba( expr, @@ -1361,7 +1405,9 @@ def init_sub_funcs( actual_spacenames, ) in self.tree.subspace_fallbacks.items(): for actual_spacename in actual_spacenames: - dim_slots, digital_encodings, blenders = meta_data[actual_spacename] + dim_slots, digital_encodings, blenders = meta_data[ + actual_spacename + ] try: expr = expression_for_numba( expr, @@ -1404,7 +1450,10 @@ def init_sub_funcs( bool_wrapping=self.bool_wrapping, ) - aux_tokens = {k: ast.parse(f"__aux_var__{k}", mode="eval").body for k in self.tree.aux_vars.keys()} + aux_tokens = { + k: ast.parse(f"__aux_var__{k}", mode="eval").body + for k in self.tree.aux_vars.keys() + } # now handle aux vars expr = expression_for_numba( @@ -1505,7 +1554,9 @@ def __initialize_2( parts = k.split("__") if len(parts) > 2: try: - digital_encoding = self.tree.subspaces[parts[1]]["__".join(parts[2:])].attrs["digital_encoding"] + digital_encoding = self.tree.subspaces[parts[1]][ + "__".join(parts[2:]) + ].attrs["digital_encoding"] except (AttributeError, KeyError): pass else: @@ -1583,7 +1634,9 @@ def __initialize_2( if isinstance(x_var, (float, int, str)): buffer.write(f"{x_name} = {x_var!r}\n") else: - buffer.write(f"{x_name} = pickle.loads({repr(pickle.dumps(x_var))})\n") + buffer.write( + f"{x_name} = pickle.loads({repr(pickle.dumps(x_var))})\n" + ) dependencies.add("import pickle") with io.StringIO() as x_code: x_code.write("\n") @@ -1600,7 +1653,9 @@ def __initialize_2( import pickle buffer = io.StringIO() for x_name, x_dict in self.encoding_dictionaries.items(): - buffer.write(f"__encoding_dict{x_name} = pickle.loads({repr(pickle.dumps(x_dict))})\n") + buffer.write( + f"__encoding_dict{x_name} = pickle.loads({repr(pickle.dumps(x_dict))})\n" + ) with io.StringIO() as x_code: x_code.write("\n") x_code.write(buffer.getvalue()) @@ -1609,7 +1664,9 @@ def __initialize_2( # write the master module for this flow os.makedirs(os.path.join(self.cache_dir, self.name), exist_ok=True) - with rewrite(os.path.join(self.cache_dir, self.name, "__init__.py"), "wt") as f_code: + with rewrite( + os.path.join(self.cache_dir, self.name, "__init__.py"), "wt" + ) as f_code: f_code.write( textwrap.dedent( f""" @@ -1671,7 +1728,9 @@ def __initialize_2( root_dims = list( presorted( - self.tree._graph.nodes[with_root_node_name]["dataset"].sizes, + self.tree._graph.nodes[with_root_node_name][ + "dataset" + ].sizes, self.dim_order, self.dim_exclude, ) @@ -1683,7 +1742,9 @@ def __initialize_2( elif n_root_dims == 2: js = "j0, j1" else: - raise NotImplementedError(f"n_root_dims only supported up to 2, not {n_root_dims}") + raise NotImplementedError( + f"n_root_dims only supported up to 2, not {n_root_dims}" + ) meta_code = [] meta_code_dot = [] @@ -1700,24 +1761,48 @@ def __initialize_2( meta_code_dot.append( f"intermediate[{n}] = ({clean(k)}({f_args_j}intermediate, {f_name_tokens})).item()" ) - meta_code_stack = textwrap.indent("\n".join(meta_code), " " * 12).lstrip() - meta_code_stack_dot = textwrap.indent("\n".join(meta_code_dot), " " * 12).lstrip() + meta_code_stack = textwrap.indent( + "\n".join(meta_code), " " * 12 + ).lstrip() + meta_code_stack_dot = textwrap.indent( + "\n".join(meta_code_dot), " " * 12 + ).lstrip() len_self_raw_functions = len(self._raw_functions) - joined_namespace_names = "\n ".join(f"{nn}," for nn in self._namespace_names) + joined_namespace_names = "\n ".join( + f"{nn}," for nn in self._namespace_names + ) linefeed = "\n " if not meta_code_stack_dot: meta_code_stack_dot = "pass" if n_root_dims == 1: - meta_template = IRUNNER_1D_TEMPLATE.format(**locals()).format(**locals()) - meta_template_dot = IDOTTER_1D_TEMPLATE.format(**locals()).format(**locals()) - line_template = ILINER_1D_TEMPLATE.format(**locals()).format(**locals()) - mnl_template = MNL_1D_TEMPLATE.format(**locals()).format(**locals()) - nl_template = NL_1D_TEMPLATE.format(**locals()).format(**locals()) + meta_template = IRUNNER_1D_TEMPLATE.format(**locals()).format( + **locals() + ) + meta_template_dot = IDOTTER_1D_TEMPLATE.format( + **locals() + ).format(**locals()) + line_template = ILINER_1D_TEMPLATE.format(**locals()).format( + **locals() + ) + mnl_template = MNL_1D_TEMPLATE.format(**locals()).format( + **locals() + ) + nl_template = NL_1D_TEMPLATE.format(**locals()).format( + **locals() + ) elif n_root_dims == 2: - meta_template = IRUNNER_2D_TEMPLATE.format(**locals()).format(**locals()) - meta_template_dot = IDOTTER_2D_TEMPLATE.format(**locals()).format(**locals()) - line_template = ILINER_2D_TEMPLATE.format(**locals()).format(**locals()) - mnl_template = MNL_2D_TEMPLATE.format(**locals()).format(**locals()) + meta_template = IRUNNER_2D_TEMPLATE.format(**locals()).format( + **locals() + ) + meta_template_dot = IDOTTER_2D_TEMPLATE.format( + **locals() + ).format(**locals()) + line_template = ILINER_2D_TEMPLATE.format(**locals()).format( + **locals() + ) + mnl_template = MNL_2D_TEMPLATE.format(**locals()).format( + **locals() + ) nl_template = "" else: raise ValueError(f"invalid n_root_dims {n_root_dims}") @@ -1790,11 +1875,15 @@ def __initialize_2( def load_raw(self, rg, args, runner=None, dtype=None, dot=None): assert isinstance(rg, DataTree) with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=nb.NumbaExperimentalFeatureWarning) + warnings.filterwarnings( + "ignore", category=nb.NumbaExperimentalFeatureWarning + ) assembled_args = [args.get(k) for k in self.arg_name_positions.keys()] for aa in assembled_args: if aa.dtype.kind != "i": - warnings.warn("position arguments are not all integers", stacklevel=2) + warnings.warn( + "position arguments are not all integers", stacklevel=2 + ) try: if runner is None: if dot is None: @@ -1872,7 +1961,9 @@ def _iload_raw( ): assert isinstance(rg, DataTree) with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=nb.NumbaExperimentalFeatureWarning) + warnings.filterwarnings( + "ignore", category=nb.NumbaExperimentalFeatureWarning + ) try: known_arg_names = { "dtype", @@ -1908,7 +1999,11 @@ def _iload_raw( elif dot is None: runner_ = self._irunner known_arg_names.update({"mask"}) - if mask is not None and dtype is not None and not np.issubdtype(dtype, np.floating): + if ( + mask is not None + and dtype is not None + and not np.issubdtype(dtype, np.floating) + ): raise TypeError("cannot use mask unless dtype is float") else: runner_ = self._idotter @@ -1959,8 +2054,13 @@ def _iload_raw( if self.with_root_node_name is None: tree_root_dims = rg.root_dataset.sizes else: - tree_root_dims = rg._graph.nodes[self.with_root_node_name]["dataset"].sizes - argshape = [tree_root_dims[i] for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude)] + tree_root_dims = rg._graph.nodes[self.with_root_node_name][ + "dataset" + ].sizes + argshape = [ + tree_root_dims[i] + for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude) + ] if mnl is not None: if nesting is not None: n_alts = nesting["n_alts"] @@ -1973,7 +2073,9 @@ def _iload_raw( elif n_alts < 32768: kwargs["choice_dtype"] = np.int16 if logger.isEnabledFor(logging.DEBUG): - logger.debug("========= PASSING ARGUMENT TO SHARROW LOAD ==========") + logger.debug( + "========= PASSING ARGUMENT TO SHARROW LOAD ==========" + ) logger.debug(f"{argshape=}") for _name, _info in zip(_arguments_names, arguments): try: @@ -1993,10 +2095,14 @@ def _iload_raw( logger.debug(f"KWARG {_name}: {alt_repr}") else: logger.debug(f"KWARG {_name}: type={type(_info)}") - logger.debug("========= ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ==========") + logger.debug( + "========= ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ==========" + ) result = runner_(np.asarray(argshape), *arguments, **kwargs) if compile_watch: - self.check_cache_misses(runner_, log_details=compile_watch != "simple") + self.check_cache_misses( + runner_, log_details=compile_watch != "simple" + ) return result except nb.TypingError as err: _raw_functions = getattr(self, "_raw_functions", {}) @@ -2053,11 +2159,18 @@ def check_cache_misses(self, *funcs, fresh=True, log_details=True): for k, v in cache_misses.items(): if v > known_cache_misses.get(k, 0): if log_details: - warning_text = "\n".join(f" - {argname}: {sig}" for (sig, argname) in zip(k, named_args)) + warning_text = "\n".join( + f" - {argname}: {sig}" + for (sig, argname) in zip(k, named_args) + ) warning_text = f"\n{runner_name}(\n{warning_text}\n)" else: warning_text = "" - timers = f.overloads[k].metadata["timers"].get("compiler_lock", "N/A") + timers = ( + f.overloads[k] + .metadata["timers"] + .get("compiler_lock", "N/A") + ) if isinstance(timers, float): if timers < 1e-3: timers = f"{timers/1e-6:.0f} µs" @@ -2065,8 +2178,13 @@ def check_cache_misses(self, *funcs, fresh=True, log_details=True): timers = f"{timers/1e-3:.1f} ms" else: timers = f"{timers:.2f} s" - logger.warning(f"cache miss in {self.flow_hash}{warning_text}\n" f"Compile Time: {timers}") - warnings.warn(f"{self.flow_hash}", CacheMissWarning, stacklevel=1) + logger.warning( + f"cache miss in {self.flow_hash}{warning_text}\n" + f"Compile Time: {timers}" + ) + warnings.warn( + f"{self.flow_hash}", CacheMissWarning, stacklevel=1 + ) self.compiled_recently = True self._known_cache_misses[runner_name][k] = v return self.compiled_recently @@ -2151,7 +2269,9 @@ def _load( logit_draws = np.zeros(source.shape + (0,), dtype=dtype) if self.with_root_node_name is None: - use_dims = list(presorted(source.root_dataset.sizes, self.dim_order, self.dim_exclude)) + use_dims = list( + presorted(source.root_dataset.sizes, self.dim_order, self.dim_exclude) + ) else: use_dims = list( presorted( @@ -2208,7 +2328,11 @@ def _load( ): result_dims = use_dims[:-1] result_squeeze = (-1,) - elif dot.ndim == 2 and dot.shape[1] == 1 and logit_draws.ndim == len(use_dims): + elif ( + dot.ndim == 2 + and dot.shape[1] == 1 + and logit_draws.ndim == len(use_dims) + ): result_dims = use_dims[:-1] + logit_draws_trailing_dim elif dot.ndim == 2 and logit_draws.ndim == len(use_dims): result_dims = use_dims[:-1] + dot_trailing_dim @@ -2231,7 +2355,11 @@ def _load( and self._logit_ndims == 1 ): result_dims = use_dims + logit_draws_trailing_dim - elif dot.ndim == 2 and logit_draws.ndim == len(use_dims) + 1 and logit_draws.shape[-1] == 0: + elif ( + dot.ndim == 2 + and logit_draws.ndim == len(use_dims) + 1 + and logit_draws.shape[-1] == 0 + ): # logsums only result_dims = use_dims result_squeeze = (-1,) @@ -2309,20 +2437,30 @@ def _load( raise RuntimeError("please digitize") if as_dataframe: index = getattr(source.root_dataset, "index", None) - result = pd.DataFrame(result, index=index, columns=list(self._raw_functions.keys())) + result = pd.DataFrame( + result, index=index, columns=list(self._raw_functions.keys()) + ) elif as_table: - result = Table({k: result[:, n] for n, k in enumerate(self._raw_functions.keys())}) + result = Table( + {k: result[:, n] for n, k in enumerate(self._raw_functions.keys())} + ) elif as_dataarray: if result_squeeze: result = squeeze(result, result_squeeze) result_p = squeeze(result_p, result_squeeze) pick_count = squeeze(pick_count, result_squeeze) if self.with_root_node_name is None: - result_coords = {k: v for k, v in source.root_dataset.coords.items() if k in result_dims} + result_coords = { + k: v + for k, v in source.root_dataset.coords.items() + if k in result_dims + } else: result_coords = { k: v - for k, v in source._graph.nodes[self.with_root_node_name]["dataset"].coords.items() + for k, v in source._graph.nodes[self.with_root_node_name][ + "dataset" + ].coords.items() if k in result_dims } if result is not None: @@ -2349,7 +2487,11 @@ def _load( out_logsum = xr.DataArray( out_logsum, dims=result_dims[: out_logsum.ndim], - coords={k: v for k, v in source.root_dataset.coords.items() if k in result_dims[: out_logsum.ndim]}, + coords={ + k: v + for k, v in source.root_dataset.coords.items() + if k in result_dims[: out_logsum.ndim] + }, ) else: @@ -2413,7 +2555,9 @@ def load(self, source=None, dtype=None, compile_watch=False, mask=None): ------- numpy.array """ - return self._load(source=source, dtype=dtype, compile_watch=compile_watch, mask=mask) + return self._load( + source=source, dtype=dtype, compile_watch=compile_watch, mask=mask + ) def load_dataframe(self, source=None, dtype=None, compile_watch=False, mask=None): """ @@ -2708,7 +2852,10 @@ def init_streamer(self, source=None, dtype=None): selected_args = tuple(general_mapping[k] for k in named_args) len_self_raw_functions = len(self._raw_functions) tree_root_dims = source.root_dataset.sizes - argshape = tuple(tree_root_dims[i] for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude)) + argshape = tuple( + tree_root_dims[i] + for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude) + ) if len(argshape) == 1: linemaker = self._linemaker @@ -2740,6 +2887,8 @@ def streamer(c, out=None): return result else: - raise NotImplementedError(f"root tree with {len(argshape)} dims {argshape=}") + raise NotImplementedError( + f"root tree with {len(argshape)} dims {argshape=}" + ) return streamer diff --git a/sharrow/relationships.py b/sharrow/relationships.py index e2989e5..e86eb66 100644 --- a/sharrow/relationships.py +++ b/sharrow/relationships.py @@ -389,13 +389,17 @@ def shape(self): from .flows import presorted dim_order = presorted(self.root_dataset.dims, self.dim_order) - return tuple(self.root_dataset.dims[i] for i in dim_order if i not in self.dim_exclude) + return tuple( + self.root_dataset.dims[i] for i in dim_order if i not in self.dim_exclude + ) @property def root_dims(self): from .flows import presorted - return tuple(presorted(self.root_dataset.dims, self.dim_order, self.dim_exclude)) + return tuple( + presorted(self.root_dataset.dims, self.dim_order, self.dim_exclude) + ) def __shallow_copy_extras(self): return dict( @@ -422,7 +426,9 @@ def __repr__(self): if len(self._graph.edges): s += "\n relationships:" for e in self._graph.edges: - s += f"\n - {self._get_relationship(e)!r}".replace("") + s += f"\n - {self._get_relationship(e)!r}".replace( + "") else: s += "\n relationships: none" return s @@ -445,7 +451,9 @@ def _hash_features(self): h.append("datasets:none") if len(self._graph.edges): for e in self._graph.edges: - r = f"relationship:{self._get_relationship(e)!r}".replace("") + r = f"relationship:{self._get_relationship(e)!r}".replace( + "") h.append(r) else: h.append("relationships:none") @@ -468,7 +476,9 @@ def root_node_name(self, name): self._root_node_name = name return if not isinstance(name, str): - raise TypeError(f"root_node_name must be one of [str, None, False] not {type(name)}") + raise TypeError( + f"root_node_name must be one of [str, None, False] not {type(name)}" + ) if name not in self._graph.nodes: raise KeyError(name) self._root_node_name = name @@ -606,7 +616,9 @@ def root_dataset(self, x): self._graph.nodes[self.root_node_name]["dataset"] = x def _get_relationship(self, edge): - return Relationship(parent_data=edge[0], child_data=edge[1], **self._graph.edges[edge]) + return Relationship( + parent_data=edge[0], child_data=edge[1], **self._graph.edges[edge] + ) def __getitem__(self, item): return self.get(item) @@ -635,12 +647,19 @@ def get(self, item, default=None, broadcast=True, coords=True): if isinstance(item, (list, tuple)): from .dataset import Dataset - return Dataset({k: self.get(k, default=default, broadcast=broadcast, coords=coords) for k in item}) + return Dataset( + { + k: self.get(k, default=default, broadcast=broadcast, coords=coords) + for k in item + } + ) try: result = self._getitem(item, dim_names_from_top=True) except KeyError: try: - result = self._getitem(item, include_blank_dims=True, dim_names_from_top=True) + result = self._getitem( + item, include_blank_dims=True, dim_names_from_top=True + ) except KeyError: if default is None: raise @@ -721,7 +740,9 @@ def _getitem( return current_node if current_node == start_from: if by_dims: - return xr.DataArray(pd.RangeIndex(dataset.dims[item]), dims=item) + return xr.DataArray( + pd.RangeIndex(dataset.dims[item]), dims=item + ) else: return dataset[item] else: @@ -741,15 +762,22 @@ def _getitem( result = dataset[item] dims_in_result = set(result.dims) top_dim_names = {} - for path in nx.algorithms.simple_paths.all_simple_edge_paths(self._graph, start_from, current_node): + for path in nx.algorithms.simple_paths.all_simple_edge_paths( + self._graph, start_from, current_node + ): if dim_names_from_top: e = path[0] top_dim_name = self._graph.edges[e].get("parent_name") start_dataset = self._graph.nodes[start_from]["dataset"] # deconvert digitized dim names back to native dims - if top_dim_name not in start_dataset.dims and top_dim_name in start_dataset.variables: + if ( + top_dim_name not in start_dataset.dims + and top_dim_name in start_dataset.variables + ): if start_dataset.variables[top_dim_name].ndim == 1: - top_dim_name = start_dataset.variables[top_dim_name].dims[0] + top_dim_name = start_dataset.variables[ + top_dim_name + ].dims[0] else: top_dim_name = None path_dim = self._graph.edges[path[-1]].get("child_name") @@ -763,18 +791,36 @@ def _getitem( r_next = self._get_relationship(e_next) if t1 is None: t1 = self._graph.nodes[r.parent_data].get("dataset") - t2 = self._graph.nodes[r.child_data].get("dataset")[[r_next.parent_name]] + t2 = self._graph.nodes[r.child_data].get("dataset")[ + [r_next.parent_name] + ] if r.indexing == "label": - t1 = t2.sel({r.child_name: _dataarray_to_numpy(t1[r.parent_name])}) + t1 = t2.sel( + { + r.child_name: _dataarray_to_numpy( + t1[r.parent_name] + ) + } + ) else: # by position - t1 = t2.isel({r.child_name: _dataarray_to_numpy(t1[r.parent_name])}) + t1 = t2.isel( + { + r.child_name: _dataarray_to_numpy( + t1[r.parent_name] + ) + } + ) # final node in path e = path[-1] - r = Relationship(parent_data=e[0], child_data=e[1], **self._graph.edges[e]) + r = Relationship( + parent_data=e[0], child_data=e[1], **self._graph.edges[e] + ) if t1 is None: t1 = self._graph.nodes[r.parent_data].get("dataset") if r.indexing == "label": - _labels[r.child_name] = _dataarray_to_numpy(t1[r.parent_name]) + _labels[r.child_name] = _dataarray_to_numpy( + t1[r.parent_name] + ) else: # by position _idx = _dataarray_to_numpy(t1[r.parent_name]) if not np.issubdtype(_idx.dtype, np.integer): @@ -823,7 +869,11 @@ def get_expr(self, expression, engine="sharrow", allow_native=True): raise KeyError except (KeyError, IndexError): if engine == "sharrow": - result = self.setup_flow({expression: expression}).load_dataarray().isel(expressions=0) + result = ( + self.setup_flow({expression: expression}) + .load_dataarray() + .isel(expressions=0) + ) elif engine == "numexpr": from xarray import DataArray @@ -908,7 +958,9 @@ def dims(self): for name, length in v.dims.items(): if name in dims: if dims[name] != length: - raise ValueError("inconsistent dimensions\n" + self.dims_detail()) + raise ValueError( + "inconsistent dimensions\n" + self.dims_detail() + ) else: dims[name] = length return xr.core.utils.Frozen(dims) @@ -1006,7 +1058,11 @@ def get_indexes( return self._cached_indexes[(position_only, as_dict)] if not position_only: raise NotImplementedError - dims = [d for d in self.dims if d[-1:] != "_" or (d[-1:] == "_" and d[:-1] not in self.dims)] + dims = [ + d + for d in self.dims + if d[-1:] != "_" or (d[-1:] == "_" and d[:-1] not in self.dims) + ] if replacements is not None: obj = self.replace_datasets(replacements) else: @@ -1272,7 +1328,9 @@ def mapper_get(x, mapper=mapper): ) # candidate name for write back - r_parent_name_new = f"{self._BY_OFFSET}{r.parent_name}_{r.child_data}_{r.child_name}" + r_parent_name_new = ( + f"{self._BY_OFFSET}{r.parent_name}_{r.child_data}_{r.child_name}" + ) # it is common to have mirrored offsets in various dimensions. # we'd like to retain only the same data in memory once, so we'll @@ -1283,14 +1341,16 @@ def mapper_get(x, mapper=mapper): if p_dataset[k].equals(offsets): # we found a match, so we'll assign this name to # the match's memory storage instead of replicating it. - obj._graph.nodes[r.parent_data]["dataset"] = p_dataset.assign( - {r_parent_name_new: p_dataset[k]} - ) + obj._graph.nodes[r.parent_data][ + "dataset" + ] = p_dataset.assign({r_parent_name_new: p_dataset[k]}) # r_parent_name_new = k break else: # no existing offset arrays match, make this new one - obj._graph.nodes[r.parent_data]["dataset"] = p_dataset.assign({r_parent_name_new: offsets}) + obj._graph.nodes[r.parent_data]["dataset"] = p_dataset.assign( + {r_parent_name_new: offsets} + ) obj._graph.edges[e].update( dict( parent_name=r_parent_name_new, @@ -1311,7 +1371,9 @@ def relationships_are_digitized(self): return False return True - def _arg_tokenizer(self, spacename, spacearray, spacearrayname, exclude_dims=None, blends=None): + def _arg_tokenizer( + self, spacename, spacearray, spacearrayname, exclude_dims=None, blends=None + ): if blends is None: blends = {} @@ -1325,7 +1387,10 @@ def _arg_tokenizer(self, spacename, spacearray, spacearrayname, exclude_dims=Non else: from_dims = spacearray.dims return ( - tuple(ast.parse(f"_arg{root_dims.index(dim):02}", mode="eval").body for dim in from_dims), + tuple( + ast.parse(f"_arg{root_dims.index(dim):02}", mode="eval").body + for dim in from_dims + ), blends, ) @@ -1335,7 +1400,9 @@ def _arg_tokenizer(self, spacename, spacearray, spacearrayname, exclude_dims=Non spacearray_ = spacearray from_dims = spacearray_.dims - offset_source = spacearray_.attrs.get("digital_encoding", {}).get("offset_source", None) + offset_source = spacearray_.attrs.get("digital_encoding", {}).get( + "offset_source", None + ) if offset_source is not None: from_dims = self._graph.nodes[spacename]["dataset"][offset_source].dims @@ -1348,7 +1415,9 @@ def _arg_tokenizer(self, spacename, spacearray, spacearrayname, exclude_dims=Non this_dim_name = self._graph.edges[e]["child_name"] retarget = None if dimname != this_dim_name: - retarget = self._graph.nodes[spacename]["dataset"].redirection.target(this_dim_name) + retarget = self._graph.nodes[spacename][ + "dataset" + ].redirection.target(this_dim_name) if dimname != retarget: continue parent_name = self._graph.edges[e]["parent_name"] @@ -1423,7 +1492,9 @@ def get_index(self, dim): return subspace.indexes[dim] def copy(self): - return type(self)(self._graph.copy(), self.root_node_name, **self.__shallow_copy_extras()) + return type(self)( + self._graph.copy(), self.root_node_name, **self.__shallow_copy_extras() + ) def all_var_names(self, uniquify=False, _duplicated_names=None): ordered_names = [] diff --git a/sharrow/selectors.py b/sharrow/selectors.py index 927cfaa..b057f89 100644 --- a/sharrow/selectors.py +++ b/sharrow/selectors.py @@ -134,7 +134,6 @@ def _filter( ds_ = ds if _names: - result = ( getattr(ds_, _func)(**loaders) .digital_encoding.strip(_names) diff --git a/sharrow/shared_memory.py b/sharrow/shared_memory.py index dfc5ed7..d2b9e23 100644 --- a/sharrow/shared_memory.py +++ b/sharrow/shared_memory.py @@ -139,7 +139,9 @@ def open_shared_memory_array(key, mode="r+"): except FileNotFoundError: raise FileNotFoundError(f"sharrow_shared_memory_array:{key}") from None else: - logger.info(f"shared memory array from ephemeral memory, {si_units(result.size)}") + logger.info( + f"shared memory array from ephemeral memory, {si_units(result.size)}" + ) return result if backing.startswith("memmap:"): @@ -360,13 +362,17 @@ def emit(k, a, is_coord): mem_arr_p = np.ndarray( shape=(_size_p // ad.indptr.dtype.itemsize,), dtype=ad.indptr.dtype, - buffer=buffer[_pos + _size_d + _size_i : _pos + _size_d + _size_i + _size_p], + buffer=buffer[ + _pos + _size_d + _size_i : _pos + _size_d + _size_i + _size_p + ], ) mem_arr_d[:] = ad.data[:] mem_arr_i[:] = ad.indices[:] mem_arr_p[:] = ad.indptr[:] else: - mem_arr = np.ndarray(shape=a.shape, dtype=a.dtype, buffer=buffer[_pos : _pos + _size]) + mem_arr = np.ndarray( + shape=a.shape, dtype=a.dtype, buffer=buffer[_pos : _pos + _size] + ) if isinstance(a, xr.DataArray) and isinstance(a.data, da.Array): tasks.append(da.store(a.data, mem_arr, lock=False, compute=False)) else: @@ -377,7 +383,9 @@ def emit(k, a, is_coord): if key.startswith("memmap:"): mem.flush() - create_shared_list([pickle.dumps(self._obj.attrs)] + [pickle.dumps(i) for i in wrappers], key) + create_shared_list( + [pickle.dumps(self._obj.attrs)] + [pickle.dumps(i) for i in wrappers], key + ) return type(self).from_shared_memory(key, own_data=mem, mode=mode) @property @@ -464,7 +472,12 @@ def from_shared_memory(cls, key, own_data=False, mode="r+"): mem_arr_p = np.ndarray( _size_p // _dtype_p.itemsize, dtype=_dtype_p, - buffer=buffer[position + _size_d + _size_i : position + _size_d + _size_i + _size_p], + buffer=buffer[ + position + _size_d + _size_i : position + + _size_d + + _size_i + + _size_p + ], ) mem_arr = sparse.GCXS( ( @@ -476,7 +489,9 @@ def from_shared_memory(cls, key, own_data=False, mode="r+"): compressed_axes=(0,), ) else: - mem_arr = np.ndarray(shape, dtype=dtype, buffer=buffer[position : position + nbytes]) + mem_arr = np.ndarray( + shape, dtype=dtype, buffer=buffer[position : position + nbytes] + ) content[name] = DataArray(mem_arr, **t) obj = cls._parent_class(content) diff --git a/sharrow/sparse.py b/sharrow/sparse.py index 4fa7095..d96035f 100644 --- a/sharrow/sparse.py +++ b/sharrow/sparse.py @@ -70,7 +70,9 @@ def __init__(self, i, j, data, shape=None): if isinstance(data, scipy.sparse.csr_matrix): self._sparse_data = data else: - self._sparse_data = scipy.sparse.coo_matrix((data, (i, j)), shape=shape).tocsr() + self._sparse_data = scipy.sparse.coo_matrix( + (data, (i, j)), shape=shape + ).tocsr() self._sparse_data.sort_indices() def __getitem__(self, item): @@ -147,7 +149,9 @@ def apply_mapper(x): i_ = i j_ = j - sparse_data = sparse.GCXS(sparse.COO((i_, j_), data, shape=shape), compressed_axes=(0,)) + sparse_data = sparse.GCXS( + sparse.COO((i_, j_), data, shape=shape), compressed_axes=(0,) + ) self._obj[f"_s_{name}"] = xr.DataArray( sparse_data, dims=(i_dim, j_dim), @@ -237,7 +241,9 @@ def get_blended_2(backstop_value, indices, indptr, data, i, j, blend_limit=np.in @nb.njit -def get_blended_2_arr(backstop_values_, indices, indptr, data, i_, j_, blend_limit=np.inf): +def get_blended_2_arr( + backstop_values_, indices, indptr, data, i_, j_, blend_limit=np.inf +): out = np.zeros_like(backstop_values_) for z in range(backstop_values_.size): out[z] = get_blended_2( diff --git a/sharrow/wrappers.py b/sharrow/wrappers.py index 796b723..baeb1e5 100644 --- a/sharrow/wrappers.py +++ b/sharrow/wrappers.py @@ -109,10 +109,16 @@ def set_df(self, df): ------- self (to facilitate chaining) """ - assert self.orig_key in df, f"orig_key '{self.orig_key}' not in df columns: {list(df.columns)}" - assert self.dest_key in df, f"dest_key '{self.dest_key}' not in df columns: {list(df.columns)}" + assert ( + self.orig_key in df + ), f"orig_key '{self.orig_key}' not in df columns: {list(df.columns)}" + assert ( + self.dest_key in df + ), f"dest_key '{self.dest_key}' not in df columns: {list(df.columns)}" if self.time_key: - assert self.time_key in df, f"time_key '{self.time_key}' not in df columns: {list(df.columns)}" + assert ( + self.time_key in df + ), f"time_key '{self.time_key}' not in df columns: {list(df.columns)}" self.df = df # TODO allow non-1 offsets From 83ba6f58e51cb47c29918694ca760a28b3c7b327 Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Wed, 10 Jan 2024 13:49:58 -0600 Subject: [PATCH 04/14] no need for exit-non-zero-on-fix --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 27bcc2e..3e23f78 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: # Run the linter. - id: ruff types_or: [ python, pyi, jupyter ] - args: [ --fix, --exit-non-zero-on-fix ] + args: [ --fix ] # Run the formatter. - id: ruff-format types_or: [ python, pyi, jupyter ] From 7f4b3d38cdad546b2ef9eeabd02bd6de8b2f583d Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Wed, 10 Jan 2024 15:18:17 -0600 Subject: [PATCH 05/14] fix license in pyproject.toml --- pyproject.toml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3fee2d9..b9c060e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,6 @@ build-backend = "setuptools.build_meta" [project] name = "sharrow" -license = "BSD-3-Clause" requires-python = ">=3.9" dynamic = ["version"] dependencies = [ @@ -21,6 +20,13 @@ dependencies = [ "dask", "networkx", ] +classifiers = [ + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] [tool.setuptools_scm] fallback_version = "1999" From 2ff15d0103d0da2d28480244ccf434fc0deac8ca Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Wed, 10 Jan 2024 16:27:36 -0600 Subject: [PATCH 06/14] fix pyproject.toml --- README.md | 12 ++++++++++++ pyproject.toml | 10 ++++++++++ 2 files changed, 22 insertions(+) diff --git a/README.md b/README.md index b2e4e8f..c060d4e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,18 @@ # sharrow numba for ActivitySim-style spec files +## Building a Wheel + +To build a wheel for sharrow, you need to have `build` installed. You can +install it with `python -m pip install build`. + +Then run the builder: + +```shell +python -m build . +``` + + ## Building the documentation Building the documentation for sharrow requires JupyterBook. diff --git a/pyproject.toml b/pyproject.toml index b9c060e..1de31ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,16 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] +description = "numba for ActivitySim-style spec files" +readme = "README.md" +keywords = ["activitysim", "discrete choice"] + +[project.urls] +Documentation = "https://activitysim.github.io/sharrow/" +Repository = "https://github.com/activitysim/sharrow" + +[tool.setuptools] +packages = ["sharrow", "sharrow.utils"] [tool.setuptools_scm] fallback_version = "1999" From c45925340636fe6c162fb40ec83e161464f32ae8 Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Mon, 15 Jan 2024 16:42:50 -0600 Subject: [PATCH 07/14] fix pyproject --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1de31ab..ea92e39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,8 +43,8 @@ fallback_version = "1999" write_to = "sharrow/_version.py" [tool.ruff] -# Enable flake8-bugbear (`B`) and pyupgrade ('UP') rules. -select = [ "F", # Pyflakes +select = [ + "F", # Pyflakes "E", # Pycodestyle Errors "W", # Pycodestyle Warnings "I", # isort @@ -59,7 +59,7 @@ ignore = ["B905", "D1"] target-version = "py39" [tool.ruff.lint.isort] -known-first-party = ["larch"] +known-first-party = ["sharrow"] [tool.ruff.lint.pycodestyle] max-line-length = 120 From 33ac2a2da4f7866f738e5721ee581419c1552c42 Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Mon, 15 Jan 2024 20:00:19 -0600 Subject: [PATCH 08/14] update ci --- .github/workflows/run-tests.yml | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 79a20ba..04dec84 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -23,13 +23,9 @@ jobs: run: shell: bash -l {0} steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - uses: conda-incubator/setup-miniconda@v2 + - uses: actions/checkout@v4 + - name: Install Python and Dependencies + uses: conda-incubator/setup-miniconda@v3 with: miniforge-variant: Mambaforge miniforge-version: latest From b2ac54271c3b6bf96c1f9e2f9c633d6f876b6411 Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Tue, 16 Jan 2024 09:22:08 -0600 Subject: [PATCH 09/14] pre-commit note in readme --- README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/README.md b/README.md index c060d4e..829aef2 100644 --- a/README.md +++ b/README.md @@ -20,3 +20,23 @@ Building the documentation for sharrow requires JupyterBook. ```shell jupyterbook build docs ``` + +## Developer Note + +This repository's continuous integration testing will use `ruff` to check code +quality. There is a pre-commit hook that will run `ruff` on all staged files +to ensure that they pass the quality checks. To install and use this hook, +run the following commands: + +```shell +python -m pip install pre-commit # if needed +pre-commit install +``` + +Then, when you try to make a commit, your code will be checked locally to ensure +that your code passes the quality checks. If you want to run the checks manually, +you can do so with the following command: + +```shell +pre-commit run --all-files +``` From 820570197c5e5f93929818ea10eb8639227fd115 Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Tue, 16 Jan 2024 09:29:10 -0600 Subject: [PATCH 10/14] check notebooks but with fewer rules --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index ea92e39..86afc1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,11 @@ ignore-init-module-imports = true line-length = 88 ignore = ["B905", "D1"] target-version = "py39" +extend-include = ["*.ipynb"] +per-file-ignores = { "*.ipynb" = [ + "E402", # allow imports to appear anywhere in Jupyter Notebooks + "E501", # allow long lines in Jupyter Notebooks +] } [tool.ruff.lint.isort] known-first-party = ["sharrow"] From e9f5f6020b62b92840e80a374a8db9ca3ea4fc8e Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Tue, 16 Jan 2024 12:19:07 -0600 Subject: [PATCH 11/14] ruff apply to ipynb --- docs/walkthrough/encoding.ipynb | 229 ++++++++++++------- docs/walkthrough/one-dim.ipynb | 375 +++++++++++++++++++++----------- docs/walkthrough/sparse.ipynb | 245 ++++++++++++--------- docs/walkthrough/two-dim.ipynb | 342 ++++++++++++++++++++--------- sharrow/translate.py | 2 +- 5 files changed, 777 insertions(+), 416 deletions(-) diff --git a/docs/walkthrough/encoding.ipynb b/docs/walkthrough/encoding.ipynb index 7f07159..b1359a0 100644 --- a/docs/walkthrough/encoding.ipynb +++ b/docs/walkthrough/encoding.ipynb @@ -23,7 +23,8 @@ "source": [ "# HIDDEN\n", "import warnings\n", - "warnings.filterwarnings(\"ignore\", category=DeprecationWarning) " + "\n", + "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)" ] }, { @@ -38,9 +39,9 @@ "import numpy as np\n", "import pandas as pd\n", "import xarray as xr\n", - "from io import StringIO\n", "\n", "import sharrow as sh\n", + "\n", "sh.__version__" ] }, @@ -57,6 +58,7 @@ "source": [ "# check versions\n", "import packaging\n", + "\n", "assert packaging.version.parse(xr.__version__) >= packaging.version.parse(\"0.20.2\")" ] }, @@ -146,7 +148,7 @@ "metadata": {}, "outputs": [], "source": [ - "from sharrow.digital_encoding import array_encode, array_decode" + "from sharrow.digital_encoding import array_decode, array_encode" ] }, { @@ -165,7 +167,7 @@ "metadata": {}, "outputs": [], "source": [ - "skims.DIST.values[:2,:3]" + "skims.DIST.values[:2, :3]" ] }, { @@ -210,7 +212,7 @@ "outputs": [], "source": [ "distance_encoded = array_encode(skims.DIST, scale=0.01, offset=0)\n", - "distance_encoded.values[:2,:3]" + "distance_encoded.values[:2, :3]" ] }, { @@ -227,10 +229,14 @@ "# TEST encoding\n", "assert distance_encoded.dtype == np.int16\n", "np.testing.assert_array_equal(\n", - " distance_encoded.values[:2,:3],\n", - " np.array([[12, 24, 44], [37, 14, 28]], dtype=np.int16)\n", + " distance_encoded.values[:2, :3],\n", + " np.array([[12, 24, 44], [37, 14, 28]], dtype=np.int16),\n", ")\n", - "assert distance_encoded.attrs['digital_encoding'] == {'scale': 0.01, 'offset': 0, 'missing_value': None}" + "assert distance_encoded.attrs[\"digital_encoding\"] == {\n", + " \"scale\": 0.01,\n", + " \"offset\": 0,\n", + " \"missing_value\": None,\n", + "}" ] }, { @@ -249,9 +255,7 @@ "metadata": {}, "outputs": [], "source": [ - "skims_encoded = skims.assign(\n", - " {'DIST': array_encode(skims.DIST, scale=0.01, offset=0)}\n", - ")" + "skims_encoded = skims.assign({\"DIST\": array_encode(skims.DIST, scale=0.01, offset=0)})" ] }, { @@ -271,7 +275,9 @@ "metadata": {}, "outputs": [], "source": [ - "skims_encoded = skims_encoded.digital_encoding.set(['DISTWALK', 'DISTBIKE'], scale=0.01, offset=0)" + "skims_encoded = skims_encoded.digital_encoding.set(\n", + " [\"DISTWALK\", \"DISTBIKE\"], scale=0.01, offset=0\n", + ")" ] }, { @@ -305,9 +311,9 @@ "source": [ "# TEST\n", "assert skims_encoded.digital_encoding.info() == {\n", - " 'DIST': {'scale': 0.01, 'offset': 0, 'missing_value': None},\n", - " 'DISTBIKE': {'scale': 0.01, 'offset': 0, 'missing_value': None},\n", - " 'DISTWALK': {'scale': 0.01, 'offset': 0, 'missing_value': None},\n", + " \"DIST\": {\"scale\": 0.01, \"offset\": 0, \"missing_value\": None},\n", + " \"DISTBIKE\": {\"scale\": 0.01, \"offset\": 0, \"missing_value\": None},\n", + " \"DISTWALK\": {\"scale\": 0.01, \"offset\": 0, \"missing_value\": None},\n", "}" ] }, @@ -330,16 +336,16 @@ "metadata": {}, "outputs": [], "source": [ - "pairs = pd.DataFrame({'orig': [0,0,0,1,1,1], 'dest': [0,1,2,0,1,2]})\n", + "pairs = pd.DataFrame({\"orig\": [0, 0, 0, 1, 1, 1], \"dest\": [0, 1, 2, 0, 1, 2]})\n", "tree = sh.DataTree(\n", - " base=pairs, \n", - " skims=skims.drop_dims('time_period'), \n", + " base=pairs,\n", + " skims=skims.drop_dims(\"time_period\"),\n", " relationships=(\n", " \"base.orig -> skims.otaz\",\n", " \"base.dest -> skims.dtaz\",\n", " ),\n", ")\n", - "flow = tree.setup_flow({'d1': 'DIST', 'd2': 'DIST**2'})\n", + "flow = tree.setup_flow({\"d1\": \"DIST\", \"d2\": \"DIST**2\"})\n", "arr = flow.load()\n", "arr" ] @@ -361,14 +367,14 @@ "outputs": [], "source": [ "tree_enc = sh.DataTree(\n", - " base=pairs, \n", - " skims=skims_encoded.drop_dims('time_period'), \n", + " base=pairs,\n", + " skims=skims_encoded.drop_dims(\"time_period\"),\n", " relationships=(\n", " \"base.orig -> skims.otaz\",\n", " \"base.dest -> skims.dtaz\",\n", " ),\n", ")\n", - "flow_enc = tree_enc.setup_flow({'d1': 'DIST', 'd2': 'DIST**2'})\n", + "flow_enc = tree_enc.setup_flow({\"d1\": \"DIST\", \"d2\": \"DIST**2\"})\n", "arr_enc = flow_enc.load()\n", "arr_enc" ] @@ -440,7 +446,7 @@ "metadata": {}, "outputs": [], "source": [ - "skims.WLK_LOC_WLK_FAR.values[:2,:3,:]" + "skims.WLK_LOC_WLK_FAR.values[:2, :3, :]" ] }, { @@ -460,7 +466,7 @@ "outputs": [], "source": [ "wlwfare_enc = array_encode(skims.WLK_LOC_WLK_FAR, bitwidth=8, by_dict=True)\n", - "wlwfare_enc.values[:2,:3,:]" + "wlwfare_enc.values[:2, :3, :]" ] }, { @@ -470,7 +476,7 @@ "metadata": {}, "outputs": [], "source": [ - "wlwfare_enc.attrs['digital_encoding']['dictionary']" + "wlwfare_enc.attrs[\"digital_encoding\"][\"dictionary\"]" ] }, { @@ -487,18 +493,18 @@ "# TEST encoding\n", "assert wlwfare_enc.dtype == np.uint8\n", "np.testing.assert_array_equal(\n", - " wlwfare_enc.values[:2,:3,:],\n", - " np.array([[[0, 0, 0, 0, 0],\n", - " [1, 2, 2, 1, 2],\n", - " [1, 2, 2, 1, 2]],\n", - "\n", - " [[1, 1, 2, 2, 1],\n", - " [0, 0, 0, 0, 0],\n", - " [1, 2, 2, 1, 2]]], dtype=np.uint8)\n", + " wlwfare_enc.values[:2, :3, :],\n", + " np.array(\n", + " [\n", + " [[0, 0, 0, 0, 0], [1, 2, 2, 1, 2], [1, 2, 2, 1, 2]],\n", + " [[1, 1, 2, 2, 1], [0, 0, 0, 0, 0], [1, 2, 2, 1, 2]],\n", + " ],\n", + " dtype=np.uint8,\n", + " ),\n", ")\n", "np.testing.assert_array_equal(\n", - " wlwfare_enc.attrs['digital_encoding']['dictionary'],\n", - " np.array([ 0., 152., 474., 626.], dtype=np.float32)\n", + " wlwfare_enc.attrs[\"digital_encoding\"][\"dictionary\"],\n", + " np.array([0.0, 152.0, 474.0, 626.0], dtype=np.float32),\n", ")" ] }, @@ -561,12 +567,14 @@ "outputs": [], "source": [ "skims1 = skims.digital_encoding.set(\n", - " ['WLK_LOC_WLK_FAR', \n", - " 'WLK_EXP_WLK_FAR', \n", - " 'WLK_HVY_WLK_FAR', \n", - " 'DRV_LOC_WLK_FAR',\n", - " 'DRV_HVY_WLK_FAR',\n", - " 'DRV_EXP_WLK_FAR'],\n", + " [\n", + " \"WLK_LOC_WLK_FAR\",\n", + " \"WLK_EXP_WLK_FAR\",\n", + " \"WLK_HVY_WLK_FAR\",\n", + " \"DRV_LOC_WLK_FAR\",\n", + " \"DRV_HVY_WLK_FAR\",\n", + " \"DRV_EXP_WLK_FAR\",\n", + " ],\n", " joint_dict=True,\n", ")" ] @@ -591,8 +599,7 @@ "outputs": [], "source": [ "skims1 = skims1.digital_encoding.set(\n", - " ['DISTBIKE', \n", - " 'DISTWALK'],\n", + " [\"DISTBIKE\", \"DISTWALK\"],\n", " joint_dict=\"jointWB\",\n", ")" ] @@ -638,9 +645,9 @@ "outputs": [], "source": [ "tree1 = sh.DataTree(\n", - " base=pairs, \n", - " skims=skims1, \n", - " rskims=skims1, \n", + " base=pairs,\n", + " skims=skims1,\n", + " rskims=skims1,\n", " relationships=(\n", " \"base.orig -> skims.otaz\",\n", " \"base.dest -> skims.dtaz\",\n", @@ -648,15 +655,18 @@ " \"base.dest -> rskims.otaz\",\n", " ),\n", ")\n", - "flow1 = tree1.setup_flow({\n", - " 'd1': 'skims[\"WLK_LOC_WLK_FAR\", \"AM\"]', \n", - " 'd2': 'skims[\"WLK_LOC_WLK_FAR\", \"AM\"]**2',\n", - " 'w1': 'skims.DISTWALK',\n", - " 'w2': 'skims.reverse(\"DISTWALK\")',\n", - " 'w3': 'rskims.DISTWALK',\n", - " 'x1': 'skims.DIST',\n", - " 'x2': 'skims.reverse(\"DIST\")',\n", - "}, hashing_level=2)\n", + "flow1 = tree1.setup_flow(\n", + " {\n", + " \"d1\": 'skims[\"WLK_LOC_WLK_FAR\", \"AM\"]',\n", + " \"d2\": 'skims[\"WLK_LOC_WLK_FAR\", \"AM\"]**2',\n", + " \"w1\": \"skims.DISTWALK\",\n", + " \"w2\": 'skims.reverse(\"DISTWALK\")',\n", + " \"w3\": \"rskims.DISTWALK\",\n", + " \"x1\": \"skims.DIST\",\n", + " \"x2\": 'skims.reverse(\"DIST\")',\n", + " },\n", + " hashing_level=2,\n", + ")\n", "arr1 = flow1.load_dataframe()\n", "arr1" ] @@ -669,13 +679,72 @@ "outputs": [], "source": [ "# TEST\n", - "assert (arr1 == np.array([[ 0.00000e+00, 0.00000e+00, 1.20000e-01, 1.20000e-01, 1.20000e-01, 1.20000e-01, 1.20000e-01],\n", - " [ 4.74000e+02, 2.24676e+05, 2.40000e-01, 3.70000e-01, 3.70000e-01, 2.40000e-01, 3.70000e-01],\n", - " [ 4.74000e+02, 2.24676e+05, 4.40000e-01, 5.70000e-01, 5.70000e-01, 4.40000e-01, 5.70000e-01],\n", - " [ 1.52000e+02, 2.31040e+04, 3.70000e-01, 2.40000e-01, 2.40000e-01, 3.70000e-01, 2.40000e-01],\n", - " [ 0.00000e+00, 0.00000e+00, 1.40000e-01, 1.40000e-01, 1.40000e-01, 1.40000e-01, 1.40000e-01],\n", - " [ 4.74000e+02, 2.24676e+05, 2.80000e-01, 2.80000e-01, 2.80000e-01, 2.80000e-01, 2.80000e-01]],\n", - " dtype=np.float32)).all().all()" + "assert (\n", + " (\n", + " arr1\n", + " == np.array(\n", + " [\n", + " [\n", + " 0.00000e00,\n", + " 0.00000e00,\n", + " 1.20000e-01,\n", + " 1.20000e-01,\n", + " 1.20000e-01,\n", + " 1.20000e-01,\n", + " 1.20000e-01,\n", + " ],\n", + " [\n", + " 4.74000e02,\n", + " 2.24676e05,\n", + " 2.40000e-01,\n", + " 3.70000e-01,\n", + " 3.70000e-01,\n", + " 2.40000e-01,\n", + " 3.70000e-01,\n", + " ],\n", + " [\n", + " 4.74000e02,\n", + " 2.24676e05,\n", + " 4.40000e-01,\n", + " 5.70000e-01,\n", + " 5.70000e-01,\n", + " 4.40000e-01,\n", + " 5.70000e-01,\n", + " ],\n", + " [\n", + " 1.52000e02,\n", + " 2.31040e04,\n", + " 3.70000e-01,\n", + " 2.40000e-01,\n", + " 2.40000e-01,\n", + " 3.70000e-01,\n", + " 2.40000e-01,\n", + " ],\n", + " [\n", + " 0.00000e00,\n", + " 0.00000e00,\n", + " 1.40000e-01,\n", + " 1.40000e-01,\n", + " 1.40000e-01,\n", + " 1.40000e-01,\n", + " 1.40000e-01,\n", + " ],\n", + " [\n", + " 4.74000e02,\n", + " 2.24676e05,\n", + " 2.80000e-01,\n", + " 2.80000e-01,\n", + " 2.80000e-01,\n", + " 2.80000e-01,\n", + " 2.80000e-01,\n", + " ],\n", + " ],\n", + " dtype=np.float32,\n", + " )\n", + " )\n", + " .all()\n", + " .all()\n", + ")" ] }, { @@ -686,11 +755,13 @@ "outputs": [], "source": [ "# TEST\n", - "assert skims1.digital_encoding.baggage(['WLK_LOC_WLK_FAR']) == {'joined_0_offsets'}\n", - "assert (skims1.iat(\n", - " otaz=[0,1,2], dtaz=[0,0,0], time_period=[1,1,1],\n", - " _name='WLK_LOC_WLK_FAR'\n", - ").to_series() == [0,152,474]).all()" + "assert skims1.digital_encoding.baggage([\"WLK_LOC_WLK_FAR\"]) == {\"joined_0_offsets\"}\n", + "assert (\n", + " skims1.iat(\n", + " otaz=[0, 1, 2], dtaz=[0, 0, 0], time_period=[1, 1, 1], _name=\"WLK_LOC_WLK_FAR\"\n", + " ).to_series()\n", + " == [0, 152, 474]\n", + ").all()" ] }, { @@ -720,8 +791,10 @@ "outputs": [], "source": [ "hh = sh.example_data.get_households()\n", - "hh[\"income_grp\"] = pd.cut(hh.income, bins=[-np.inf,30000,60000,np.inf], labels=['Low', \"Mid\", \"High\"])\n", - "hh = hh[[\"income\",\"income_grp\"]]\n", + "hh[\"income_grp\"] = pd.cut(\n", + " hh.income, bins=[-np.inf, 30000, 60000, np.inf], labels=[\"Low\", \"Mid\", \"High\"]\n", + ")\n", + "hh = hh[[\"income\", \"income_grp\"]]\n", "hh.head()" ] }, @@ -754,7 +827,7 @@ }, "outputs": [], "source": [ - "hh_dataset = sh.dataset.construct(hh[[\"income\",\"income_grp\"]])\n", + "hh_dataset = sh.dataset.construct(hh[[\"income\", \"income_grp\"]])\n", "hh_dataset" ] }, @@ -793,9 +866,12 @@ "source": [ "# TESTING\n", "assert hh_dataset[\"income_grp\"].dtype == \"int8\"\n", - "assert hh_dataset[\"income_grp\"].digital_encoding.keys() == {'dictionary', 'ordered'}\n", - "assert all(hh_dataset[\"income_grp\"].digital_encoding['dictionary'] == np.array(['Low', 'Mid', 'High'], dtype='= packaging.version.parse(\"0.20.2\")" ] }, @@ -84,7 +87,7 @@ "source": [ "# TEST households content\n", "assert len(households) == 5000\n", - "assert \"income\" in households \n", + "assert \"income\" in households\n", "assert households.index.name == \"HHID\"" ] }, @@ -112,7 +115,7 @@ "source": [ "assert len(persons) == 8212\n", "assert \"household_id\" in persons\n", - "assert persons.index.name == 'PERID'" + "assert persons.index.name == \"PERID\"" ] }, { @@ -178,13 +181,17 @@ "source": [ "def random_tours(n_tours=100_000, seed=42):\n", " rng = np.random.default_rng(seed)\n", - " n_zones = skims.dims['dtaz']\n", - " return pd.DataFrame({\n", - " 'PERID': rng.choice(persons.index, size=n_tours),\n", - " 'dest_taz_idx': rng.choice(n_zones, size=n_tours),\n", - " 'out_time_period': rng.choice(skims.time_period, size=n_tours),\n", - " 'in_time_period': rng.choice(skims.time_period, size=n_tours),\n", - " }).rename_axis(\"TOURIDX\")\n", + " n_zones = skims.dims[\"dtaz\"]\n", + " return pd.DataFrame(\n", + " {\n", + " \"PERID\": rng.choice(persons.index, size=n_tours),\n", + " \"dest_taz_idx\": rng.choice(n_zones, size=n_tours),\n", + " \"out_time_period\": rng.choice(skims.time_period, size=n_tours),\n", + " \"in_time_period\": rng.choice(skims.time_period, size=n_tours),\n", + " }\n", + " ).rename_axis(\"TOURIDX\")\n", + "\n", + "\n", "tours = random_tours()\n", "tours.head()" ] @@ -269,7 +276,7 @@ "metadata": {}, "outputs": [], "source": [ - "spec = pd.read_csv(StringIO(mini_spec), index_col='Label')\n", + "spec = pd.read_csv(StringIO(mini_spec), index_col=\"Label\")\n", "spec" ] }, @@ -286,7 +293,7 @@ "source": [ "# TEST check spec\n", "assert spec.index.name == \"Label\"\n", - "assert all(spec.columns == ['Expression', 'DRIVE', 'WALK', 'TRANSIT'])" + "assert all(spec.columns == [\"Expression\", \"DRIVE\", \"WALK\", \"TRANSIT\"])" ] }, { @@ -309,7 +316,7 @@ "metadata": {}, "outputs": [], "source": [ - "income_breakpoints = nb.typed.Dict.empty(nb.types.int32,nb.types.int32)\n", + "income_breakpoints = nb.typed.Dict.empty(nb.types.int32, nb.types.int32)\n", "income_breakpoints[0] = 15000\n", "income_breakpoints[1] = 30000\n", "income_breakpoints[2] = 60000\n", @@ -331,12 +338,12 @@ " \"tour.in_time_period @ dot_skims.time_period\",\n", " ),\n", " extra_vars={\n", - " 'shortwait': 3,\n", - " 'one': 1,\n", + " \"shortwait\": 3,\n", + " \"one\": 1,\n", " },\n", " aux_vars={\n", - " 'short_i_wait_mult': 0.75,\n", - " 'income_breakpoints': income_breakpoints,\n", + " \"short_i_wait_mult\": 0.75,\n", + " \"income_breakpoints\": income_breakpoints,\n", " },\n", ")" ] @@ -410,9 +417,9 @@ "outputs": [], "source": [ "# TEST\n", - "from pytest import approx\n", - "assert flow.tree.aux_vars['short_i_wait_mult'] == 0.75\n", - "assert flow.tree.aux_vars['income_breakpoints'][2] == 60000" + "\n", + "assert flow.tree.aux_vars[\"short_i_wait_mult\"] == 0.75\n", + "assert flow.tree.aux_vars[\"income_breakpoints\"][2] == 60000" ] }, { @@ -439,16 +446,21 @@ "# TEST utility data\n", "assert flow.check_cache_misses(fresh=False)\n", "actual = flow.load()\n", - "expected = np.array([[ 9.4 , 16.9572 , 4.5 , 0. , 1. ],\n", - " [ 9.32 , 14.3628 , 4.5 , 1. , 1. ],\n", - " [ 7.62 , 11.0129 , 4.5 , 1. , 1. ],\n", - " [ 4.25 , 7.6692 , 2.50065 , 0. , 1. ],\n", - " [ 6.16 , 8.2186 , 3.387825, 0. , 1. ],\n", - " [ 4.86 , 4.9288 , 4.5 , 0. , 1. ],\n", - " [ 1.07 , 0. , 0. , 0. , 1. ],\n", - " [ 8.52 , 11.615499, 3.260325, 0. , 1. ],\n", - " [ 11.74 , 16.2798 , 3.440325, 0. , 1. ],\n", - " [ 10.48 , 13.3974 , 3.942825, 0. , 1. ]], dtype=np.float32)\n", + "expected = np.array(\n", + " [\n", + " [9.4, 16.9572, 4.5, 0.0, 1.0],\n", + " [9.32, 14.3628, 4.5, 1.0, 1.0],\n", + " [7.62, 11.0129, 4.5, 1.0, 1.0],\n", + " [4.25, 7.6692, 2.50065, 0.0, 1.0],\n", + " [6.16, 8.2186, 3.387825, 0.0, 1.0],\n", + " [4.86, 4.9288, 4.5, 0.0, 1.0],\n", + " [1.07, 0.0, 0.0, 0.0, 1.0],\n", + " [8.52, 11.615499, 3.260325, 0.0, 1.0],\n", + " [11.74, 16.2798, 3.440325, 0.0, 1.0],\n", + " [10.48, 13.3974, 3.942825, 0.0, 1.0],\n", + " ],\n", + " dtype=np.float32,\n", + ")\n", "\n", "np.testing.assert_array_almost_equal(actual[:5], expected[:5])\n", "np.testing.assert_array_almost_equal(actual[-5:], expected[-5:])\n", @@ -483,8 +495,11 @@ "# TEST compile flags\n", "flow.load(compile_watch=False)\n", "import pytest\n", + "\n", "with pytest.raises(AttributeError):\n", - " flow.compiled_recently # attribute does not exist if compile_watch flag is off" + " compiled_recently = (\n", + " flow.compiled_recently\n", + " ) # attribute does not exist if compile_watch flag is off" ] }, { @@ -542,8 +557,9 @@ "source": [ "# TEST\n", "from pytest import approx\n", - "assert tree_2.aux_vars['short_i_wait_mult'] == 0.75\n", - "assert tree_2.aux_vars['income_breakpoints'][2] == approx(60000)" + "\n", + "assert tree_2.aux_vars[\"short_i_wait_mult\"] == 0.75\n", + "assert tree_2.aux_vars[\"income_breakpoints\"][2] == approx(60000)" ] }, { @@ -565,18 +581,23 @@ "source": [ "# TEST that aux_vars also work with arrays\n", "tree_a = tree_2.replace_datasets(tour=tours)\n", - "tree_a.aux_vars['income_breakpoints'] = np.asarray([1,2,60000])\n", + "tree_a.aux_vars[\"income_breakpoints\"] = np.asarray([1, 2, 60000])\n", "actual = flow.load(tree_a)\n", - "expected = np.array([[ 9.4 , 16.9572 , 4.5 , 0. , 1. ],\n", - " [ 9.32 , 14.3628 , 4.5 , 1. , 1. ],\n", - " [ 7.62 , 11.0129 , 4.5 , 1. , 1. ],\n", - " [ 4.25 , 7.6692 , 2.50065 , 0. , 1. ],\n", - " [ 6.16 , 8.2186 , 3.387825, 0. , 1. ],\n", - " [ 4.86 , 4.9288 , 4.5 , 0. , 1. ],\n", - " [ 1.07 , 0. , 0. , 0. , 1. ],\n", - " [ 8.52 , 11.615499, 3.260325, 0. , 1. ],\n", - " [ 11.74 , 16.2798 , 3.440325, 0. , 1. ],\n", - " [ 10.48 , 13.3974 , 3.942825, 0. , 1. ]], dtype=np.float32)\n", + "expected = np.array(\n", + " [\n", + " [9.4, 16.9572, 4.5, 0.0, 1.0],\n", + " [9.32, 14.3628, 4.5, 1.0, 1.0],\n", + " [7.62, 11.0129, 4.5, 1.0, 1.0],\n", + " [4.25, 7.6692, 2.50065, 0.0, 1.0],\n", + " [6.16, 8.2186, 3.387825, 0.0, 1.0],\n", + " [4.86, 4.9288, 4.5, 0.0, 1.0],\n", + " [1.07, 0.0, 0.0, 0.0, 1.0],\n", + " [8.52, 11.615499, 3.260325, 0.0, 1.0],\n", + " [11.74, 16.2798, 3.440325, 0.0, 1.0],\n", + " [10.48, 13.3974, 3.942825, 0.0, 1.0],\n", + " ],\n", + " dtype=np.float32,\n", + ")\n", "\n", "np.testing.assert_array_almost_equal(actual[:5], expected[:5])\n", "np.testing.assert_array_almost_equal(actual[-5:], expected[-5:])\n", @@ -617,15 +638,20 @@ "# TEST df\n", "assert len(df) == len(tours)\n", "pd.testing.assert_index_equal(\n", - " df.columns, \n", - " pd.Index(['Drive Time', 'Transit IVT', 'Transit Wait Time', 'Income', 'Constant']),\n", + " df.columns,\n", + " pd.Index([\"Drive Time\", \"Transit IVT\", \"Transit Wait Time\", \"Income\", \"Constant\"]),\n", ")\n", - "expected_df_head = pd.read_csv(StringIO(''',Drive Time,Transit IVT,Transit Wait Time,Income,Constant\n", + "expected_df_head = pd.read_csv(\n", + " StringIO(\n", + " \"\"\",Drive Time,Transit IVT,Transit Wait Time,Income,Constant\n", "0,9.4,16.9572,4.5,0.0,1.0\n", "1,9.32,14.3628,4.5,1.0,1.0\n", "2,7.62,11.0129,4.5,1.0,1.0\n", "3,4.25,7.6692,2.50065,0.0,1.0\n", - "4,6.16,8.2186,3.387825,0.0,1.0'''), index_col=0).astype(np.float32)\n", + "4,6.16,8.2186,3.387825,0.0,1.0\"\"\"\n", + " ),\n", + " index_col=0,\n", + ").astype(np.float32)\n", "pd.testing.assert_frame_equal(df.head(), expected_df_head)" ] }, @@ -651,7 +677,7 @@ "outputs": [], "source": [ "x = flow.load()\n", - "b = spec.iloc[:,1:].fillna(0).astype(np.float32).values\n", + "b = spec.iloc[:, 1:].fillna(0).astype(np.float32).values\n", "np.dot(x, b)" ] }, @@ -672,7 +698,17 @@ "metadata": {}, "outputs": [], "source": [ - "%time u = flow.dot(b)\n", + "%time flow.dot(b)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5776822fb0889df", + "metadata": {}, + "outputs": [], + "source": [ + "u = flow.dot(b)\n", "u" ] }, @@ -747,8 +783,7 @@ "outputs": [], "source": [ "B = xr.DataArray(\n", - " spec.iloc[:,1:].fillna(0).astype(np.float32), \n", - " dims=('expressions','modes')\n", + " spec.iloc[:, 1:].fillna(0).astype(np.float32), dims=(\"expressions\", \"modes\")\n", ")\n", "flow.dot_dataarray(B, source=tree_2)" ] @@ -788,6 +823,16 @@ "was computed for each chosen alternative. " ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "d54d71021951470b", + "metadata": {}, + "outputs": [], + "source": [ + "choices, choice_probs = flow.logit_draws(b, draws)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -809,6 +854,16 @@ "milliseconds more time than just computing the utilities." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "eec9ebd14ff646eb", + "metadata": {}, + "outputs": [], + "source": [ + "choices2, choice_probs2 = flow.logit_draws(b, draws, source=tree_2)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -868,9 +923,9 @@ "source": [ "# TEST mnl choices\n", "uz = np.exp(flow.dot(b))\n", - "uz = uz / uz.sum(1)[:,None]\n", + "uz = uz / uz.sum(1)[:, None]\n", "np.testing.assert_array_almost_equal(\n", - " uz[range(uz.shape[0]),choices.ravel()],\n", + " uz[range(uz.shape[0]), choices.ravel()],\n", " choice_probs.ravel(),\n", ")" ] @@ -926,12 +981,12 @@ "\"\"\"\n", "\n", "import yaml\n", + "\n", "from sharrow.nested_logit import construct_nesting_tree\n", "\n", - "nesting_settings = yaml.safe_load(nesting_settings)['NESTS']\n", + "nesting_settings = yaml.safe_load(nesting_settings)[\"NESTS\"]\n", "nest_tree = construct_nesting_tree(\n", - " alternatives=spec.columns[1:],\n", - " nesting_settings=nesting_settings\n", + " alternatives=spec.columns[1:], nesting_settings=nesting_settings\n", ")" ] }, @@ -965,7 +1020,9 @@ "metadata": {}, "outputs": [], "source": [ - "nesting = nest_tree.as_arrays(trim=True, parameter_dict={'coef_nest_motor': 0.5, 'coef_nest_root': 1.0})" + "nesting = nest_tree.as_arrays(\n", + " trim=True, parameter_dict={\"coef_nest_motor\": 0.5, \"coef_nest_root\": 1.0}\n", + ")" ] }, { @@ -1023,8 +1080,11 @@ "source": [ "# TEST devolve NL to MNL\n", "choices_nl_1, choice_probs_nl_1 = flow.logit_draws(\n", - " b, draws, \n", - " nesting=nest_tree.as_arrays(trim=True, parameter_dict={'coef_nest_motor': 1.0, 'coef_nest_root': 1.0}),\n", + " b,\n", + " draws,\n", + " nesting=nest_tree.as_arrays(\n", + " trim=True, parameter_dict={\"coef_nest_motor\": 1.0, \"coef_nest_root\": 1.0}\n", + " ),\n", ")\n", "assert (choices_nl_1 == choices).all()\n", "assert choice_probs == approx(choice_probs_nl_1)" @@ -1055,23 +1115,28 @@ "metadata": {}, "outputs": [], "source": [ - "# TEST \n", - "_ch, _pr, _pc, _ls = flow.logit_draws(b, draws, source=tree_2, nesting=nesting, logsums=1)\n", + "# TEST\n", + "_ch, _pr, _pc, _ls = flow.logit_draws(\n", + " b, draws, source=tree_2, nesting=nesting, logsums=1\n", + ")\n", "assert _ch is None\n", "assert _pr is None\n", "assert _pc is None\n", "assert _ls.size == 100000\n", "np.testing.assert_array_almost_equal(\n", - " _ls[:5],\n", - " [ 0.532791, 0.490935, 0.557529, 0.556371, 0.54812 ]\n", + " _ls[:5], [0.532791, 0.490935, 0.557529, 0.556371, 0.54812]\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " _ls[-5:],\n", - " [ 0.452682, 0.465422, 0.554312, 0.525064, 0.515226 ]\n", + " _ls[-5:], [0.452682, 0.465422, 0.554312, 0.525064, 0.515226]\n", ")\n", "\n", "_ch, _pr, _pc, _ls = flow.logit_draws(\n", - " b, draws, source=tree_2, nesting=nesting, logsums=1, as_dataarray=True,\n", + " b,\n", + " draws,\n", + " source=tree_2,\n", + " nesting=nesting,\n", + " logsums=1,\n", + " as_dataarray=True,\n", ")\n", "assert _ch is None\n", "assert _pr is None\n", @@ -1091,7 +1156,9 @@ "# TEST masking\n", "masker = np.zeros(draws.shape, dtype=np.int8)\n", "masker[::2] = 1\n", - "_ch_m, _pr_m, _pc_m, _ls_m = flow.logit_draws(b, draws, source=tree_2, nesting=nesting, logsums=1, mask=masker)\n", + "_ch_m, _pr_m, _pc_m, _ls_m = flow.logit_draws(\n", + " b, draws, source=tree_2, nesting=nesting, logsums=1, mask=masker\n", + ")\n", "\n", "assert _ls_m == approx(np.where(masker, _ls, 0))\n", "assert (_ch_m, _pr_m, _pc_m) == (None, None, None)" @@ -1126,37 +1193,31 @@ "metadata": {}, "outputs": [], "source": [ - "# TEST \n", - "_ch, _pr, _pc, _ls = flow.logit_draws(b, draws, source=tree_2, nesting=nesting, logsums=2)\n", + "# TEST\n", + "_ch, _pr, _pc, _ls = flow.logit_draws(\n", + " b, draws, source=tree_2, nesting=nesting, logsums=2\n", + ")\n", "assert _ch.size == 100000\n", "assert _pr.size == 100000\n", "assert _pc is None\n", "assert _ls.size == 100000\n", + "np.testing.assert_array_almost_equal(_ch[:5], [1, 2, 1, 1, 1])\n", + "np.testing.assert_array_almost_equal(_ch[-5:], [0, 1, 0, 1, 0])\n", "np.testing.assert_array_almost_equal(\n", - " _ch[:5],\n", - " [ 1, 2, 1, 1, 1 ]\n", - ")\n", - "np.testing.assert_array_almost_equal(\n", - " _ch[-5:],\n", - " [ 0, 1, 0, 1, 0 ]\n", + " _pr[:5], [0.393454, 0.16956, 0.38384, 0.384285, 0.387469]\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " _pr[:5],\n", - " [ 0.393454, 0.16956 , 0.38384 , 0.384285, 0.387469 ]\n", + " _pr[-5:], [0.503606, 0.420874, 0.478898, 0.396506, 0.468742]\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " _pr[-5:],\n", - " [ 0.503606, 0.420874, 0.478898, 0.396506, 0.468742 ]\n", + " _ls[:5], [0.532791, 0.490935, 0.557529, 0.556371, 0.54812]\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " _ls[:5],\n", - " [ 0.532791, 0.490935, 0.557529, 0.556371, 0.54812 ]\n", + " _ls[-5:], [0.452682, 0.465422, 0.554312, 0.525064, 0.515226]\n", ")\n", - "np.testing.assert_array_almost_equal(\n", - " _ls[-5:],\n", - " [ 0.452682, 0.465422, 0.554312, 0.525064, 0.515226 ]\n", + "_ch, _pr, _pc, _ls = flow.logit_draws(\n", + " b, draws, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True\n", ")\n", - "_ch, _pr, _pc, _ls = flow.logit_draws(b, draws, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True)\n", "assert _ch.size == 100000\n", "assert _ch.dims == (\"TOURIDX\",)\n", "assert _ch.shape == (100000,)\n", @@ -1177,23 +1238,33 @@ "source": [ "# TEST\n", "draws_many = np.random.default_rng(42).random(size=(tree.shape[0], 5))\n", - "_ch, _pr, _pc, _ls = flow.logit_draws(b, draws_many, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True)\n", - "assert _ch.dims == ('TOURIDX', 'DRAW')\n", + "_ch, _pr, _pc, _ls = flow.logit_draws(\n", + " b, draws_many, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True\n", + ")\n", + "assert _ch.dims == (\"TOURIDX\", \"DRAW\")\n", "assert _ch.shape == (100000, 5)\n", - "assert _pr.dims == ('TOURIDX', 'DRAW')\n", + "assert _pr.dims == (\"TOURIDX\", \"DRAW\")\n", "assert _pr.shape == (100000, 5)\n", - "assert _ls.dims == ('TOURIDX', )\n", - "assert _ls.shape == (100000, )\n", + "assert _ls.dims == (\"TOURIDX\",)\n", + "assert _ls.shape == (100000,)\n", "assert _pc is None\n", "\n", - "_ch, _pr, _pc, _ls = flow.logit_draws(b, draws_many, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True, pick_counted=True)\n", - "assert _ch.dims == ('TOURIDX', 'DRAW')\n", + "_ch, _pr, _pc, _ls = flow.logit_draws(\n", + " b,\n", + " draws_many,\n", + " source=tree_2,\n", + " nesting=nesting,\n", + " logsums=2,\n", + " as_dataarray=True,\n", + " pick_counted=True,\n", + ")\n", + "assert _ch.dims == (\"TOURIDX\", \"DRAW\")\n", "assert _ch.shape == (100000, 5)\n", - "assert _pr.dims == ('TOURIDX', 'DRAW')\n", + "assert _pr.dims == (\"TOURIDX\", \"DRAW\")\n", "assert _pr.shape == (100000, 5)\n", - "assert _ls.dims == ('TOURIDX', )\n", - "assert _ls.shape == (100000, )\n", - "assert _pc.dims == ('TOURIDX', 'DRAW')\n", + "assert _ls.dims == (\"TOURIDX\",)\n", + "assert _ls.shape == (100000,)\n", + "assert _pc.dims == (\"TOURIDX\", \"DRAW\")\n", "assert _pc.shape == (100000, 5)" ] }, @@ -1209,7 +1280,14 @@ "masker[::3] = 1\n", "\n", "_ch_m, _pr_m, _pc_m, _ls_m = flow.logit_draws(\n", - " b, draws_many, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True, mask=masker, pick_counted=True\n", + " b,\n", + " draws_many,\n", + " source=tree_2,\n", + " nesting=nesting,\n", + " logsums=2,\n", + " as_dataarray=True,\n", + " mask=masker,\n", + " pick_counted=True,\n", ")\n", "\n", "assert (_ch_m.values == (np.where(np.expand_dims(masker, -1), _ch, -1))).all()\n", @@ -1241,8 +1319,10 @@ "metadata": {}, "outputs": [], "source": [ - "tour_by_dest = tree.subspaces['tour']\n", - "tour_by_dest = tour_by_dest.assign_coords({'CAND_DEST': xr.DataArray(np.arange(25), dims='CAND_DEST')})\n", + "tour_by_dest = tree.subspaces[\"tour\"]\n", + "tour_by_dest = tour_by_dest.assign_coords(\n", + " {\"CAND_DEST\": xr.DataArray(np.arange(25), dims=\"CAND_DEST\")}\n", + ")\n", "tour_by_dest" ] }, @@ -1278,14 +1358,14 @@ " \"tour.in_time_period @ dot_skims.time_period\",\n", " ),\n", " extra_vars={\n", - " 'shortwait': 3,\n", - " 'one': 1,\n", + " \"shortwait\": 3,\n", + " \"one\": 1,\n", " },\n", " aux_vars={\n", - " 'short_i_wait_mult': 0.75,\n", - " 'income_breakpoints': income_breakpoints,\n", + " \"short_i_wait_mult\": 0.75,\n", + " \"income_breakpoints\": income_breakpoints,\n", " },\n", - " dim_order=('TOURIDX', 'CAND_DEST')\n", + " dim_order=(\"TOURIDX\", \"CAND_DEST\"),\n", ")\n", "wide_flow = wide_tree.setup_flow(spec.Expression)" ] @@ -1297,7 +1377,7 @@ "metadata": {}, "outputs": [], "source": [ - "%time wide_logsums = wide_flow.logit_draws(b, logsums=1, compile_watch=\"simple\")[-1]" + "wide_logsums = wide_flow.logit_draws(b, logsums=1, compile_watch=\"simple\")[-1]" ] }, { @@ -1320,20 +1400,30 @@ "source": [ "# TEST\n", "np.testing.assert_array_almost_equal(\n", - " wide_logsums[:5,:5],\n", - " np.array([[ 0.759222, 0.75862 , 0.744936, 0.758251, 0.737007],\n", - " [ 0.671698, 0.671504, 0.663015, 0.661482, 0.667133],\n", - " [ 0.670188, 0.678498, 0.687647, 0.691152, 0.715783],\n", - " [ 0.760743, 0.769123, 0.763733, 0.784487, 0.802356],\n", - " [ 0.73474 , 0.743051, 0.751439, 0.754731, 0.778121]], dtype=np.float32)\n", + " wide_logsums[:5, :5],\n", + " np.array(\n", + " [\n", + " [0.759222, 0.75862, 0.744936, 0.758251, 0.737007],\n", + " [0.671698, 0.671504, 0.663015, 0.661482, 0.667133],\n", + " [0.670188, 0.678498, 0.687647, 0.691152, 0.715783],\n", + " [0.760743, 0.769123, 0.763733, 0.784487, 0.802356],\n", + " [0.73474, 0.743051, 0.751439, 0.754731, 0.778121],\n", + " ],\n", + " dtype=np.float32,\n", + " ),\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " wide_logsums[-5:,-5:],\n", - " np.array([[ 0.719523, 0.755152, 0.739368, 0.762664, 0.764388],\n", - " [ 0.740303, 0.678783, 0.649964, 0.694407, 0.681555],\n", - " [ 0.758865, 0.663663, 0.637266, 0.673351, 0.65875 ],\n", - " [ 0.765125, 0.706478, 0.676878, 0.717814, 0.713912],\n", - " [ 0.73348 , 0.683626, 0.647698, 0.69146 , 0.673006]], dtype=np.float32)\n", + " wide_logsums[-5:, -5:],\n", + " np.array(\n", + " [\n", + " [0.719523, 0.755152, 0.739368, 0.762664, 0.764388],\n", + " [0.740303, 0.678783, 0.649964, 0.694407, 0.681555],\n", + " [0.758865, 0.663663, 0.637266, 0.673351, 0.65875],\n", + " [0.765125, 0.706478, 0.676878, 0.717814, 0.713912],\n", + " [0.73348, 0.683626, 0.647698, 0.69146, 0.673006],\n", + " ],\n", + " dtype=np.float32,\n", + " ),\n", ")" ] }, @@ -1346,8 +1436,8 @@ "source": [ "# TEST\n", "np.testing.assert_array_almost_equal(\n", - " wide_logsums[np.arange(len(tours)), tours['dest_taz_idx'].to_numpy()],\n", - " flow.logit_draws(b, logsums=1)[-1]\n", + " wide_logsums[np.arange(len(tours)), tours[\"dest_taz_idx\"].to_numpy()],\n", + " flow.logit_draws(b, logsums=1)[-1],\n", ")" ] }, @@ -1359,7 +1449,9 @@ "outputs": [], "source": [ "# TEST\n", - "wide_logsums_ = wide_flow.logit_draws(b, logsums=1, compile_watch=True, as_dataarray=True)[-1]\n", + "wide_logsums_ = wide_flow.logit_draws(\n", + " b, logsums=1, compile_watch=True, as_dataarray=True\n", + ")[-1]\n", "assert wide_logsums_.dims == (\"TOURIDX\", \"CAND_DEST\")\n", "assert wide_logsums_.shape == (100000, 25)" ] @@ -1392,7 +1484,9 @@ "source": [ "# TEST\n", "wide_draws = np.random.default_rng(42).random(size=wide_tree.shape + (2,))\n", - "wide_logsums_plus = wide_flow.logit_draws(b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws)\n", + "wide_logsums_plus = wide_flow.logit_draws(\n", + " b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws\n", + ")\n", "assert wide_logsums_plus[0].dims == (\"TOURIDX\", \"CAND_DEST\", \"DRAW\")\n", "assert wide_logsums_plus[0].shape == (100000, 25, 2)\n", "assert wide_logsums_plus[3].dims == (\"TOURIDX\", \"CAND_DEST\")\n", @@ -1418,8 +1512,12 @@ "assert wide_logsums_mask[3].dims == (\"TOURIDX\", \"CAND_DEST\")\n", "assert wide_logsums_mask[3].shape == (100000, 25)\n", "\n", - "assert (wide_logsums_plus[0].where(np.expand_dims(mask, -1), -1) == wide_logsums_mask[0]).all()\n", - "assert (wide_logsums_plus[1].where(np.expand_dims(mask, -1), 0) == wide_logsums_mask[1]).all()\n", + "assert (\n", + " wide_logsums_plus[0].where(np.expand_dims(mask, -1), -1) == wide_logsums_mask[0]\n", + ").all()\n", + "assert (\n", + " wide_logsums_plus[1].where(np.expand_dims(mask, -1), 0) == wide_logsums_mask[1]\n", + ").all()\n", "assert (wide_logsums_plus[3].where(mask, 0) == wide_logsums_mask[3]).all()" ] }, @@ -1431,17 +1529,30 @@ "outputs": [], "source": [ "# TEST masking performance\n", - "import timeit, warnings\n", + "import timeit\n", + "import warnings\n", + "\n", "with warnings.catch_warnings():\n", " warnings.simplefilter(\"error\")\n", - " masked_time = timeit.timeit(lambda: wide_flow.logit_draws(\n", - " b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws, mask=mask\n", - " ), number=1)\n", - " raw_time = timeit.timeit(lambda: wide_flow.logit_draws(\n", - " b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws\n", - " ), number=1)\n", + " masked_time = timeit.timeit(\n", + " lambda: wide_flow.logit_draws(\n", + " b,\n", + " logsums=2,\n", + " compile_watch=True,\n", + " as_dataarray=True,\n", + " draws=wide_draws,\n", + " mask=mask,\n", + " ),\n", + " number=1,\n", + " )\n", + " raw_time = timeit.timeit(\n", + " lambda: wide_flow.logit_draws(\n", + " b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws\n", + " ),\n", + " number=1,\n", + " )\n", "assert masked_time * 2 < raw_time # generous buffer, should be nearly 7 times faster\n", - "assert len(wide_flow.cache_misses['_imnl_plus1d']) == 3" + "assert len(wide_flow.cache_misses[\"_imnl_plus1d\"]) == 3" ] } ], diff --git a/docs/walkthrough/sparse.ipynb b/docs/walkthrough/sparse.ipynb index 73565bf..dc4415d 100644 --- a/docs/walkthrough/sparse.ipynb +++ b/docs/walkthrough/sparse.ipynb @@ -17,7 +17,7 @@ "source": [ "import numpy as np\n", "import pandas as pd\n", - "import xarray as xr\n", + "\n", "import sharrow as sh" ] }, @@ -106,10 +106,10 @@ "outputs": [], "source": [ "skims.redirection.set(\n", - " maz_taz, \n", - " map_to='otaz', \n", + " maz_taz,\n", + " map_to=\"otaz\",\n", " name=\"omaz\",\n", - " map_also={'dtaz': \"dmaz\"}, \n", + " map_also={\"dtaz\": \"dmaz\"},\n", ")" ] }, @@ -141,9 +141,9 @@ "outputs": [], "source": [ "skims.redirection.sparse_blender(\n", - " 'DISTWALK', \n", - " maz_to_maz_walk.OMAZ, \n", - " maz_to_maz_walk.DMAZ, \n", + " \"DISTWALK\",\n", + " maz_to_maz_walk.OMAZ,\n", + " maz_to_maz_walk.DMAZ,\n", " maz_to_maz_walk.DISTWALK,\n", " max_blend_distance=1.0,\n", " index=maz_taz.index,\n", @@ -170,10 +170,12 @@ "metadata": {}, "outputs": [], "source": [ - "trips = pd.DataFrame({\n", - " 'orig_maz': [100, 100, 100, 200, 200],\n", - " 'dest_maz': [100, 101, 103, 201, 202],\n", - "})\n", + "trips = pd.DataFrame(\n", + " {\n", + " \"orig_maz\": [100, 100, 100, 200, 200],\n", + " \"dest_maz\": [100, 101, 103, 201, 202],\n", + " }\n", + ")\n", "trips" ] }, @@ -199,7 +201,7 @@ " relationships=(\n", " \"base.orig_maz @ skims.omaz\",\n", " \"base.dest_maz @ skims.dmaz\",\n", - " )\n", + " ),\n", ")" ] }, @@ -218,9 +220,12 @@ "metadata": {}, "outputs": [], "source": [ - "flow = tree.setup_flow({\n", - " 'plain_distance': 'DISTWALK',\n", - "}, boundscheck=True)" + "flow = tree.setup_flow(\n", + " {\n", + " \"plain_distance\": \"DISTWALK\",\n", + " },\n", + " boundscheck=True,\n", + ")" ] }, { @@ -252,15 +257,20 @@ "source": [ "# TEST\n", "from pytest import approx\n", + "\n", "sparse_dat = np.array([0.01, 0.2, np.nan, 3.2, np.nan])\n", - "dense_dat = np.array([0.12,0.12,0.12,0.17,0.17])\n", - "def blend(s,d, max_s):\n", + "dense_dat = np.array([0.12, 0.12, 0.12, 0.17, 0.17])\n", + "\n", + "\n", + "def blend(s, d, max_s):\n", " out = np.zeros_like(d)\n", - " ratio = s/max_s\n", - " out = d*ratio + s*(1-ratio)\n", - " out = np.where(s>max_s, d, out)\n", + " ratio = s / max_s\n", + " out = d * ratio + s * (1 - ratio)\n", + " out = np.where(s > max_s, d, out)\n", " out = np.where(np.isnan(s), d, out)\n", " return out\n", + "\n", + "\n", "assert blend(sparse_dat, dense_dat, 1.0) == approx(flow.load().ravel())" ] }, @@ -279,11 +289,13 @@ "metadata": {}, "outputs": [], "source": [ - "flow2 = tree.setup_flow({\n", - " 'plain_distance': 'DISTWALK',\n", - " 'clip_distance': 'DISTWALK.clip(upper=0.15)',\n", - " 'square_distance': 'DISTWALK**2',\n", - "})" + "flow2 = tree.setup_flow(\n", + " {\n", + " \"plain_distance\": \"DISTWALK\",\n", + " \"clip_distance\": \"DISTWALK.clip(upper=0.15)\",\n", + " \"square_distance\": \"DISTWALK**2\",\n", + " }\n", + ")" ] }, { @@ -304,12 +316,17 @@ "outputs": [], "source": [ "# TEST\n", - "assert flow2.load_dataframe().values == approx(np.array([\n", - " [ 1.1100e-02, 1.1100e-02, 1.2321e-04],\n", - " [ 1.8400e-01, 1.5000e-01, 3.3856e-02],\n", - " [ 1.2000e-01, 1.2000e-01, 1.4400e-02],\n", - " [ 1.7000e-01, 1.5000e-01, 2.8900e-02],\n", - " [ 1.7000e-01, 1.5000e-01, 2.8900e-02]], dtype=np.float32)\n", + "assert flow2.load_dataframe().values == approx(\n", + " np.array(\n", + " [\n", + " [1.1100e-02, 1.1100e-02, 1.2321e-04],\n", + " [1.8400e-01, 1.5000e-01, 3.3856e-02],\n", + " [1.2000e-01, 1.2000e-01, 1.4400e-02],\n", + " [1.7000e-01, 1.5000e-01, 2.8900e-02],\n", + " [1.7000e-01, 1.5000e-01, 2.8900e-02],\n", + " ],\n", + " dtype=np.float32,\n", + " )\n", ")" ] }, @@ -340,7 +357,7 @@ "skims.at(\n", " omaz=trips.orig_maz,\n", " dmaz=trips.dest_maz,\n", - " _names=['DIST', 'DISTWALK'],\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", ")" ] }, @@ -355,24 +372,26 @@ "out = skims.at(\n", " omaz=trips.orig_maz,\n", " dmaz=trips.dest_maz,\n", - " _names=['DIST', 'DISTWALK'], _load=True,\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", + " _load=True,\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " out['DIST'].to_numpy(), \n", - " np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32)\n", + " out[\"DIST\"].to_numpy(), np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32)\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " out['DISTWALK'].to_numpy(), \n", - " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32)\n", + " out[\"DISTWALK\"].to_numpy(),\n", + " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32),\n", ")\n", "\n", "from pytest import raises\n", + "\n", "with raises(NotImplementedError):\n", " skims.at(\n", " omaz=trips.orig_maz,\n", " dmaz=trips.dest_maz,\n", - " time_period=['AM', 'AM', 'AM', 'AM', 'AM'],\n", - " _names=['DIST', 'DISTWALK', 'SOV_TIME'], _load=True,\n", + " time_period=[\"AM\", \"AM\", \"AM\", \"AM\", \"AM\"],\n", + " _names=[\"DIST\", \"DISTWALK\", \"SOV_TIME\"],\n", + " _load=True,\n", " )" ] }, @@ -384,9 +403,9 @@ "outputs": [], "source": [ "skims.iat(\n", - " omaz=[ 0, 0, 0, 100, 100],\n", - " dmaz=[ 0, 1, 3, 101, 102],\n", - " _names=['DIST', 'DISTWALK'],\n", + " omaz=[0, 0, 0, 100, 100],\n", + " dmaz=[0, 1, 3, 101, 102],\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", ")" ] }, @@ -399,18 +418,18 @@ "source": [ "# TEST\n", "out = skims.iat(\n", - " omaz=[ 0, 0, 0, 100, 100],\n", - " dmaz=[ 0, 1, 3, 101, 102],\n", - " _names=['DIST', 'DISTWALK'], _load=True,\n", + " omaz=[0, 0, 0, 100, 100],\n", + " dmaz=[0, 1, 3, 101, 102],\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", + " _load=True,\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " out['DIST'].to_numpy(), \n", - " np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32)\n", + " out[\"DIST\"].to_numpy(), np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32)\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " out['DISTWALK'].to_numpy(), \n", - " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32)\n", - ")\n" + " out[\"DISTWALK\"].to_numpy(),\n", + " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32),\n", + ")" ] }, { @@ -430,9 +449,10 @@ "outputs": [], "source": [ "skims.at(\n", - " otaz=[1,1,1,16,16],\n", - " dtaz=[1,1,1,16,16],\n", - " _names=['DIST', 'DISTWALK'], _load=True,\n", + " otaz=[1, 1, 1, 16, 16],\n", + " dtaz=[1, 1, 1, 16, 16],\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", + " _load=True,\n", ")" ] }, @@ -444,9 +464,9 @@ "outputs": [], "source": [ "skims.at(\n", - " otaz=[1,1,1,16,16],\n", - " dtaz=[1,1,1,16,16],\n", - " _name='DISTWALK',\n", + " otaz=[1, 1, 1, 16, 16],\n", + " dtaz=[1, 1, 1, 16, 16],\n", + " _name=\"DISTWALK\",\n", ")" ] }, @@ -458,44 +478,47 @@ "outputs": [], "source": [ "# TEST\n", - "import sys\n", - "if sys.version_info > (3,8):\n", - " import secrets\n", - " token = \"skims-with-sparse\" + secrets.token_hex(5)\n", - " readback0 = skims.shm.to_shared_memory(token)\n", - " assert readback0.attrs == skims.attrs\n", - " readback = sh.Dataset.shm.from_shared_memory(token)\n", - " assert readback.attrs == skims.attrs\n", - " \n", - " out = readback.iat(\n", - " omaz=[ 0, 0, 0, 100, 100],\n", - " dmaz=[ 0, 1, 3, 101, 102],\n", - " _names=['DIST', 'DISTWALK'], _load=True,\n", - " )\n", - " np.testing.assert_array_almost_equal(\n", - " out['DIST'].to_numpy(), \n", - " np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32)\n", - " )\n", - " np.testing.assert_array_almost_equal(\n", - " out['DISTWALK'].to_numpy(), \n", - " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32)\n", - " )\n", + "import secrets\n", "\n", - " out = readback.at(\n", - " omaz=trips.orig_maz,\n", - " dmaz=trips.dest_maz,\n", - " _names=['DIST', 'DISTWALK'], _load=True,\n", - " )\n", - " np.testing.assert_array_almost_equal(\n", - " out['DIST'].to_numpy(), \n", - " np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32)\n", - " )\n", - " np.testing.assert_array_almost_equal(\n", - " out['DISTWALK'].to_numpy(), \n", - " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32)\n", - " )\n", - " \n", - " assert readback.redirection.blenders == {'DISTWALK': {'max_blend_distance': 1.0, 'blend_distance_name': None}}\n" + "token = \"skims-with-sparse\" + secrets.token_hex(5)\n", + "readback0 = skims.shm.to_shared_memory(token)\n", + "assert readback0.attrs == skims.attrs\n", + "readback = sh.Dataset.shm.from_shared_memory(token)\n", + "assert readback.attrs == skims.attrs\n", + "\n", + "out = readback.iat(\n", + " omaz=[0, 0, 0, 100, 100],\n", + " dmaz=[0, 1, 3, 101, 102],\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", + " _load=True,\n", + ")\n", + "np.testing.assert_array_almost_equal(\n", + " out[\"DIST\"].to_numpy(),\n", + " np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32),\n", + ")\n", + "np.testing.assert_array_almost_equal(\n", + " out[\"DISTWALK\"].to_numpy(),\n", + " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32),\n", + ")\n", + "\n", + "out = readback.at(\n", + " omaz=trips.orig_maz,\n", + " dmaz=trips.dest_maz,\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", + " _load=True,\n", + ")\n", + "np.testing.assert_array_almost_equal(\n", + " out[\"DIST\"].to_numpy(),\n", + " np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32),\n", + ")\n", + "np.testing.assert_array_almost_equal(\n", + " out[\"DISTWALK\"].to_numpy(),\n", + " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32),\n", + ")\n", + "\n", + "assert readback.redirection.blenders == {\n", + " \"DISTWALK\": {\"max_blend_distance\": 1.0, \"blend_distance_name\": None}\n", + "}" ] }, { @@ -506,7 +529,9 @@ "outputs": [], "source": [ "# TEST\n", - "assert skims.redirection.blenders == {'DISTWALK': {'max_blend_distance': 1.0, 'blend_distance_name': None}}" + "assert skims.redirection.blenders == {\n", + " \"DISTWALK\": {\"max_blend_distance\": 1.0, \"blend_distance_name\": None}\n", + "}" ] }, { @@ -518,24 +543,28 @@ "source": [ "# TEST\n", "# reverse skims in sparse\n", - "flow3 = tree.setup_flow({\n", - " 'plain_distance': 'DISTWALK',\n", - " 'reverse_distance': 'skims.reverse(\"DISTWALK\")',\n", - "})\n", + "flow3 = tree.setup_flow(\n", + " {\n", + " \"plain_distance\": \"DISTWALK\",\n", + " \"reverse_distance\": 'skims.reverse(\"DISTWALK\")',\n", + " }\n", + ")\n", "\n", - "assert flow3.load() == approx(np.array([[ 0.0111, 0.0111],\n", - " [ 0.184 , 0.12 ],\n", - " [ 0.12 , 0.12 ],\n", - " [ 0.17 , 0.17 ],\n", - " [ 0.17 , 0.17 ]], dtype=np.float32))\n", + "assert flow3.load() == approx(\n", + " np.array(\n", + " [[0.0111, 0.0111], [0.184, 0.12], [0.12, 0.12], [0.17, 0.17], [0.17, 0.17]],\n", + " dtype=np.float32,\n", + " )\n", + ")\n", "\n", "z = skims.iat(\n", - " omaz=[ 0, 1, 3, 101, 102],\n", - " dmaz=[ 0, 0, 0, 100, 100],\n", - " _names=['DIST', 'DISTWALK'], _load=True,\n", + " omaz=[0, 1, 3, 101, 102],\n", + " dmaz=[0, 0, 0, 100, 100],\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", + " _load=True,\n", ")\n", - "assert z['DISTWALK'].data == approx(np.array([ 0.0111, 0.12 , 0.12 , 0.17 , 0.17 ]))\n", - "assert z['DIST'].data == approx(np.array([ 0.12, 0.12 , 0.12 , 0.17 , 0.17 ]))" + "assert z[\"DISTWALK\"].data == approx(np.array([0.0111, 0.12, 0.12, 0.17, 0.17]))\n", + "assert z[\"DIST\"].data == approx(np.array([0.12, 0.12, 0.12, 0.17, 0.17]))" ] } ], diff --git a/docs/walkthrough/two-dim.ipynb b/docs/walkthrough/two-dim.ipynb index b7fb5b6..4b68245 100644 --- a/docs/walkthrough/two-dim.ipynb +++ b/docs/walkthrough/two-dim.ipynb @@ -19,10 +19,10 @@ "outputs": [], "source": [ "import numpy as np\n", - "import pandas as pd\n", "import xarray as xr\n", "\n", "import sharrow as sh\n", + "\n", "sh.__version__" ] }, @@ -39,6 +39,7 @@ "source": [ "# TEST check versions\n", "import packaging\n", + "\n", "assert packaging.version.parse(xr.__version__) >= packaging.version.parse(\"0.20.2\")" ] }, @@ -83,7 +84,7 @@ "source": [ "# test households content\n", "assert len(households) == 5000\n", - "assert \"income\" in households \n", + "assert \"income\" in households\n", "assert households.index.name == \"HHID\"" ] }, @@ -111,7 +112,7 @@ "source": [ "assert len(persons) == 8212\n", "assert \"household_id\" in persons\n", - "assert persons.index.name == 'PERID'" + "assert persons.index.name == \"PERID\"" ] }, { @@ -180,7 +181,7 @@ "metadata": {}, "outputs": [], "source": [ - "workers = persons.query(\"pemploy in [1,2]\").rename_axis(index='WORKERID')\n", + "workers = persons.query(\"pemploy in [1,2]\").rename_axis(index=\"WORKERID\")\n", "workers" ] }, @@ -215,8 +216,8 @@ "metadata": {}, "outputs": [], "source": [ - "skims_am = skims.sel(time_period='AM')\n", - "skims_pm = skims.sel(time_period='PM')" + "skims_am = skims.sel(time_period=\"AM\")\n", + "skims_pm = skims.sel(time_period=\"PM\")" ] }, { @@ -246,7 +247,7 @@ "outputs": [], "source": [ "base = sh.dataset.from_named_objects(\n", - " workers.index, \n", + " workers.index,\n", " landuse.index,\n", ")" ] @@ -279,7 +280,7 @@ "metadata": {}, "outputs": [], "source": [ - "tree = sh.DataTree(base=base, dim_order=('WORKERID', 'TAZ'))" + "tree = sh.DataTree(base=base, dim_order=(\"WORKERID\", \"TAZ\"))" ] }, { @@ -294,7 +295,7 @@ "outputs": [], "source": [ "# TEST tree_dest attributes\n", - "assert tree.dim_order == ('WORKERID', 'TAZ')\n", + "assert tree.dim_order == (\"WORKERID\", \"TAZ\")\n", "assert tree.shape == (4361, 25)" ] }, @@ -317,7 +318,7 @@ "metadata": {}, "outputs": [], "source": [ - "tree.add_dataset('person', persons, \"base.WORKERID @ person.PERID\")" + "tree.add_dataset(\"person\", persons, \"base.WORKERID @ person.PERID\")" ] }, { @@ -337,8 +338,8 @@ "metadata": {}, "outputs": [], "source": [ - "tree.add_dataset('landuse', landuse, \"base.TAZ @ landuse.TAZ\")\n", - "tree.add_dataset('hh', households, \"person.household_id @ hh.HHID\")" + "tree.add_dataset(\"landuse\", landuse, \"base.TAZ @ landuse.TAZ\")\n", + "tree.add_dataset(\"hh\", households, \"person.household_id @ hh.HHID\")" ] }, { @@ -360,17 +361,17 @@ "outputs": [], "source": [ "tree.add_dataset(\n", - " 'odskims', \n", - " skims_am, \n", + " \"odskims\",\n", + " skims_am,\n", " relationships=(\n", - " \"hh.TAZ @ odskims.otaz\", \n", + " \"hh.TAZ @ odskims.otaz\",\n", " \"base.TAZ @ odskims.dtaz\",\n", " ),\n", ")\n", "\n", "tree.add_dataset(\n", - " 'doskims', \n", - " skims_pm, \n", + " \"doskims\",\n", + " skims_pm,\n", " relationships=(\n", " \"base.TAZ @ doskims.otaz\",\n", " \"hh.TAZ @ doskims.dtaz\",\n", @@ -399,10 +400,10 @@ "outputs": [], "source": [ "definition = {\n", - " 'round_trip_dist': 'odskims.DIST + doskims.DIST',\n", - " 'round_trip_dist_first_mile': 'clip(odskims.DIST, 0, 1) + clip(doskims.DIST, 0, 1)',\n", - " 'round_trip_dist_addl_miles': 'clip(odskims.DIST-1, 0, None) + clip(doskims.DIST-1, 0, None)',\n", - " 'size_term': 'log(TOTPOP + 0.5*EMPRES)',\n", + " \"round_trip_dist\": \"odskims.DIST + doskims.DIST\",\n", + " \"round_trip_dist_first_mile\": \"clip(odskims.DIST, 0, 1) + clip(doskims.DIST, 0, 1)\",\n", + " \"round_trip_dist_addl_miles\": \"clip(odskims.DIST-1, 0, None) + clip(doskims.DIST-1, 0, None)\",\n", + " \"size_term\": \"log(TOTPOP + 0.5*EMPRES)\",\n", "}\n", "\n", "flow = tree.setup_flow(definition)" @@ -440,37 +441,46 @@ "source": [ "# TEST\n", "assert arr.shape == (4361, 25, 4)\n", - "expected = np.array([\n", - " [[ 0.61 , 0.61 , 0. , 4.610157],\n", - " [ 0.28 , 0.28 , 0. , 5.681878],\n", - " [ 0.56 , 0.56 , 0. , 6.368187],\n", - " [ 0.53 , 0.53 , 0. , 5.741399],\n", - " [ 1.23 , 1.23 , 0. , 7.17549 ]],\n", - "\n", - " [[ 1.19 , 1.19 , 0. , 4.610157],\n", - " [ 1.49 , 1.49 , 0. , 5.681878],\n", - " [ 1.88 , 1.85 , 0.03 , 6.368187],\n", - " [ 1.36 , 1.36 , 0. , 5.741399],\n", - " [ 1.93 , 1.93 , 0. , 7.17549 ]],\n", - "\n", - " [[ 1.19 , 1.19 , 0. , 4.610157],\n", - " [ 1.49 , 1.49 , 0. , 5.681878],\n", - " [ 1.88 , 1.85 , 0.03 , 6.368187],\n", - " [ 1.36 , 1.36 , 0. , 5.741399],\n", - " [ 1.93 , 1.93 , 0. , 7.17549 ]],\n", - "\n", - " [[ 0.24 , 0.24 , 0. , 4.610157],\n", - " [ 0.61 , 0.61 , 0. , 5.681878],\n", - " [ 1.01 , 1.01 , 0. , 6.368187],\n", - " [ 0.75 , 0.75 , 0. , 5.741399],\n", - " [ 1.38 , 1.38 , 0. , 7.17549 ]],\n", - "\n", - " [[ 0.61 , 0.61 , 0. , 4.610157],\n", - " [ 0.28 , 0.28 , 0. , 5.681878],\n", - " [ 0.56 , 0.56 , 0. , 6.368187],\n", - " [ 0.53 , 0.53 , 0. , 5.741399],\n", - " [ 1.23 , 1.23 , 0. , 7.17549 ]],\n", - "], dtype=np.float32)\n", + "expected = np.array(\n", + " [\n", + " [\n", + " [0.61, 0.61, 0.0, 4.610157],\n", + " [0.28, 0.28, 0.0, 5.681878],\n", + " [0.56, 0.56, 0.0, 6.368187],\n", + " [0.53, 0.53, 0.0, 5.741399],\n", + " [1.23, 1.23, 0.0, 7.17549],\n", + " ],\n", + " [\n", + " [1.19, 1.19, 0.0, 4.610157],\n", + " [1.49, 1.49, 0.0, 5.681878],\n", + " [1.88, 1.85, 0.03, 6.368187],\n", + " [1.36, 1.36, 0.0, 5.741399],\n", + " [1.93, 1.93, 0.0, 7.17549],\n", + " ],\n", + " [\n", + " [1.19, 1.19, 0.0, 4.610157],\n", + " [1.49, 1.49, 0.0, 5.681878],\n", + " [1.88, 1.85, 0.03, 6.368187],\n", + " [1.36, 1.36, 0.0, 5.741399],\n", + " [1.93, 1.93, 0.0, 7.17549],\n", + " ],\n", + " [\n", + " [0.24, 0.24, 0.0, 4.610157],\n", + " [0.61, 0.61, 0.0, 5.681878],\n", + " [1.01, 1.01, 0.0, 6.368187],\n", + " [0.75, 0.75, 0.0, 5.741399],\n", + " [1.38, 1.38, 0.0, 7.17549],\n", + " ],\n", + " [\n", + " [0.61, 0.61, 0.0, 4.610157],\n", + " [0.28, 0.28, 0.0, 5.681878],\n", + " [0.56, 0.56, 0.0, 6.368187],\n", + " [0.53, 0.53, 0.0, 5.741399],\n", + " [1.23, 1.23, 0.0, 7.17549],\n", + " ],\n", + " ],\n", + " dtype=np.float32,\n", + ")\n", "\n", "np.testing.assert_array_almost_equal(arr[:5, :5, :], expected)" ] @@ -529,10 +539,20 @@ "source": [ "# TEST\n", "assert isinstance(arr_pretty, xr.DataArray)\n", - "assert arr_pretty.dims == ('WORKERID', 'TAZ', 'expressions')\n", + "assert arr_pretty.dims == (\"WORKERID\", \"TAZ\", \"expressions\")\n", "assert arr_pretty.shape == (4361, 25, 4)\n", - "assert all(arr_pretty.expressions == np.array(['round_trip_dist', 'round_trip_dist_first_mile',\n", - " 'round_trip_dist_addl_miles', 'size_term'], dtype=' Date: Tue, 16 Jan 2024 12:19:21 -0600 Subject: [PATCH 12/14] ruff tests in CI --- .github/workflows/run-tests.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 04dec84..ae23237 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -18,7 +18,7 @@ jobs: strategy: matrix: os: ["ubuntu-latest", "macos-latest", "windows-latest"] - python-version: ["3.9", "3.10"] + python-version: ["3.9", "3.10", "3.11"] defaults: run: shell: bash -l {0} @@ -44,10 +44,11 @@ jobs: conda list - name: Lint with Ruff run: | + # code quality check # stop the build if there are Python syntax errors or undefined names - ruff check . --select=E9,F63,F7,F82 --statistics - # exit-zero treats all errors as warnings. - ruff check . --exit-zero --statistics + ruff check . --select=E9,F63,F7,F82 --no-fix + # stop the build for any other configured Ruff linting errors + ruff check . --show-fixes --exit-non-zero-on-fix - name: Test with pytest run: | python -m pytest From 4f7fb9deef40a93632310e4559a7050a6ac7e7d3 Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Tue, 16 Jan 2024 12:35:20 -0600 Subject: [PATCH 13/14] pre-commit service --- .github/workflows/run-tests.yml | 23 ++++++++++++++++++++++- README.md | 6 ++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index ae23237..3b9775e 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -12,13 +12,34 @@ on: workflow_dispatch: jobs: + + fmt: + name: formatting quality + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Install Ruff + run: | + python -m pip install ruff + - name: Lint with Ruff + run: | + # code quality check, stop the build for any errors + ruff check . --show-fixes --exit-non-zero-on-fix + test: + needs: fmt name: ${{ matrix.os }} py${{ matrix.python-version }} runs-on: ${{ matrix.os }} strategy: matrix: os: ["ubuntu-latest", "macos-latest", "windows-latest"] - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10"] defaults: run: shell: bash -l {0} diff --git a/README.md b/README.md index 829aef2..d34b1f8 100644 --- a/README.md +++ b/README.md @@ -40,3 +40,9 @@ you can do so with the following command: ```shell pre-commit run --all-files ``` + +If you don't use `pre-commit`, a service will run the checks for you when you +open a pull request, and make fixes to your code when possible. + + + From 690b6a5b6b9bd70b6fc765441345e78d9a63c66a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jan 2024 18:35:45 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/README.md b/README.md index d34b1f8..c72cb6c 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,3 @@ pre-commit run --all-files If you don't use `pre-commit`, a service will run the checks for you when you open a pull request, and make fixes to your code when possible. - - -