Skip to content

Commit

Permalink
Created a new DerivaModel class and pulled out all of the model speci…
Browse files Browse the repository at this point in the history
…fic methods and put into this class.

Changed Dataset_Bag so it inherits from this model class.
  • Loading branch information
carlkesselman committed Feb 14, 2025
1 parent dc5705c commit 9f702a7
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 170 deletions.
75 changes: 57 additions & 18 deletions src/deriva_ml/database_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

from .deriva_definitions import ML_SCHEMA, MLVocab, RID, DerivaMLException
from .dataset_aux_classes import DatasetVersion, DatasetMinid

from .deriva_model import DerivaModel
from .dataset_bag import DatasetBag


class DatabaseModel:
class DatabaseModel(DerivaModel):
"""Read in the contents of a BDBag and create a local SQLite database.
As part of its initialization, this routine will create a sqlite database that has the contents of all the tables
Expand All @@ -29,6 +29,10 @@ class DatabaseModel:
to the table name using the convention SchemaName:TableName. Methods in DatasetBag that have table names as the
argument will perform the appropriate name mappings.
Because of nested datasets, it's possible that more than one dataset rid is in a bag, or that a dataset rid might
appear in more than one database. To help manage this, a global list of all the datasets that have been loaded
into DatabaseModels, is kept in the class variable `_rid_map`.
Attributes:
bag_path (Path): path to the local copy of the BDBag
minid (DatasetMinid): Minid for the specified bag
Expand All @@ -38,25 +42,54 @@ class DatabaseModel:
dataset_table (Table): the dataset table in the ERMRest model.
"""

# Keep track of what databases we have loaded.
_paths_loaded: dict[Path:"DatabaseModel"] = {}

# Maintain a global map of RIDS to versions and databases.
_rid_map: dict[RID, list[tuple[DatasetVersion, "DatabaseModel"]]] = {}

@classmethod
@validate_call
def register(cls, minid: DatasetMinid, bag_path: Path):
"""Register a new minid in the list of local databases if it's new, otherwise, return an existing DatabaseModel.
Args:
minid: MINID to the databag that is to be loaded.
bag_path: Path to the bag on the local filesystem./
Returns:
A DatabaseModel instance to the loaded bag.
"""
o = cls._paths_loaded.get(bag_path.as_posix())
if o:
return o
return cls(minid, bag_path)

@staticmethod
def rid_lookup(dataset_rid: RID) -> list[tuple[DatasetVersion, "DatabaseModel"]]:
"""Return a list of DatasetVersion/DatabaseModel instances corresponding to the given RID.
Args:
dataset_rid: Rit to be looked up.
Returns:
List of DatasetVersion/DatabaseModel instances corresponding to the given RID.
Raises:
Raise a DerivaMLException if the given RID is not found.
"""
try:
return DatabaseModel._rid_map[dataset_rid]
except KeyError:
raise DerivaMLException(f"Dataset {dataset_rid} not found")

def __init__(self, minid: DatasetMinid, bag_path: Path):
"""Create a new DatabaseModel. This should only be called via the static Register method
Args:
minid: Minid for the specified bag.
bag_path: Path to the local copy of the BDBag.
"""
DatabaseModel._paths_loaded[bag_path.as_posix()] = self

self.bag_path = bag_path
Expand All @@ -66,7 +99,10 @@ def __init__(self, minid: DatasetMinid, bag_path: Path):
self.dbase_file = dir_path / f"{minid.version_rid}.db"
self.dbase = sqlite3.connect(self.dbase_file)

self._model = Model.fromfile("file-system", self.bag_path / "data/schema.json")
super().__init__(
Model.fromfile("file-system", self.bag_path / "data/schema.json")
)

