Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend interface, now it uses subclassing #4836

Merged
merged 9 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 51 additions & 55 deletions xarray/backends/cfgrib_.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
BackendEntrypoint,
)
from .locks import SerializableLock, ensure_lock
from .store import open_backend_dataset_store
from .store import StoreBackendEntrypoint

try:
import cfgrib
Expand Down Expand Up @@ -86,62 +86,58 @@ def get_encoding(self):
return encoding


def guess_can_open_cfgrib(store_spec):
try:
_, ext = os.path.splitext(store_spec)
except TypeError:
return False
return ext in {".grib", ".grib2", ".grb", ".grb2"}


def open_backend_dataset_cfgrib(
filename_or_obj,
*,
mask_and_scale=True,
decode_times=None,
concat_characters=None,
decode_coords=None,
drop_variables=None,
use_cftime=None,
decode_timedelta=None,
lock=None,
indexpath="{path}.{short_hash}.idx",
filter_by_keys={},
read_keys=[],
encode_cf=("parameter", "time", "geography", "vertical"),
squeeze=True,
time_dims=("time", "step"),
):

store = CfGribDataStore(
class CfgribfBackendEntrypoint(BackendEntrypoint):
def guess_can_open(self, store_spec):
try:
_, ext = os.path.splitext(store_spec)
except TypeError:
return False
return ext in {".grib", ".grib2", ".grb", ".grb2"}

def open_dataset(
self,
filename_or_obj,
indexpath=indexpath,
filter_by_keys=filter_by_keys,
read_keys=read_keys,
encode_cf=encode_cf,
squeeze=squeeze,
time_dims=time_dims,
lock=lock,
)

with close_on_error(store):
ds = open_backend_dataset_store(
store,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
*,
mask_and_scale=True,
decode_times=None,
concat_characters=None,
decode_coords=None,
drop_variables=None,
use_cftime=None,
decode_timedelta=None,
lock=None,
indexpath="{path}.{short_hash}.idx",
filter_by_keys={},
read_keys=[],
encode_cf=("parameter", "time", "geography", "vertical"),
squeeze=True,
time_dims=("time", "step"),
):

store = CfGribDataStore(
filename_or_obj,
indexpath=indexpath,
filter_by_keys=filter_by_keys,
read_keys=read_keys,
encode_cf=encode_cf,
squeeze=squeeze,
time_dims=time_dims,
lock=lock,
)
return ds


cfgrib_backend = BackendEntrypoint(
open_dataset=open_backend_dataset_cfgrib, guess_can_open=guess_can_open_cfgrib
)
store_entrypoint = StoreBackendEntrypoint()
with close_on_error(store):
ds = store_entrypoint.open_dataset(
store,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
return ds


if has_cfgrib:
BACKEND_ENTRYPOINTS["cfgrib"] = cfgrib_backend
BACKEND_ENTRYPOINTS["cfgrib"] = CfgribfBackendEntrypoint
15 changes: 8 additions & 7 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import time
import traceback
from typing import Dict
from typing import Dict, Tuple, Type, Union

import numpy as np

Expand Down Expand Up @@ -344,12 +344,13 @@ def encode(self, variables, attributes):


class BackendEntrypoint:
__slots__ = ("guess_can_open", "open_dataset", "open_dataset_parameters")
open_dataset_parameters: Union[Tuple, None] = None

def __init__(self, open_dataset, open_dataset_parameters=None, guess_can_open=None):
self.open_dataset = open_dataset
self.open_dataset_parameters = open_dataset_parameters
self.guess_can_open = guess_can_open
def open_dataset(self):
raise NotImplementedError

def guess_can_open(self, store_spec):
return False

BACKEND_ENTRYPOINTS: Dict[str, BackendEntrypoint] = {}

BACKEND_ENTRYPOINTS: Dict[str, Type[BackendEntrypoint]] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from a typing perspective, wont this will require all backends to subclass from BackendEntrypoint? So while we seem to have decided to not to explicitly check for inheritance, this would seem to require it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mm.. you are right, and in fact this is a point in favour of inheritance. So we can control typing. Otherwise that would be Any that feels less correct to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, type checking is optional, and I doubt type-checkers are aware of entry-points :). So this would seem to imply that restriction only for xarray. That said, I'm not majorly concerned about making users inherit from a class. They already have a hard dependency on xarray, so there isn't a strong need for using a protocol.

We could also (optionally?) inherit BackendEntrypoint from typing.Protocol, although that requires Python 3.9 or typing_extensions

Copy link
Collaborator Author

@aurghs aurghs Jan 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any advantage in forcing the user to inherit from our class.
typing.Protocol seems to be the best solution, but I would avoid adding new requirements only for this.

107 changes: 53 additions & 54 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_get_datatype,
_nc4_require_group,
)
from .store import open_backend_dataset_store
from .store import StoreBackendEntrypoint

try:
import h5netcdf
Expand Down Expand Up @@ -328,62 +328,61 @@ def close(self, **kwargs):
self._manager.close(**kwargs)


def guess_can_open_h5netcdf(store_spec):
try:
return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n")
except TypeError:
pass

try:
_, ext = os.path.splitext(store_spec)
except TypeError:
return False

return ext in {".nc", ".nc4", ".cdf"}


def open_backend_dataset_h5netcdf(
filename_or_obj,
*,
mask_and_scale=True,
decode_times=None,
concat_characters=None,
decode_coords=None,
drop_variables=None,
use_cftime=None,
decode_timedelta=None,
format=None,
group=None,
lock=None,
invalid_netcdf=None,
phony_dims=None,
):

store = H5NetCDFStore.open(
class H5netcdfBackendEntrypoint(BackendEntrypoint):
def guess_can_open(self, store_spec):
try:
return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n")
except TypeError:
pass

try:
_, ext = os.path.splitext(store_spec)
except TypeError:
return False

return ext in {".nc", ".nc4", ".cdf"}

def open_dataset(
self,
filename_or_obj,
format=format,
group=group,
lock=lock,
invalid_netcdf=invalid_netcdf,
phony_dims=phony_dims,
)
*,
mask_and_scale=True,
decode_times=None,
concat_characters=None,
decode_coords=None,
drop_variables=None,
use_cftime=None,
decode_timedelta=None,
format=None,
group=None,
lock=None,
invalid_netcdf=None,
phony_dims=None,
):

ds = open_backend_dataset_store(
store,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
return ds
store = H5NetCDFStore.open(
filename_or_obj,
format=format,
group=group,
lock=lock,
invalid_netcdf=invalid_netcdf,
phony_dims=phony_dims,
)

store_entrypoint = StoreBackendEntrypoint()

ds = store_entrypoint.open_dataset(
store,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
return ds

h5netcdf_backend = BackendEntrypoint(
open_dataset=open_backend_dataset_h5netcdf, guess_can_open=guess_can_open_h5netcdf
)

if has_h5netcdf:
BACKEND_ENTRYPOINTS["h5netcdf"] = h5netcdf_backend
BACKEND_ENTRYPOINTS["h5netcdf"] = H5netcdfBackendEntrypoint
Loading