Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add delete function to API #662

Merged
merged 9 commits into from
Jun 5, 2022
3 changes: 2 additions & 1 deletion lineapy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import atexit

from lineapy.api.api import catalog, get, save, to_pipeline
from lineapy.api.api import catalog, delete, get, save, to_pipeline
from lineapy.data.graph import Graph
from lineapy.data.types import SessionType, ValueType
from lineapy.editors.ipython import start, stop, visualize
Expand All @@ -15,6 +15,7 @@
"save",
"get",
"catalog",
"delete",
"to_pipeline",
"SessionType",
"ValueType",
Expand Down
60 changes: 59 additions & 1 deletion lineapy/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import types
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Union

from lineapy.data.types import Artifact, NodeValue, PipelineType
from lineapy.db.relational import SessionContextORM
Expand Down Expand Up @@ -94,7 +94,11 @@ def save(reference: object, name: str) -> LineaArtifact:
node_id=value_node_id, execution_id=execution_id
):
# can raise ArtifactSaveException

# pickles value of artifact and saves to filesystem
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prob should have renamed the function name since it's not db but rather file path

pickled_path = _try_write_to_db(reference)

# adds reference to pickled file inside database
db.write_node_value(
NodeValue(
node_id=value_node_id,
Expand Down Expand Up @@ -136,6 +140,60 @@ def save(reference: object, name: str) -> LineaArtifact:
return linea_artifact


def delete(
artifact_name: str, version: Optional[Union[int, str]] = None
) -> None:
"""
Deletes an artifact from artifact store. If no other artifacts
refer to the value, the value is also deleted from both the
value node store and the pickle store.

If version is not provided, latest version is used.
"""
execution_context = get_context()
executor = execution_context.executor
db = executor.db

get_version = None if not isinstance(version, int) else version
artifact = db.get_artifact_by_name(artifact_name, version=get_version)

node_id = artifact.node_id
execution_id = artifact.execution_id

num_artifacts = db.number_of_artifacts_per_node(node_id, execution_id)
if num_artifacts == 1:
try:
pickled_path = db.get_node_value_path(node_id, execution_id)
db.delete_node_value_from_db(node_id, execution_id)
if pickled_path is not None:
try:
_try_delete_pickle_file(Path(pickled_path))
except KeyError:
logging.info(f"Pickle not found at {pickled_path}")
else:
logging.info(f"No pickle associated with {node_id}")
except ValueError:
logging.info(f"No pickle associated with {node_id}")

delete_version = version or "latest"
db.delete_artifact_by_name(artifact_name, version=delete_version)


def _try_delete_pickle_file(pickled_path: Path) -> None:
if pickled_path.exists():
pickled_path.unlink()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
# Attempt to reconstruct path to pickle with current
# linea folder and picke base directory.
new_pickled_path = Path(
options.safe_get("artifact_storage_dir")
).joinpath(pickled_path.name)
if new_pickled_path.exists():
new_pickled_path.unlink()
else:
raise KeyError(f"Pickle not found at {pickled_path}")


def _try_write_to_db(value: object) -> Path:
"""
Saves the value to a random file inside linea folder. This file path is returned and eventually saved to the db.
Expand Down
84 changes: 83 additions & 1 deletion lineapy/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, Union, cast

from sqlalchemy.orm import defaultload, scoped_session, sessionmaker
from sqlalchemy.sql.expression import and_
Expand Down Expand Up @@ -467,6 +467,24 @@ def node_value_in_db(
.exists()
).scalar()

def number_of_artifacts_per_node(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know whether a new pickle gets created when the .save is called on the same node multiple times?

self, node_id: LineaID, execution_id: LineaID
) -> int:
"""
Returns number of artifacts that refer to
the same execution node.
"""
return (
self.session.query(ArtifactORM)
.filter(
and_(
ArtifactORM.node_id == node_id,
ArtifactORM.execution_id == execution_id,
)
)
.count()
)

def get_libraries_for_session(
self, session_id: LineaID
) -> List[ImportNodeORM]:
Expand Down Expand Up @@ -599,3 +617,67 @@ def get_source_code_for_session(self, session_id: LineaID) -> str:
if script_source_code_orms is not None
else ""
)

def delete_artifact_by_name(
self, artifact_name: str, version: Union[int, str] = None
):
"""
Deletes the most recent artifact with a certain name.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description doesn't include a reference to "all" and "latest"?

I don't really like the design using strings but it's fine for now.

If a version is not specified, it will delete the most recent
version sorted by date_created
"""
if (
not isinstance(version, int)
and version != "all"
and version != "latest"
):
track(ExceptionEvent("UserException", "Artifact version invalid"))
raise UserException(NameError(f"{version} is an invalid version"))

res_query = self.session.query(ArtifactORM).filter(
ArtifactORM.name == artifact_name
)
if version == "all":
res_query.delete()
else:
if isinstance(version, int):
res_query = res_query.filter(ArtifactORM.version == version)
res = res_query.order_by(ArtifactORM.version.desc()).first()
if res is None:
msg = (
f"Artifact {artifact_name} (version {version})"
if version
else f"Artifact {artifact_name}"
)
track(ExceptionEvent("UserException", "Artifact not found"))
raise UserException(
NameError(
f"{msg} not found. Perhaps there was a typo. Please try lineapy.catalog() to inspect all your artifacts."
)
)
self.session.delete(res)
self.renew_session()

def delete_node_value_from_db(
self, node_id: LineaID, execution_id: LineaID
):
value_orm = (
self.session.query(NodeValueORM)
.filter(
and_(
NodeValueORM.node_id == node_id,
NodeValueORM.execution_id == execution_id,
)
)
.first()
)
if value_orm is None:
track(ExceptionEvent("UserException", "Value node not found"))
raise UserException(
NameError(
f"NodeID {node_id} and ExecutionID {execution_id} does not exist"
)
)

self.session.delete(value_orm)
self.renew_session()
120 changes: 120 additions & 0 deletions tests/end_to_end/test_linea_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from lineapy.utils.utils import prettify


Expand Down Expand Up @@ -73,3 +75,121 @@ def test_save_twice(execute):
assert res.artifacts["x"] == "x = 100\n"
assert res.values["y"] == 100
assert res.artifacts["y"] == "y = 100\n"


def test_delete_artifact(execute):
res = execute(
"""import lineapy
x = 100
lineapy.save(x, 'x')
lineapy.delete('x')
""",
snapshot=False,
)

with pytest.raises(KeyError):
assert res.artifacts["x"]


def test_delete_artifact_latest(execute):
res = execute(
"""import lineapy
x = 100
lineapy.save(x, 'x')
x = 200
lineapy.save(x, 'x')
lineapy.delete('x')

catalog = lineapy.catalog()
versions = [x._version for x in catalog.artifacts if x.name=='x']
num_versions = len(versions)
""",
snapshot=False,
)

assert res.artifacts["x"] == "x = 100\n"
assert res.values["num_versions"] == 1


def test_delete_artifact_version_simple(execute):
res = execute(
"""import lineapy
x = 100
lineapy.save(x, 'x')
lineapy.delete('x', version=0)
""",
snapshot=False,
)

with pytest.raises(KeyError):
assert res.artifacts["x"]


def test_delete_artifact_version(execute):
res = execute(
"""import lineapy
x = 100
lineapy.save(x, 'x')
x = 200
lineapy.save(x, 'x')
lineapy.delete('x', version=1)

catalog = lineapy.catalog()
versions = [x._version for x in catalog.artifacts if x.name=='x']
num_versions = len(versions)
x_retrieve = lineapy.get('x').get_value()

""",
snapshot=False,
)

assert res.values["num_versions"] == 1
assert res.values["x_retrieve"] == 100


def test_delete_artifact_version_complex(execute):
res = execute(
"""import lineapy
x = 100
lineapy.save(x, 'x')
x = 200
lineapy.save(x, 'x')
x = 300
lineapy.save(x, 'x')

# We want to Delete version 1, but the code is executed twice in testing, causing no version 1 to be deleted in second execution
lineapy.delete('x', version=sorted([x._version for x in lineapy.catalog().artifacts if x.name=='x'])[-2])

num_versions = len([x._version for x in lineapy.catalog().artifacts if x.name=='x'])
x_retrieve = lineapy.get('x').get_value()
""",
snapshot=False,
)

assert res.values["num_versions"] == 2
assert res.values["x_retrieve"] == 300


def test_delete_artifact_all(execute):
res = execute(
"""import lineapy
x = 100
lineapy.save(x, 'x')
x = 200
lineapy.save(x, 'x')
x = 300
lineapy.save(x, 'x')
lineapy.delete('x', version='all')

catalog = lineapy.catalog()
versions = [x._version for x in catalog.artifacts if x.name=='x']
num_versions = len(versions)


""",
snapshot=False,
)

assert res.values["num_versions"] == 0
with pytest.raises(KeyError):
assert res.artifacts["x"]