self._logger = logging.getLogger("deriva_ml")
self.domain_schema = self._guess_domain_schema()
self._load_model()
Expand All @@ -77,7 +113,7 @@ def __init__(self, minid: DatasetMinid, bag_path: Path):
self.dataset_rid,
self.dbase_file,
)
self.dataset_table = self._model.schemas[self.ml_schema].tables["Dataset"]
self.dataset_table = self.model.schemas[self.ml_schema].tables["Dataset"]
# Now go through the database and pick out all the dataset_table RIDS, along with their versions.
sql_dataset = self.normalize_table_name("Dataset_Version")
with self.dbase:
Expand All @@ -100,12 +136,11 @@ def __init__(self, minid: DatasetMinid, bag_path: Path):
version_list.append((dataset_version, self))

def _load_model(self) -> None:
# Create a sqlite database schema that contains all the tables within the catalog from which the
# BDBag was created.
"""Create a sqlite database schema that contains all the tables within the catalog from which the BDBag was created."""
with self.dbase:
for t in self._model.schemas[self.domain_schema].tables.values():
for t in self.model.schemas[self.domain_schema].tables.values():
self.dbase.execute(t.sqlite3_ddl())
for t in self._model.schemas["deriva-ml"].tables.values():
for t in self.model.schemas["deriva-ml"].tables.values():
self.dbase.execute(t.sqlite3_ddl())

def _load_sqllite(self) -> None:
Expand All @@ -123,7 +158,7 @@ def _load_sqllite(self) -> None:
table = csv_file.stem
schema = (
self.domain_schema
if table in self._model.schemas[self.domain_schema].tables
if table in self.model.schemas[self.domain_schema].tables
else self.ml_schema
)

Expand Down Expand Up @@ -173,10 +208,14 @@ def _localize_asset_table(self) -> dict[str, str]:
logging.info(f"No downloaded assets in bag {dataset_rid}")
return fetch_map

def _guess_domain_schema(self):
# Guess the domain schema name by eliminating all the "builtin" schema.
def _guess_domain_schema(self) -> str:
"""Guess the domain schema name by eliminating all the "builtin" schema.
Returns:
String for domain schema name.
"""
return [
s for s in self._model.schemas if s not in ["deriva-ml", "public", "www"]
s for s in self.model.schemas if s not in ["deriva-ml", "public", "www"]
][0]

def _is_asset(self, table_name: str) -> bool:
Expand All @@ -186,15 +225,15 @@ def _is_asset(self, table_name: str) -> bool:
table_name: str:
Returns:
Boolean that is true if the table looks like an asset table.
"""
asset_columns = {"Filename", "URL", "Length", "MD5", "Description"}
sname = (
self.domain_schema
if table_name in self._model.schemas[self.domain_schema].tables
if table_name in self.model.schemas[self.domain_schema].tables
else self.ml_schema
)
asset_table = self._model.schemas[sname].tables[table_name]
asset_table = self.model.schemas[sname].tables[table_name]
return asset_columns.issubset({c.name for c in asset_table.columns})

@staticmethod
Expand Down Expand Up @@ -260,7 +299,7 @@ def find_datasets(self) -> list[dict[str, Any]]:
list of currently available datasets.
"""
atable = next(
self._model.schemas[ML_SCHEMA]
self.model.schemas[ML_SCHEMA]
.tables[MLVocab.dataset_type]
.find_associations()
).name
Expand Down Expand Up @@ -291,11 +330,11 @@ def normalize_table_name(self, table: str) -> str:
[sname, tname] = table.split(":")
except ValueError:
tname = table
for sname, s in self._model.schemas.items():
for sname, s in self.model.schemas.items():
if table in s.tables:
break
try:
_ = self._model.schemas[sname].tables[tname]
_ = self.model.schemas[sname].tables[tname]
return f"{sname}:{tname}"
except KeyError:
raise DerivaMLException(f'Table name "{table}" does not exist.')
Expand Down
40 changes: 11 additions & 29 deletions src/deriva_ml/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from bdbag.fetch.fetcher import fetch_single_file
from bdbag import bdbag_api as bdb
from collections import defaultdict
from deriva.core.ermrest_model import Model, Table
from deriva.core.ermrest_model import Table
from deriva.core.datapath import DataPathException
from deriva.core.utils.core_utils import tag as deriva_tags, format_exception
from deriva.transfer.download.deriva_export import DerivaExport
Expand Down Expand Up @@ -41,6 +41,7 @@
from deriva_ml import DatasetBag
from .deriva_definitions import ML_SCHEMA, DerivaMLException, MLVocab, Status, RID
from .history import iso_to_snap
from .deriva_model import DerivaModel
from .database_model import DatabaseModel
from .dataset_aux_classes import (
DatasetVersion,
Expand All @@ -61,12 +62,9 @@ class Dataset:

_Logger = logging.getLogger("deriva_ml")

def __init__(self, model: Model, cache_dir: Path):
def __init__(self, model: DerivaModel, cache_dir: Path):
self._model = model
self._ml_schema = ML_SCHEMA
self._domain_schema = [
s for s in model.schemas if s not in ["deriva-ml", "www", "public"]
].pop()
self.dataset_table = self._model.schemas[self._ml_schema].tables["Dataset"]
self._cache_dir = cache_dir
self._logger = logging.getLogger("deriva_ml")
Expand Down Expand Up @@ -360,7 +358,7 @@ def list_dataset_element_types(self) -> Iterable[Table]:

def domain_table(table: Table) -> bool:
return (
table.schema.name == self._domain_schema
table.schema.name == self._model.domain_schema
or table.name == self.dataset_table.name
)

Expand All @@ -370,22 +368,6 @@ def domain_table(table: Table) -> bool:
if domain_table(t := a.other_fkeys.pop().pk_table)
]

