-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add delete function to API #662
Changes from all commits
650219d
c21bc63
4f10dfb
32bbfd7
098ec15
cc5a31a
ed0f41f
61a49f1
5976ca6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
import types | ||
from datetime import datetime | ||
from pathlib import Path | ||
from typing import List, Optional | ||
from typing import List, Optional, Union | ||
|
||
from lineapy.data.types import Artifact, NodeValue, PipelineType | ||
from lineapy.db.relational import SessionContextORM | ||
|
@@ -94,7 +94,11 @@ def save(reference: object, name: str) -> LineaArtifact: | |
node_id=value_node_id, execution_id=execution_id | ||
): | ||
# can raise ArtifactSaveException | ||
|
||
# pickles value of artifact and saves to filesystem | ||
pickled_path = _try_write_to_db(reference) | ||
|
||
# adds reference to pickled file inside database | ||
db.write_node_value( | ||
NodeValue( | ||
node_id=value_node_id, | ||
|
@@ -136,6 +140,60 @@ def save(reference: object, name: str) -> LineaArtifact: | |
return linea_artifact | ||
|
||
|
||
def delete( | ||
artifact_name: str, version: Optional[Union[int, str]] = None | ||
) -> None: | ||
""" | ||
Deletes an artifact from artifact store. If no other artifacts | ||
refer to the value, the value is also deleted from both the | ||
value node store and the pickle store. | ||
|
||
If version is not provided, latest version is used. | ||
""" | ||
execution_context = get_context() | ||
executor = execution_context.executor | ||
db = executor.db | ||
|
||
get_version = None if not isinstance(version, int) else version | ||
artifact = db.get_artifact_by_name(artifact_name, version=get_version) | ||
|
||
node_id = artifact.node_id | ||
execution_id = artifact.execution_id | ||
|
||
num_artifacts = db.number_of_artifacts_per_node(node_id, execution_id) | ||
if num_artifacts == 1: | ||
try: | ||
pickled_path = db.get_node_value_path(node_id, execution_id) | ||
db.delete_node_value_from_db(node_id, execution_id) | ||
if pickled_path is not None: | ||
try: | ||
_try_delete_pickle_file(Path(pickled_path)) | ||
except KeyError: | ||
logging.info(f"Pickle not found at {pickled_path}") | ||
else: | ||
logging.info(f"No pickle associated with {node_id}") | ||
except ValueError: | ||
logging.info(f"No pickle associated with {node_id}") | ||
|
||
delete_version = version or "latest" | ||
db.delete_artifact_by_name(artifact_name, version=delete_version) | ||
|
||
|
||
def _try_delete_pickle_file(pickled_path: Path) -> None: | ||
if pickled_path.exists(): | ||
pickled_path.unlink() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
else: | ||
# Attempt to reconstruct path to pickle with current | ||
# linea folder and picke base directory. | ||
new_pickled_path = Path( | ||
options.safe_get("artifact_storage_dir") | ||
).joinpath(pickled_path.name) | ||
if new_pickled_path.exists(): | ||
new_pickled_path.unlink() | ||
else: | ||
raise KeyError(f"Pickle not found at {pickled_path}") | ||
|
||
|
||
def _try_write_to_db(value: object) -> Path: | ||
""" | ||
Saves the value to a random file inside linea folder. This file path is returned and eventually saved to the db. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
import logging | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Optional, cast | ||
from typing import Any, Dict, List, Optional, Union, cast | ||
|
||
from sqlalchemy.orm import defaultload, scoped_session, sessionmaker | ||
from sqlalchemy.sql.expression import and_ | ||
|
@@ -467,6 +467,24 @@ def node_value_in_db( | |
.exists() | ||
).scalar() | ||
|
||
def number_of_artifacts_per_node( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we know whether a new pickle gets created when the .save is called on the same node multiple times? |
||
self, node_id: LineaID, execution_id: LineaID | ||
) -> int: | ||
""" | ||
Returns number of artifacts that refer to | ||
the same execution node. | ||
""" | ||
return ( | ||
self.session.query(ArtifactORM) | ||
.filter( | ||
and_( | ||
ArtifactORM.node_id == node_id, | ||
ArtifactORM.execution_id == execution_id, | ||
) | ||
) | ||
.count() | ||
) | ||
|
||
def get_libraries_for_session( | ||
self, session_id: LineaID | ||
) -> List[ImportNodeORM]: | ||
|
@@ -599,3 +617,67 @@ def get_source_code_for_session(self, session_id: LineaID) -> str: | |
if script_source_code_orms is not None | ||
else "" | ||
) | ||
|
||
def delete_artifact_by_name( | ||
self, artifact_name: str, version: Union[int, str] = None | ||
): | ||
""" | ||
Deletes the most recent artifact with a certain name. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The description doesn't include a reference to I don't really like the design using strings but it's fine for now. |
||
If a version is not specified, it will delete the most recent | ||
version sorted by date_created | ||
""" | ||
if ( | ||
not isinstance(version, int) | ||
and version != "all" | ||
and version != "latest" | ||
): | ||
track(ExceptionEvent("UserException", "Artifact version invalid")) | ||
raise UserException(NameError(f"{version} is an invalid version")) | ||
|
||
res_query = self.session.query(ArtifactORM).filter( | ||
ArtifactORM.name == artifact_name | ||
) | ||
if version == "all": | ||
res_query.delete() | ||
else: | ||
if isinstance(version, int): | ||
res_query = res_query.filter(ArtifactORM.version == version) | ||
res = res_query.order_by(ArtifactORM.version.desc()).first() | ||
if res is None: | ||
msg = ( | ||
f"Artifact {artifact_name} (version {version})" | ||
if version | ||
else f"Artifact {artifact_name}" | ||
) | ||
track(ExceptionEvent("UserException", "Artifact not found")) | ||
raise UserException( | ||
NameError( | ||
f"{msg} not found. Perhaps there was a typo. Please try lineapy.catalog() to inspect all your artifacts." | ||
) | ||
) | ||
self.session.delete(res) | ||
self.renew_session() | ||
|
||
def delete_node_value_from_db( | ||
self, node_id: LineaID, execution_id: LineaID | ||
): | ||
value_orm = ( | ||
self.session.query(NodeValueORM) | ||
.filter( | ||
and_( | ||
NodeValueORM.node_id == node_id, | ||
NodeValueORM.execution_id == execution_id, | ||
) | ||
) | ||
.first() | ||
) | ||
if value_orm is None: | ||
track(ExceptionEvent("UserException", "Value node not found")) | ||
raise UserException( | ||
NameError( | ||
f"NodeID {node_id} and ExecutionID {execution_id} does not exist" | ||
) | ||
) | ||
|
||
self.session.delete(value_orm) | ||
self.renew_session() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We prob should have renamed the function name since it's not db but rather file path