Skip to content

Commit

Permalink
refactor(catalogs): create base methods in catalogs module that dynam…
Browse files Browse the repository at this point in the history
…ically import relevant submodules
  • Loading branch information
annehaley authored and johnkit committed Apr 30, 2024
1 parent 8aa8780 commit 633918d
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 17 deletions.
26 changes: 25 additions & 1 deletion pan3d/catalogs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib


def call_catalog_function(catalog_name, function_name, **kwargs):
def _call_catalog_function(catalog_name, function_name, **kwargs):
try:
module = importlib.import_module(f"pan3d.catalogs.{catalog_name}")
func = getattr(module, function_name)
Expand All @@ -12,3 +12,27 @@ def call_catalog_function(catalog_name, function_name, **kwargs):
)
except AttributeError:
raise ValueError(f"{catalog_name} is not a valid catalog module.")


def get(catalog_name):
return _call_catalog_function(catalog_name, "get_catalog")


def get_search_options(catalog_name):
return _call_catalog_function(catalog_name, "get_search_options")


def search(catalog_name, **filters):
return _call_catalog_function(catalog_name, "search", **filters)


def load_dataset(catalog_name, id):
return _call_catalog_function(catalog_name, "load_dataset", id=id)


__all__ = [
get,
get_search_options,
search,
load_dataset,
]
4 changes: 2 additions & 2 deletions pan3d/catalogs/esgf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_catalog():
}


def get_catalog_search_options():
def get_search_options():
catalog = ESGFCatalog()
# perform unfiltered search and get unique values for each column
results = catalog.search()
Expand All @@ -22,7 +22,7 @@ def get_catalog_search_options():
return search_options


def search_catalog(**kwargs):
def search(**kwargs):
group_name = "/".join([f'{k}:{",".join(v)}' for k, v in kwargs.items()])
if not group_name:
group_name = "All ESGF Datasets"
Expand Down
4 changes: 2 additions & 2 deletions pan3d/catalogs/pangeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def entry_filter_match(entry, filters):
return True


def get_catalog_search_options():
def get_search_options():
all_entries = get_all_entries()
search_options = {
"name": [],
Expand All @@ -93,7 +93,7 @@ def get_catalog_search_options():
return search_options


def search_catalog(**kwargs):
def search(**kwargs):
group_name = "/".join([f'{k}:{",".join(v)}' for k, v in kwargs.items()])
if not group_name:
group_name = "All Pangeo Datasets"
Expand Down
6 changes: 2 additions & 4 deletions pan3d/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import xarray

from pan3d.utils import coordinate_auto_selection
from pan3d.catalogs import call_catalog_function
from pan3d import catalogs
from pathlib import Path
from pvxarray.vtk_source import PyVistaXarraySource
from typing import Any, Dict, List, Optional, Union, Tuple
Expand Down Expand Up @@ -261,9 +261,7 @@ def _load_dataset(self, dataset_info):
if dataset_info is not None:
source = dataset_info.get("source")
if source in ["pangeo", "esgf"]:
ds = call_catalog_function(
source, "load_dataset", id=dataset_info["id"]
)
ds = catalogs.load_dataset(source, id=dataset_info["id"])
elif source == "xarray":
ds = xarray.tutorial.load_dataset(dataset_info["id"])
else:
Expand Down
13 changes: 5 additions & 8 deletions pan3d/dataset_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from trame_server.controller import Controller
from trame_vuetify.ui.vuetify3 import VAppLayout

from pan3d.catalogs import call_catalog_function
from pan3d import catalogs
from pan3d.dataset_builder import DatasetBuilder
from pan3d.ui import AxisDrawer, MainDrawer, Toolbar, RenderOptions
from pan3d.utils import (
Expand Down Expand Up @@ -69,8 +69,7 @@ def __init__(

if catalogs:
self.state.available_catalogs = [
call_catalog_function(catalog_name, "get_catalog")
for catalog_name in catalogs
catalogs.get(catalog_name) for catalog_name in catalogs
]

self._force_local_rendering = not has_gpu_rendering()
Expand Down Expand Up @@ -159,8 +158,8 @@ def _update_catalog_search_term(self, term_key, term_value):
def _catalog_search(self):
def load_results():
catalog_id = self.state.catalog.get("id")
results, group_name, message = call_catalog_function(
catalog_id, "search_catalog", **self.state.catalog_current_search
results, group_name, message = catalogs.search(
catalog_id, **self.state.catalog_current_search
)

if len(results) > 0:
Expand All @@ -185,9 +184,7 @@ def load_results():
def _catalog_term_option_search(self):
def load_terms():
catalog_id = self.state.catalog.get("id")
search_options = call_catalog_function(
catalog_id, "get_catalog_search_options"
)
search_options = catalogs.get_search_options(catalog_id)
self.state.available_catalogs = [
{
**catalog,
Expand Down

0 comments on commit 633918d

Please sign in to comment.