def _get_table(self, table: Table | str) -> Table:
# Add table to map
t = table
table_found = False
if not isinstance(table, Table):
for s in self._model.schemas.values():
try:
t = s.tables[t]
table_found = True
break
except KeyError:
pass
if not table_found:
raise DerivaMLException(f"The table {table} doesn't exist.")
return t

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def add_dataset_element_type(self, element: str | Table) -> Table:
"""A dataset_table is a heterogeneous collection of objects, each of which comes from a different table. This
Expand All @@ -399,14 +381,14 @@ def add_dataset_element_type(self, element: str | Table) -> Table:
The table object that was added to the dataset_table.
"""
# Add table to map
element_table = self._get_table(element)
table = self._model.schemas[self._domain_schema].create_table(
element_table = self._model.get_table(element)
table = self._model.schemas[self._model.domain_schema].create_table(
Table.define_association([self.dataset_table, element_table])
)

# self.model = self.catalog.getCatalogModel()
self.dataset_table.annotations.update(self._generate_dataset_annotations())
self._model.apply()
self._model.model.apply()
return table

@validate_call
Expand Down Expand Up @@ -447,7 +429,7 @@ def list_dataset_members(
member_table = assoc_table.table

if (
target_table.schema.name != self._domain_schema
target_table.schema.name != self._model.domain_schema
and target_table != self.dataset_table
):
# Look at domain tables and nested datasets.
Expand Down Expand Up @@ -557,7 +539,7 @@ def check_dataset_cycle(member_rid, path=None):
pb = self._model.catalog.getPathBuilder()
for table, elements in dataset_elements.items():
schema_path = pb.schemas[
self._ml_schema if table == "Dataset" else self._domain_schema
self._ml_schema if table == "Dataset" else self._model.domain_schema
]
fk_column = "Nested_Dataset" if table == "Dataset" else table

Expand All @@ -583,7 +565,7 @@ def list_dataset_parents(self, dataset_rid: RID) -> list[RID]:
RID of the parent dataset_table.
"""
try:
rid_record = self._model.catalog.resolve_rid(dataset_rid, self._model)
rid_record = self._model.catalog.resolve_rid(dataset_rid, self._model.model)
except KeyError as _e:
raise DerivaMLException(f"Invalid RID {dataset_rid}")

Expand Down Expand Up @@ -828,7 +810,7 @@ def include_node(child: Table) -> bool:
return (
child != node
and child not in visited_nodes
and child.schema.name == self._domain_schema
and child.schema.name == self._model.domain_schema
)

# Get all the tables reachable from the end of the path avoiding loops from T1<->T2 via referenced_by
Expand Down
32 changes: 22 additions & 10 deletions src/deriva_ml/dataset_bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ class DatasetBag:
"""DatasetBag is a class that manages a materialized bag. It is created from a locally materialized BDBag for a
dataset_table, which is created either by DerivaML.create_execution, or directly by calling DerivaML.download_dataset.
A general a bag may contain multiple datasets, if the dataset is nested. The DatasetBag is used to represent only
one of the datasets in the bag.
All the metadata associated with the dataset is stored in a SQLLite database that can be queried using SQL.
Attributes
dataset_rid (RID): RID for the specified dataset
version: The version of the dataset
model (DatabaseModel): The Database model that has all the catalog metadata associated with this dataset.
database:
dbase (Connection): connection to the sqlite database holding table values
domain_schema (str): Name of the domain schema
"""

# @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
Expand All @@ -33,10 +45,10 @@ def __init__(
self.dataset_rid = dataset_rid or self.model.dataset_rid
self.model.rid_lookup(
dataset_rid
) # Check to make sure that this dataset is in the
) # Check to make sure that this dataset is in the bag.

self.version = self.model.dataset_version(self.dataset_rid)
self.dataset_table = self.model.dataset_table
self._dataset_table = self.model.dataset_table

def __repr__(self) -> str:
return f"<deriva_ml.DatasetBag object {self.dataset_rid} at {hex(id(self))}>"
Expand All @@ -60,32 +72,32 @@ def get_table_as_dict(self, table: str) -> Generator[dict[str, Any], None, None]

@validate_call
def list_dataset_members(self, recurse: bool = False) -> dict[str, list[tuple]]:
"""Return a list of entities associated with a specific dataset_table.
"""Return a list of entities associated with a specific _dataset_table.
Args:
recurse: (Default value = False)
Returns:
Dictionary of entities associated with a specific dataset_table. Key is the table from which the elements
Dictionary of entities associated with a specific _dataset_table. Key is the table from which the elements
were taken.
"""

# Look at each of the element types that might be in the dataset_table and get the list of rid for them from
# Look at each of the element types that might be in the _dataset_table and get the list of rid for them from
# the appropriate association table.
members = defaultdict(list)
for assoc_table in self.dataset_table.find_associations():
for assoc_table in self._dataset_table.find_associations():
other_fkey = assoc_table.other_fkeys.pop()
self_fkey = assoc_table.self_fkey
target_table = other_fkey.pk_table
member_table = assoc_table.table

if (
target_table.schema.name != self.database.domain_schema
and target_table != self.dataset_table
and target_table != self._dataset_table
):
# Look at domain tables and nested datasets.
continue
if target_table == self.dataset_table:
if target_table == self._dataset_table:
# find_assoc gives us the keys in the wrong position, so swap.
self_fkey, other_fkey = other_fkey, self_fkey
sql_target = self.model.normalize_table_name(target_table.name)
Expand All @@ -107,7 +119,7 @@ def list_dataset_members(self, recurse: bool = False) -> dict[str, list[tuple]]:

target_entities = [] # path.entities().fetch()
members[target_table.name].extend(target_entities)
if recurse and target_table.name == self.dataset_table:
if recurse and target_table.name == self._dataset_table:
# Get the members for all the nested datasets and add to the member list.
nested_datasets = [d["RID"] for d in target_entities]
for ds in nested_datasets:
Expand All @@ -119,7 +131,7 @@ def list_dataset_members(self, recurse: bool = False) -> dict[str, list[tuple]]:

@validate_call
def list_dataset_children(self, recurse: bool = False) -> list["DatasetBag"]:
"""Given a dataset_table RID, return a list of RIDs of any nested datasets.
"""Given a _dataset_table RID, return a list of RIDs of any nested datasets.
Returns:
list of RIDs of nested datasets.
Expand Down
Loading

0 comments on commit 9f702a7

Please sign in to comment.