Skip to content

Commit

Permalink
More detailed typing on file openers (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt authored Nov 15, 2024
1 parent 97ec1ec commit 39dda18
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 18 deletions.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ ignore =
W503
S410
S320
# overload operator causes this
E704
exclude =
.tox,
.git,
Expand Down
65 changes: 57 additions & 8 deletions src/pystow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,19 @@

import sqlite3
from contextlib import contextmanager
from io import BytesIO, StringIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generator, Mapping, Optional, Sequence, Union
from typing import (
TYPE_CHECKING,
Any,
Generator,
Literal,
Mapping,
Optional,
Sequence,
Union,
overload,
)

from .constants import JSON, BytesOpener, Opener, Provider
from .impl import Module
Expand Down Expand Up @@ -109,15 +120,40 @@ def join(key: str, *subkeys: str, name: Optional[str] = None, ensure_exists: boo
return _module.join(*subkeys, name=name, ensure_exists=ensure_exists)


# docstr-coverage:excused `overload`
@overload
@contextmanager
def open(
key: str,
*subkeys: str,
name: str,
mode: str = "r",
mode: Literal["r", "rt", "w", "wt"] = "r",
open_kwargs: Optional[Mapping[str, Any]] = None,
) -> Opener:
"""Open a file that exists already.
) -> Generator[StringIO, None, None]: ...


# docstr-coverage:excused `overload`
@overload
@contextmanager
def open(
key: str,
*subkeys: str,
name: str,
mode: Literal["rb", "wb"],
open_kwargs: Optional[Mapping[str, Any]] = None,
) -> Generator[BytesIO, None, None]: ...


@contextmanager
def open(
key: str,
*subkeys: str,
name: str,
mode: Literal["r", "rb", "rt", "w", "wb", "wt"] = "r",
open_kwargs: Optional[Mapping[str, Any]] = None,
ensure_exists: bool = False,
) -> Generator[Union[StringIO, BytesIO], None, None]:
"""Open a file.
:param key:
The name of the module. No funny characters. The envvar
Expand All @@ -127,13 +163,26 @@ def open(
A sequence of additional strings to join. If none are given,
returns the directory for this module.
:param name: The name of the file to open
:param mode: The read mode, passed to :func:`open`
:param mode: The read or write mode, passed to :func:`open`
:param open_kwargs: Additional keyword arguments passed to :func:`open`
:param ensure_exists: Should the directory the file is in be made? Set to true on write operations.
:yields: An open file object
This function should be called inside a context manager like in the following
.. code-block:: python
import pystow
with pystow.open("test", name="test.tsv", mode="w") as file:
print("Test text!", file=file)
"""
_module = Module.from_key(key, ensure_exists=True)
with _module.open(*subkeys, name=name, mode=mode, open_kwargs=open_kwargs) as file:
with _module.open(
*subkeys, name=name, mode=mode, open_kwargs=open_kwargs, ensure_exists=ensure_exists
) as file:
yield file


Expand Down Expand Up @@ -924,7 +973,7 @@ def load_pickle(
key: str,
*subkeys: str,
name: str,
mode: str = "rb",
mode: Literal["rb"] = "rb",
open_kwargs: Optional[Mapping[str, Any]] = None,
pickle_load_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any:
Expand Down Expand Up @@ -955,7 +1004,7 @@ def dump_pickle(
*subkeys: str,
name: str,
obj: Any,
mode: str = "wb",
mode: Literal["wb"] = "wb",
open_kwargs: Optional[Mapping[str, Any]] = None,
pickle_dump_kwargs: Optional[Mapping[str, Any]] = None,
) -> None:
Expand Down
69 changes: 59 additions & 10 deletions src/pystow/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@
import tarfile
import zipfile
from contextlib import closing, contextmanager
from io import BytesIO, StringIO
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
Literal,
Mapping,
Optional,
Sequence,
Union,
overload,
)

from . import utils
Expand Down Expand Up @@ -356,27 +359,71 @@ def ensure_open(
with path.open(**open_kwargs) as file:
yield file

# docstr-coverage:excused `overload`
@overload
@contextmanager
def open(
self,
*subkeys: str,
name: str,
mode: str = "r",
mode: Literal["r", "rt", "w", "wt"] = ...,
open_kwargs: Optional[Mapping[str, Any]] = None,
ensure_exists: bool,
) -> Generator[StringIO, None, None]: ...

# docstr-coverage:excused `overload`
@overload
@contextmanager
def open(
self,
*subkeys: str,
name: str,
mode: Literal["rb", "wb"] = ...,
open_kwargs: Optional[Mapping[str, Any]] = None,
ensure_exists: bool,
) -> Generator[BytesIO, None, None]: ...

@contextmanager
def open(
self,
*subkeys: str,
name: str,
mode: Union[Literal["r", "rt", "w", "wt"], Literal["rb", "wb"]] = "r",
open_kwargs: Optional[Mapping[str, Any]] = None,
ensure_exists: bool = False,
) -> Opener:
"""Open a file that exists already.
) -> Generator[Union[StringIO, BytesIO], None, None]:
"""Open a file.
:param subkeys:
A sequence of additional strings to join. If none are given,
returns the directory for this module.
:param name: The name of the file to open
:param mode: The read mode, passed to :func:`open`
:param open_kwargs: Additional keyword arguments passed to :func:`open`
:param ensure_exists: Should the file be made? Set to true on write operations.
:param ensure_exists: Should the directory the file is in be made? Set to true on write operations.
:raises ValueError: In the following situations:
1. If the file should be opened in write mode, and it is not ensured to exist
2. If the file should be opened in read mode, and it is ensured to exist. This is bad because
it will create a file when there previously wasn't one
:yields: An open file object.
This function should be called inside a context manager like in the following
.. code-block:: python
import pystow
with pystow.module("test").open(name="test.tsv", mode="w") as file:
print("Test text!", file=file)
:yields: An open file object
"""
if "w" in mode and not ensure_exists:
raise ValueError
if "r" in mode and ensure_exists:
raise ValueError

path = self.join(*subkeys, name=name, ensure_exists=ensure_exists)
open_kwargs = {} if open_kwargs is None else dict(open_kwargs)
open_kwargs.setdefault("mode", mode)
Expand Down Expand Up @@ -661,7 +708,7 @@ def load_df(
"""
import pandas as pd

with self.open(*subkeys, name=name) as file:
with self.open(*subkeys, name=name, mode="r", ensure_exists=False) as file:
return pd.read_csv(file, **_clean_csv_kwargs(read_csv_kwargs))

def dump_df(
Expand Down Expand Up @@ -787,7 +834,7 @@ def load_json(
:returns: A JSON object (list, dict, etc.)
"""
with self.open(
*subkeys, name=name, mode="r", open_kwargs=open_kwargs, ensure_exists=True
*subkeys, name=name, mode="r", open_kwargs=open_kwargs, ensure_exists=False
) as file:
return json.load(file, **(json_load_kwargs or {}))

Expand Down Expand Up @@ -859,7 +906,7 @@ def load_pickle(
self,
*subkeys: str,
name: str,
mode: str = "rb",
mode: Literal["rb"] = "rb",
open_kwargs: Optional[Mapping[str, Any]] = None,
pickle_load_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any:
Expand All @@ -879,6 +926,7 @@ def load_pickle(
name=name,
mode=mode,
open_kwargs=open_kwargs,
ensure_exists=False,
) as file:
return pickle.load(file, **(pickle_load_kwargs or {}))

Expand All @@ -887,7 +935,7 @@ def dump_pickle(
*subkeys: str,
name: str,
obj: Any,
mode: str = "wb",
mode: Literal["wb"] = "wb",
open_kwargs: Optional[Mapping[str, Any]] = None,
pickle_dump_kwargs: Optional[Mapping[str, Any]] = None,
) -> None:
Expand All @@ -907,6 +955,7 @@ def dump_pickle(
name=name,
mode=mode,
open_kwargs=open_kwargs,
ensure_exists=True,
) as file:
pickle.dump(obj, file, **(pickle_dump_kwargs or {}))

Expand Down Expand Up @@ -1103,7 +1152,7 @@ def load_xml(
"""
from lxml import etree

with self.open(*subkeys, name=name, ensure_exists=False) as file:
with self.open(*subkeys, mode="r", name=name, ensure_exists=False) as file:
return etree.parse(file, **(parse_kwargs or {}))

def dump_xml(
Expand Down

0 comments on commit 39dda18

Please sign in to comment.