From 92e8f50ad520c152817a892e33046d49aac137b2 Mon Sep 17 00:00:00 2001 From: Nicola Soranzo Date: Thu, 20 Feb 2025 14:29:14 +0000 Subject: [PATCH 1/5] Remove unused function --- lib/galaxy/tool_shed/util/repository_util.py | 26 -------------------- lib/tool_shed/util/repository_util.py | 2 -- 2 files changed, 28 deletions(-) diff --git a/lib/galaxy/tool_shed/util/repository_util.py b/lib/galaxy/tool_shed/util/repository_util.py index 5f9a4cd0ff7b..2c92c061e06f 100644 --- a/lib/galaxy/tool_shed/util/repository_util.py +++ b/lib/galaxy/tool_shed/util/repository_util.py @@ -17,7 +17,6 @@ from sqlalchemy import ( and_, false, - or_, ) from sqlalchemy.orm import joinedload @@ -242,30 +241,6 @@ def get_absolute_path_to_file_in_repository(repo_files_dir, file_name): return file_path -def get_ids_of_tool_shed_repositories_being_installed(app, as_string=False): - installing_repository_ids = [] - new_status = app.install_model.ToolShedRepository.installation_status.NEW - cloning_status = app.install_model.ToolShedRepository.installation_status.CLONING - setting_tool_versions_status = app.install_model.ToolShedRepository.installation_status.SETTING_TOOL_VERSIONS - installing_dependencies_status = ( - app.install_model.ToolShedRepository.installation_status.INSTALLING_TOOL_DEPENDENCIES - ) - loading_datatypes_status = app.install_model.ToolShedRepository.installation_status.LOADING_PROPRIETARY_DATATYPES - for tool_shed_repository in app.install_model.context.query(app.install_model.ToolShedRepository).filter( - or_( - app.install_model.ToolShedRepository.status == new_status, - app.install_model.ToolShedRepository.status == cloning_status, - app.install_model.ToolShedRepository.status == setting_tool_versions_status, - app.install_model.ToolShedRepository.status == installing_dependencies_status, - app.install_model.ToolShedRepository.status == loading_datatypes_status, - ) - ): - installing_repository_ids.append(app.security.encode_id(tool_shed_repository.id)) - if as_string: - return ",".join(installing_repository_ids) - return installing_repository_ids - - def get_installed_repository( app, tool_shed=None, @@ -773,7 +748,6 @@ def set_repository_attributes(app, repository, status, error_message, deleted, u "extract_components_from_tuple", "generate_tool_shed_repository_install_dir", "get_absolute_path_to_file_in_repository", - "get_ids_of_tool_shed_repositories_being_installed", "get_installed_repository", "get_installed_tool_shed_repository", "get_prior_import_or_install_required_dict", diff --git a/lib/tool_shed/util/repository_util.py b/lib/tool_shed/util/repository_util.py index 5f198a131fe9..1ae90d7e488c 100644 --- a/lib/tool_shed/util/repository_util.py +++ b/lib/tool_shed/util/repository_util.py @@ -28,7 +28,6 @@ extract_components_from_tuple, generate_tool_shed_repository_install_dir, get_absolute_path_to_file_in_repository, - get_ids_of_tool_shed_repositories_being_installed, get_installed_repository, get_installed_tool_shed_repository, get_prior_import_or_install_required_dict, @@ -601,7 +600,6 @@ def delete_repository_category_associations(session, repository_category_assoc_m "generate_sharable_link_for_repository_in_tool_shed", "generate_tool_shed_repository_install_dir", "get_absolute_path_to_file_in_repository", - "get_ids_of_tool_shed_repositories_being_installed", "get_installed_repository", "get_installed_tool_shed_repository", "get_prior_import_or_install_required_dict", From dd80698e4b046bf4b0d4e900505003cb4bb5db58 Mon Sep 17 00:00:00 2001 From: Nicola Soranzo Date: Mon, 10 Feb 2025 14:34:47 +0000 Subject: [PATCH 2/5] Removed deprecated ``current`` attribute of scoped_session --- lib/galaxy/app_unittest_utils/tools_support.py | 1 - lib/galaxy/model/base.py | 3 --- lib/tool_shed/managers/repositories.py | 2 +- scripts/cleanup_datasets/cleanup_datasets.py | 2 +- scripts/grt/export.py | 2 +- scripts/set_dataset_sizes.py | 2 +- scripts/set_user_disk_usage.py | 2 +- scripts/tool_shed/deprecate_repositories_without_metadata.py | 2 +- test/integration/test_save_job_id_on_datasets.py | 2 +- 9 files changed, 7 insertions(+), 11 deletions(-) diff --git a/lib/galaxy/app_unittest_utils/tools_support.py b/lib/galaxy/app_unittest_utils/tools_support.py index 9342f80c2b34..e87f70c9f82f 100644 --- a/lib/galaxy/app_unittest_utils/tools_support.py +++ b/lib/galaxy/app_unittest_utils/tools_support.py @@ -130,7 +130,6 @@ def __init__(self, model_objects=None): self.flushed = False self.model_objects = model_objects or defaultdict(dict) self.created_objects = [] - self.current = self def expunge_all(self): self.expunged_all = True diff --git a/lib/galaxy/model/base.py b/lib/galaxy/model/base.py index b9e81bdfe017..eedebec80053 100644 --- a/lib/galaxy/model/base.py +++ b/lib/galaxy/model/base.py @@ -59,9 +59,6 @@ def __init__(self, model_modules, engine): self._SessionLocal = sessionmaker(autoflush=False) versioned_session(self._SessionLocal) context = scoped_session(self._SessionLocal, scopefunc=self.request_scopefunc) - # For backward compatibility with "context.current" - # deprecated? - context.current = context self.session = context self.scoped_registry = context.registry diff --git a/lib/tool_shed/managers/repositories.py b/lib/tool_shed/managers/repositories.py index 2a9b0e36e270..5e21ba862674 100644 --- a/lib/tool_shed/managers/repositories.py +++ b/lib/tool_shed/managers/repositories.py @@ -211,7 +211,7 @@ def index_tool_ids(app: ToolShedApp, tool_ids: List[str]) -> Dict[str, Any]: owner = repository.user.username name = repository.name assert name - repository = _get_repository_by_name_and_owner(app.model.session().current, name, owner, app.model.User) + repository = _get_repository_by_name_and_owner(app.model.session, name, owner, app.model.User) if not repository: log.warning(f"Repository {owner}/{name} does not exist, skipping") continue diff --git a/scripts/cleanup_datasets/cleanup_datasets.py b/scripts/cleanup_datasets/cleanup_datasets.py index 5cd7330ea7ea..4a04f2ad90f6 100755 --- a/scripts/cleanup_datasets/cleanup_datasets.py +++ b/scripts/cleanup_datasets/cleanup_datasets.py @@ -706,7 +706,7 @@ def sa_session(self): session from the threadlocal session context, but this is provided to allow migration toward a more SQLAlchemy 0.4 style of use. """ - return self.model.context.current + return self.model.context def shutdown(self): self.object_store.shutdown() diff --git a/scripts/grt/export.py b/scripts/grt/export.py index aaa0da2d502d..12a117777934 100644 --- a/scripts/grt/export.py +++ b/scripts/grt/export.py @@ -124,7 +124,7 @@ def annotate(label, human_label=None): # Galaxy overrides our logging level. logging.getLogger().setLevel(getattr(logging, args.loglevel.upper())) - sa_session = model.context.current + sa_session = model.context annotate("galaxy_end") # Fetch jobs COMPLETED with status OK that have not yet been sent. diff --git a/scripts/set_dataset_sizes.py b/scripts/set_dataset_sizes.py index 0e235af24e1a..169e70aec979 100644 --- a/scripts/set_dataset_sizes.py +++ b/scripts/set_dataset_sizes.py @@ -31,7 +31,7 @@ def init(): if __name__ == "__main__": print("Loading Galaxy model...") model, object_store = init() - sa_session = model.context.current + sa_session = model.context session = sa_session() set = 0 diff --git a/scripts/set_user_disk_usage.py b/scripts/set_user_disk_usage.py index da42f7a55902..43c1f2a9fe01 100755 --- a/scripts/set_user_disk_usage.py +++ b/scripts/set_user_disk_usage.py @@ -70,7 +70,7 @@ def quotacheck(sa_session, users, engine, object_store): if __name__ == "__main__": print("Loading Galaxy model...") model, object_store, engine = init() - sa_session = model.context.current + sa_session = model.context if not args.username and not args.email: user_count = sa_session.query(model.User).count() diff --git a/scripts/tool_shed/deprecate_repositories_without_metadata.py b/scripts/tool_shed/deprecate_repositories_without_metadata.py index badb49ad3725..c5112411670f 100644 --- a/scripts/tool_shed/deprecate_repositories_without_metadata.py +++ b/scripts/tool_shed/deprecate_repositories_without_metadata.py @@ -205,7 +205,7 @@ def sa_session(self): session from the threadlocal session context, but this is provided to allow migration toward a more SQLAlchemy 0.4 style of use. """ - return self.model.context.current + return self.model.context def shutdown(self): pass diff --git a/test/integration/test_save_job_id_on_datasets.py b/test/integration/test_save_job_id_on_datasets.py index 99685b4ed034..12d681e3fb62 100644 --- a/test/integration/test_save_job_id_on_datasets.py +++ b/test/integration/test_save_job_id_on_datasets.py @@ -39,7 +39,7 @@ def test_driver(): @pytest.mark.parametrize("tool_id", TEST_TOOL_IDS) def test_tool_datasets(tool_id, test_driver): test_driver.run_tool_test(tool_id) - session = test_driver.app.model.context.current + session = test_driver.app.model.context job = session.scalars(select(model.Job).order_by(model.Job.id.desc()).limit(1)).first() datasets = session.scalars(select(model.Dataset).filter(model.Dataset.job_id == job.id)).all() From 0f1bfb6839d9a62036522afc5fe95680551243e0 Mon Sep 17 00:00:00 2001 From: Nicola Soranzo Date: Mon, 10 Feb 2025 15:30:39 +0000 Subject: [PATCH 3/5] Type annotation improvements --- lib/galaxy/app.py | 3 +- lib/galaxy/celery/base_task.py | 10 +- lib/galaxy/managers/context.py | 5 +- lib/galaxy/managers/histories.py | 10 +- lib/galaxy/managers/history_contents.py | 4 +- lib/galaxy/managers/job_connections.py | 8 +- lib/galaxy/managers/model_stores.py | 1 + lib/galaxy/managers/pages.py | 18 +- lib/galaxy/managers/roles.py | 2 +- lib/galaxy/managers/secured.py | 6 +- lib/galaxy/managers/sharable.py | 5 +- lib/galaxy/managers/users.py | 2 +- lib/galaxy/managers/visualizations.py | 12 +- lib/galaxy/managers/workflows.py | 11 +- lib/galaxy/metadata/set_metadata.py | 2 + lib/galaxy/model/__init__.py | 9 +- lib/galaxy/model/base.py | 12 +- lib/galaxy/model/item_attrs.py | 3 +- lib/galaxy/model/mapping.py | 3 - lib/galaxy/model/metadata.py | 5 +- lib/galaxy/model/store/__init__.py | 12 +- lib/galaxy/model/tags.py | 46 ++--- .../galaxy_install/install_manager.py | 4 +- .../repository_dependency_manager.py | 5 +- lib/galaxy/tool_shed/util/repository_util.py | 165 ++++++------------ lib/galaxy/tool_shed/util/shed_util_common.py | 8 +- lib/galaxy/tools/parameters/wrapped.py | 4 +- lib/galaxy/tools/remote_tool_eval.py | 1 + lib/galaxy/webapps/galaxy/api/users.py | 1 + .../webapps/galaxy/services/histories.py | 2 +- lib/galaxy/webapps/galaxy/services/pages.py | 7 +- .../webapps/galaxy/services/sharable.py | 2 +- .../webapps/galaxy/services/visualizations.py | 2 +- lib/galaxy/webapps/reports/app.py | 2 +- .../dependencies/attribute_handlers.py | 4 +- .../repository/relation_builder.py | 104 ++++++----- lib/tool_shed/grids/repository_grids.py | 6 +- lib/tool_shed/managers/repositories.py | 10 +- .../metadata/repository_metadata_manager.py | 3 +- lib/tool_shed/tools/tool_validator.py | 9 + lib/tool_shed/tools/tool_version_manager.py | 6 +- lib/tool_shed/util/metadata_util.py | 19 +- lib/tool_shed/util/repository_util.py | 144 +++++++-------- lib/tool_shed/util/shed_util_common.py | 4 +- lib/tool_shed/webapp/api/repositories.py | 5 +- .../webapp/api/repository_revisions.py | 8 +- lib/tool_shed/webapp/app.py | 2 +- lib/tool_shed/webapp/controllers/hg.py | 4 +- .../webapp/controllers/repository.py | 42 +++-- lib/tool_shed/webapp/model/__init__.py | 21 ++- lib/tool_shed/webapp/model/db/__init__.py | 36 ++++ .../webapps/tool_shed/repository/common.mako | 4 +- scripts/grt/export.py | 8 +- scripts/set_dataset_sizes.py | 17 +- scripts/set_user_disk_usage.py | 6 +- ...deprecate_repositories_without_metadata.py | 30 ++-- .../webapps/tool_shed/repository/common.mako | 4 +- test/integration/oidc/test_auth_oidc.py | 7 +- test/integration/test_kubernetes_runner.py | 5 +- .../test_page_revision_json_encoding.py | 2 + test/integration/test_remote_files_posix.py | 1 + .../test_workflow_handler_configuration.py | 4 +- test/unit/app/authnz/test_custos_authnz.py | 11 +- test/unit/app/jobs/test_job_wrapper.py | 6 +- test/unit/app/jobs/test_runner_local.py | 11 +- .../managers/test_JobConnectionsManager.py | 10 +- .../app/managers/test_user_file_sources.py | 2 +- .../app/managers/test_user_object_stores.py | 5 +- test/unit/data/model/test_model_store.py | 2 + test/unit/data/test_galaxy_mapping.py | 40 ++--- test/unit/data/test_quota.py | 57 +++--- .../unit/tool_shed/test_tool_panel_manager.py | 7 +- 72 files changed, 575 insertions(+), 483 deletions(-) create mode 100644 lib/tool_shed/webapp/model/db/__init__.py diff --git a/lib/galaxy/app.py b/lib/galaxy/app.py index a576adaec40b..d1b9aa8d46ec 100644 --- a/lib/galaxy/app.py +++ b/lib/galaxy/app.py @@ -608,7 +608,8 @@ def __init__(self, configure_logging=True, use_converters=True, use_display_appl self._register_singleton(ShortTermStorageMonitor, short_term_storage_manager) # type: ignore[type-abstract] # Tag handler - self.tag_handler = self._register_singleton(GalaxyTagHandler) + tag_handler = GalaxyTagHandler(self.model.context) + self.tag_handler = self._register_singleton(GalaxyTagHandler, tag_handler) self.user_manager = self._register_singleton(UserManager) self._register_singleton(GalaxySessionManager) self.hda_manager = self._register_singleton(HDAManager) diff --git a/lib/galaxy/celery/base_task.py b/lib/galaxy/celery/base_task.py index 9ab36837c1d8..af5d5accdc94 100644 --- a/lib/galaxy/celery/base_task.py +++ b/lib/galaxy/celery/base_task.py @@ -11,9 +11,9 @@ ) from sqlalchemy.dialects.postgresql import insert as ps_insert from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import scoped_session from galaxy.model import CeleryUserRateLimit -from galaxy.model.scoped_session import galaxy_scoped_session class GalaxyTaskBeforeStart: @@ -45,7 +45,7 @@ class GalaxyTaskBeforeStartUserRateLimit(GalaxyTaskBeforeStart): def __init__( self, tasks_per_user_per_sec: float, - ga_scoped_session: galaxy_scoped_session, + ga_scoped_session: scoped_session, ): try: self.task_exec_countdown_secs = 1 / tasks_per_user_per_sec @@ -68,7 +68,7 @@ def __call__(self, task: Task, task_id, args, kwargs): @abstractmethod def calculate_task_start_time( - self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime + self, user_id: int, sa_session: scoped_session, task_interval_secs: float, now: datetime.datetime ) -> datetime.datetime: return now @@ -80,7 +80,7 @@ class GalaxyTaskBeforeStartUserRateLimitPostgres(GalaxyTaskBeforeStartUserRateLi """ def calculate_task_start_time( - self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime + self, user_id: int, sa_session: scoped_session, task_interval_secs: float, now: datetime.datetime ) -> datetime.datetime: update_stmt = ( update(CeleryUserRateLimit) @@ -125,7 +125,7 @@ class GalaxyTaskBeforeStartUserRateLimitStandard(GalaxyTaskBeforeStartUserRateLi ) def calculate_task_start_time( - self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime + self, user_id: int, sa_session: scoped_session, task_interval_secs: float, now: datetime.datetime ) -> datetime.datetime: last_scheduled_time = None last_scheduled_time = sa_session.scalars(self._select_stmt, {"userid": user_id}).first() diff --git a/lib/galaxy/managers/context.py b/lib/galaxy/managers/context.py index eef80a5fe675..8446c27e5b86 100644 --- a/lib/galaxy/managers/context.py +++ b/lib/galaxy/managers/context.py @@ -63,7 +63,6 @@ User, ) from galaxy.model.base import ModelMapping -from galaxy.model.scoped_session import galaxy_scoped_session from galaxy.model.tags import GalaxyTagHandlerSession from galaxy.schema.tasks import RequestUser from galaxy.security.idencoding import IdEncodingHelper @@ -155,10 +154,10 @@ def log_event(self, message, tool_id=None, **kwargs): self.sa_session.commit() @property - def sa_session(self) -> galaxy_scoped_session: + def sa_session(self): """Provide access to Galaxy's SQLAlchemy session. - :rtype: galaxy.model.scoped_session.galaxy_scoped_session + :rtype: sqlalchemy.orm.scoped_session """ return self.app.model.session diff --git a/lib/galaxy/managers/histories.py b/lib/galaxy/managers/histories.py index 8ccd80b88b27..f074e8a64788 100644 --- a/lib/galaxy/managers/histories.py +++ b/lib/galaxy/managers/histories.py @@ -14,6 +14,7 @@ Optional, Set, Tuple, + TYPE_CHECKING, Union, ) @@ -85,6 +86,9 @@ RawTextTerm, ) +if TYPE_CHECKING: + from sqlalchemy.engine import ScalarResult + log = logging.getLogger(__name__) INDEX_SEARCH_FILTERS = { @@ -95,7 +99,7 @@ } -class HistoryManager(sharable.SharableModelManager, deletable.PurgableManagerMixin, SortableManager): +class HistoryManager(sharable.SharableModelManager[model.History], deletable.PurgableManagerMixin, SortableManager): model_class = model.History foreign_key_name = "history" user_share_model = model.HistoryUserShareAssociation @@ -120,7 +124,7 @@ def __init__( def index_query( self, trans: ProvidesUserContext, payload: HistoryIndexQueryPayload, include_total_count: bool = False - ) -> Tuple[List[model.History], int]: + ) -> Tuple["ScalarResult[model.History]", Union[int, None]]: show_deleted = False show_own = payload.show_own show_published = payload.show_published @@ -234,7 +238,7 @@ def p_tag_filter(term_text: str, quoted: bool): stmt = stmt.limit(payload.limit) if payload.offset is not None: stmt = stmt.offset(payload.offset) - return trans.sa_session.scalars(stmt), total_matches # type:ignore[return-value] + return trans.sa_session.scalars(stmt), total_matches # .... sharable # overriding to handle anonymous users' current histories in both cases diff --git a/lib/galaxy/managers/history_contents.py b/lib/galaxy/managers/history_contents.py index 0db49a4536e8..f588a083e82b 100644 --- a/lib/galaxy/managers/history_contents.py +++ b/lib/galaxy/managers/history_contents.py @@ -168,8 +168,8 @@ def state_counts(self, history): statement: Select = ( select(sql.column("state"), func.count()).select_from(contents_subquery).group_by(sql.column("state")) ) - counts = self.app.model.session().execute(statement).fetchall() - return dict(counts) + counts = self.app.model.session.execute(statement).fetchall() + return dict(counts) # type:ignore[arg-type] def active_counts(self, history): """ diff --git a/lib/galaxy/managers/job_connections.py b/lib/galaxy/managers/job_connections.py index c3317a0e8a52..f90f3c7c849b 100644 --- a/lib/galaxy/managers/job_connections.py +++ b/lib/galaxy/managers/job_connections.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + from sqlalchemy import ( literal, union, @@ -9,13 +11,15 @@ from galaxy import model from galaxy.managers.base import get_class -from galaxy.model.scoped_session import galaxy_scoped_session + +if TYPE_CHECKING: + from sqlalchemy.orm import scoped_session class JobConnectionsManager: """Get connections graph of inputs and outputs for given item""" - def __init__(self, sa_session: galaxy_scoped_session): + def __init__(self, sa_session: "scoped_session"): self.sa_session = sa_session def get_connections_graph(self, id: int, src: str): diff --git a/lib/galaxy/managers/model_stores.py b/lib/galaxy/managers/model_stores.py index c5c8a93a1455..1f397a5c16e9 100644 --- a/lib/galaxy/managers/model_stores.py +++ b/lib/galaxy/managers/model_stores.py @@ -317,6 +317,7 @@ def import_model_store(self, request: ImportModelStoreTaskRequest): def _build_user_context(self, user_id: int): user = self._user_manager.by_id(user_id) + assert user is not None user_context = ModelStoreUserContext(self._app, user) return user_context diff --git a/lib/galaxy/managers/pages.py b/lib/galaxy/managers/pages.py index 7913fa30375a..6c13d183e558 100644 --- a/lib/galaxy/managers/pages.py +++ b/lib/galaxy/managers/pages.py @@ -14,9 +14,10 @@ Callable, Optional, Tuple, + TYPE_CHECKING, + Union, ) -import sqlalchemy from sqlalchemy import ( desc, false, @@ -76,6 +77,9 @@ RawTextTerm, ) +if TYPE_CHECKING: + from sqlalchemy.engine import ScalarResult + log = logging.getLogger(__name__) # Copied from https://github.com/kurtmckee/feedparser @@ -121,7 +125,7 @@ } -class PageManager(sharable.SharableModelManager, UsesAnnotations): +class PageManager(sharable.SharableModelManager[model.Page], UsesAnnotations): """Provides operations for managing a Page.""" model_class = model.Page @@ -139,7 +143,7 @@ def __init__(self, app: MinimalManagerApp): def index_query( self, trans: ProvidesUserContext, payload: PageIndexQueryPayload, include_total_count: bool = False - ) -> Tuple[sqlalchemy.engine.Result, int]: + ) -> Tuple["ScalarResult[model.Page]", Union[int, None]]: show_deleted = payload.deleted show_own = payload.show_own show_published = payload.show_published @@ -242,7 +246,7 @@ def p_tag_filter(term_text: str, quoted: bool): stmt = stmt.limit(payload.limit) if payload.offset is not None: stmt = stmt.offset(payload.offset) - return trans.sa_session.scalars(stmt), total_matches # type:ignore[return-value] + return trans.sa_session.scalars(stmt), total_matches def create_page(self, trans, payload: CreatePagePayload): user = trans.get_user() @@ -269,7 +273,7 @@ def create_page(self, trans, payload: CreatePagePayload): content = self.rewrite_content_for_import(trans, content, content_format) # Create the new stored page - page = trans.app.model.Page() + page = model.Page() page.title = payload.title page.slug = payload.slug if (page_annotation := payload.annotation) is not None: @@ -278,7 +282,7 @@ def create_page(self, trans, payload: CreatePagePayload): page.user = user # And the first (empty) page revision - page_revision = trans.app.model.PageRevision() + page_revision = model.PageRevision() page_revision.title = payload.title page_revision.page = page page.latest_revision = page_revision @@ -310,7 +314,7 @@ def save_new_revision(self, trans, page, payload): content_format = page.latest_revision.content_format content = self.rewrite_content_for_import(trans, content, content_format=content_format) - page_revision = trans.app.model.PageRevision() + page_revision = model.PageRevision() page_revision.title = title page_revision.page = page page.latest_revision = page_revision diff --git a/lib/galaxy/managers/roles.py b/lib/galaxy/managers/roles.py index 8d15c9ee0368..0de7e9f02aa4 100644 --- a/lib/galaxy/managers/roles.py +++ b/lib/galaxy/managers/roles.py @@ -128,7 +128,7 @@ def purge(self, trans: ProvidesUserContext, role: model.Role) -> model.Role: raise RequestParameterInvalidException(f"Role '{role.name}' has not been deleted, so it cannot be purged.") # Delete UserRoleAssociations for ura in role.users: - user = sa_session.get(trans.app.model.User, ura.user_id) + user = sa_session.get(model.User, ura.user_id) assert user # Delete DefaultUserPermissions for associated users for dup in user.default_permissions: diff --git a/lib/galaxy/managers/secured.py b/lib/galaxy/managers/secured.py index 64721d8b137e..479337fbe3ca 100644 --- a/lib/galaxy/managers/secured.py +++ b/lib/galaxy/managers/secured.py @@ -33,14 +33,14 @@ class AccessibleManagerMixin: def by_id(self, id: int): ... # don't want to override by_id since consumers will also want to fetch w/o any security checks - def is_accessible(self, item, user: model.User, **kwargs: Any) -> bool: + def is_accessible(self, item, user: Optional[model.User], **kwargs: Any) -> bool: """ Return True if the item accessible to user. """ # override in subclasses raise exceptions.NotImplemented("Abstract interface Method") - def get_accessible(self, id: int, user: model.User, **kwargs: Any): + def get_accessible(self, id: int, user: Optional[model.User], **kwargs: Any): """ Return the item with the given id if it's accessible to user, otherwise raise an error. @@ -50,7 +50,7 @@ def get_accessible(self, id: int, user: model.User, **kwargs: Any): item = self.by_id(id) return self.error_unless_accessible(item, user, **kwargs) - def error_unless_accessible(self, item: "Query", user, **kwargs): + def error_unless_accessible(self, item: "Query", user: Optional[model.User], **kwargs): """ Raise an error if the item is NOT accessible to user, otherwise return the item. diff --git a/lib/galaxy/managers/sharable.py b/lib/galaxy/managers/sharable.py index e45332380877..bcc0d2489746 100644 --- a/lib/galaxy/managers/sharable.py +++ b/lib/galaxy/managers/sharable.py @@ -17,6 +17,7 @@ Optional, Set, Type, + TypeVar, ) from slugify import slugify @@ -54,10 +55,12 @@ from galaxy.util.hash_util import md5_hash_str log = logging.getLogger(__name__) +# Only model classes that have `users_shared_with` field +U = TypeVar("U", model.History, model.Page, model.StoredWorkflow, model.Visualization) class SharableModelManager( - base.ModelManager, + base.ModelManager[U], secured.OwnableManagerMixin, secured.AccessibleManagerMixin, annotatable.AnnotatableManagerMixin, diff --git a/lib/galaxy/managers/users.py b/lib/galaxy/managers/users.py index 6027d1cba007..25d3f6463a9a 100644 --- a/lib/galaxy/managers/users.py +++ b/lib/galaxy/managers/users.py @@ -278,7 +278,7 @@ def _error_on_duplicate_email(self, email: str) -> None: if self.by_email(email) is not None: raise exceptions.Conflict("Email must be unique", email=email) - def by_id(self, user_id: int) -> model.User: + def by_id(self, user_id: int) -> Optional[model.User]: return self.app.model.session.get(self.model_class, user_id) # ---- filters diff --git a/lib/galaxy/managers/visualizations.py b/lib/galaxy/managers/visualizations.py index 80fbf13348c8..03c36b9d0c24 100644 --- a/lib/galaxy/managers/visualizations.py +++ b/lib/galaxy/managers/visualizations.py @@ -8,8 +8,9 @@ import logging from typing import ( Dict, - List, Tuple, + TYPE_CHECKING, + Union, ) from sqlalchemy import ( @@ -44,6 +45,9 @@ RawTextTerm, ) +if TYPE_CHECKING: + from sqlalchemy.engine import ScalarResult + log = logging.getLogger(__name__) @@ -59,7 +63,7 @@ } -class VisualizationManager(sharable.SharableModelManager): +class VisualizationManager(sharable.SharableModelManager[model.Visualization]): """ Handle operations outside and between visualizations and other models. """ @@ -76,7 +80,7 @@ class VisualizationManager(sharable.SharableModelManager): def index_query( self, trans: ProvidesUserContext, payload: VisualizationIndexQueryPayload, include_total_count: bool = False - ) -> Tuple[List[model.Visualization], int]: + ) -> Tuple["ScalarResult[model.Visualization]", Union[int, None]]: show_deleted = payload.deleted show_own = payload.show_own show_published = payload.show_published @@ -171,7 +175,7 @@ def p_tag_filter(term_text: str, quoted: bool): stmt = stmt.limit(payload.limit) if payload.offset is not None: stmt = stmt.offset(payload.offset) - return trans.sa_session.scalars(stmt), total_matches # type:ignore[return-value] + return trans.sa_session.scalars(stmt), total_matches class VisualizationSerializer(sharable.SharableModelSerializer): diff --git a/lib/galaxy/managers/workflows.py b/lib/galaxy/managers/workflows.py index 60680798057b..1f5ec656cce3 100644 --- a/lib/galaxy/managers/workflows.py +++ b/lib/galaxy/managers/workflows.py @@ -10,10 +10,10 @@ NamedTuple, Optional, Tuple, + TYPE_CHECKING, Union, ) -import sqlalchemy import yaml from gxformat2 import ( from_galaxy_native, @@ -140,6 +140,9 @@ ) from galaxy.workflow.trs_proxy import TrsProxy +if TYPE_CHECKING: + from sqlalchemy.engine import ScalarResult + log = logging.getLogger(__name__) @@ -154,7 +157,7 @@ } -class WorkflowsManager(sharable.SharableModelManager, deletable.DeletableManagerMixin): +class WorkflowsManager(sharable.SharableModelManager[model.StoredWorkflow], deletable.DeletableManagerMixin): """Handle CRUD type operations related to workflows. More interesting stuff regarding workflow execution, step sorting, etc... can be found in the galaxy.workflow module. @@ -170,7 +173,7 @@ def __init__(self, app: MinimalManagerApp): def index_query( self, trans: ProvidesUserContext, payload: WorkflowIndexQueryPayload, include_total_count: bool = False - ) -> Tuple[sqlalchemy.engine.Result, Optional[int]]: + ) -> Tuple["ScalarResult[model.StoredWorkflow]", Optional[int]]: show_published = payload.show_published show_hidden = payload.show_hidden show_deleted = payload.show_deleted @@ -291,7 +294,7 @@ def name_filter(term): if payload.offset is not None: stmt = stmt.offset(payload.offset) result = trans.sa_session.scalars(stmt).unique() - return result, total_matches # type:ignore[return-value] + return result, total_matches def get_stored_workflow(self, trans, workflow_id, by_stored_id=True) -> StoredWorkflow: """Use a supplied ID (UUID or encoded stored workflow ID) to find diff --git a/lib/galaxy/metadata/set_metadata.py b/lib/galaxy/metadata/set_metadata.py index 38f3f7917729..d53bf2306c94 100644 --- a/lib/galaxy/metadata/set_metadata.py +++ b/lib/galaxy/metadata/set_metadata.py @@ -56,6 +56,7 @@ ) from galaxy.model.custom_types import total_size from galaxy.model.metadata import MetadataTempFile +from galaxy.model.store import SessionlessContext from galaxy.model.store.discover import MaxDiscoveredFilesExceededError from galaxy.objectstore import ( build_object_store_from_config, @@ -303,6 +304,7 @@ def set_meta(new_dataset_instance, file_dict): import_model_store = store.imported_store_for_metadata( tool_job_working_directory / "metadata/outputs_new", object_store=object_store ) + assert isinstance(import_model_store.sa_session, SessionlessContext) tool_script_file = tool_job_working_directory / "tool_script.sh" job: Optional[Job] = None diff --git a/lib/galaxy/model/__init__.py b/lib/galaxy/model/__init__.py index 119454c93963..d13ab2f8cbfc 100644 --- a/lib/galaxy/model/__init__.py +++ b/lib/galaxy/model/__init__.py @@ -138,7 +138,6 @@ import galaxy.exceptions import galaxy.model.metadata -import galaxy.model.tags import galaxy.security.passwords import galaxy.util from galaxy.files.templates import ( @@ -6111,6 +6110,8 @@ def __init__( def to_history_dataset_association( self, target_history, parent_id=None, add_to_history=False, visible=None, commit=True ): + from galaxy.model.tags import GalaxyTagHandler + sa_session = object_session(self) hda = HistoryDatasetAssociation( name=self.name, @@ -6128,7 +6129,7 @@ def to_history_dataset_association( history=target_history, ) - tag_manager = galaxy.model.tags.GalaxyTagHandler(sa_session) + tag_manager = GalaxyTagHandler(sa_session) src_ldda_tags = tag_manager.get_tags_str(self.tags) tag_manager.apply_item_tags(user=self.user, item=hda, tags_str=src_ldda_tags, flush=False) sa_session.add(hda) @@ -6142,6 +6143,8 @@ def to_history_dataset_association( return hda def copy(self, parent_id=None, target_folder=None, flush=True): + from galaxy.model.tags import GalaxyTagHandler + sa_session = object_session(self) ldda = LibraryDatasetDatasetAssociation( name=self.name, @@ -6159,7 +6162,7 @@ def copy(self, parent_id=None, target_folder=None, flush=True): folder=target_folder, ) - tag_manager = galaxy.model.tags.GalaxyTagHandler(sa_session) + tag_manager = GalaxyTagHandler(sa_session) src_ldda_tags = tag_manager.get_tags_str(self.tags) tag_manager.apply_item_tags(user=self.user, item=ldda, tags_str=src_ldda_tags) diff --git a/lib/galaxy/model/base.py b/lib/galaxy/model/base.py index eedebec80053..06448c76ee5e 100644 --- a/lib/galaxy/model/base.py +++ b/lib/galaxy/model/base.py @@ -11,8 +11,10 @@ getmembers, isclass, ) +from types import ModuleType from typing import ( Dict, + List, Type, Union, ) @@ -54,7 +56,7 @@ def check_database_connection(session): # TODO: Refactor this to be a proper class, not a bunch. class ModelMapping(Bunch): - def __init__(self, model_modules, engine): + def __init__(self, model_modules: List[ModuleType], engine): self.engine = engine self._SessionLocal = sessionmaker(autoflush=False) versioned_session(self._SessionLocal) @@ -62,11 +64,11 @@ def __init__(self, model_modules, engine): self.session = context self.scoped_registry = context.registry - model_classes = {} + model_classes: Dict[str, type] = {} for module in model_modules: - m_obs = getmembers(module, isclass) - m_obs = dict([m for m in m_obs if m[1].__module__ == module.__name__]) - model_classes.update(m_obs) + name_class_pairs = getmembers(module, isclass) + filtered_module_classes_dict = dict(m for m in name_class_pairs if m[1].__module__ == module.__name__) + model_classes.update(filtered_module_classes_dict) super().__init__(**model_classes) diff --git a/lib/galaxy/model/item_attrs.py b/lib/galaxy/model/item_attrs.py index 4e81dda02fea..d816ae832741 100644 --- a/lib/galaxy/model/item_attrs.py +++ b/lib/galaxy/model/item_attrs.py @@ -5,6 +5,7 @@ # Cannot import galaxy.model b/c it creates a circular import graph. import galaxy +from galaxy.util import unicodify log = logging.getLogger(__name__) @@ -147,7 +148,7 @@ def get_item_annotation_str(db_session, user, item): else: annotation_obj = get_item_annotation_obj(db_session, user, item) if annotation_obj: - return galaxy.util.unicodify(annotation_obj.annotation) + return unicodify(annotation_obj.annotation) return None diff --git a/lib/galaxy/model/mapping.py b/lib/galaxy/model/mapping.py index 707b20b7ca2f..c2a059c0582d 100644 --- a/lib/galaxy/model/mapping.py +++ b/lib/galaxy/model/mapping.py @@ -2,7 +2,6 @@ from threading import local from typing import ( Optional, - Type, TYPE_CHECKING, ) @@ -28,8 +27,6 @@ class GalaxyModelMapping(SharedModelMapping): security_agent: GalaxyRBACAgent thread_local_log: Optional[local] - User: Type - GalaxySession: Type def init( diff --git a/lib/galaxy/model/metadata.py b/lib/galaxy/model/metadata.py index 49e827416e70..1213b66bf09a 100644 --- a/lib/galaxy/model/metadata.py +++ b/lib/galaxy/model/metadata.py @@ -25,7 +25,6 @@ from sqlalchemy.orm.attributes import flag_modified import galaxy.model -from galaxy.model.scoped_session import galaxy_scoped_session from galaxy.security.object_wrapper import sanitize_lists_to_string from galaxy.util import ( form_builder, @@ -37,6 +36,8 @@ from galaxy.util.json import safe_dumps if TYPE_CHECKING: + from sqlalchemy.orm import scoped_session + from galaxy.model import DatasetInstance from galaxy.model.none_like import NoneDataset from galaxy.model.store import SessionlessContext @@ -84,7 +85,7 @@ class MetadataCollection(Mapping): def __init__( self, parent: Union["DatasetInstance", "NoneDataset"], - session: Optional[Union[galaxy_scoped_session, "SessionlessContext"]] = None, + session: Optional[Union["scoped_session", "SessionlessContext"]] = None, ) -> None: self.parent = parent self._session = session diff --git a/lib/galaxy/model/store/__init__.py b/lib/galaxy/model/store/__init__.py index a8b515cc5c29..9018dfcf03b0 100644 --- a/lib/galaxy/model/store/__init__.py +++ b/lib/galaxy/model/store/__init__.py @@ -233,7 +233,7 @@ def flush(self) -> None: def add(self, obj: model.RepresentById) -> None: self.objects[obj.__class__][obj.id] = obj - def query(self, model_class: model.RepresentById) -> Bunch: + def query(self, model_class: Type[model.RepresentById]) -> Bunch: def find(obj_id): return self.objects.get(model_class, {}).get(obj_id) or None @@ -243,7 +243,7 @@ def filter_by(*args, **kwargs): return Bunch(find=find, get=find, filter_by=filter_by) - def get(self, model_class: model.RepresentById, primary_key: Any): # patch for SQLAlchemy 2.0 compatibility + def get(self, model_class: Type[model.RepresentById], primary_key: Any): # patch for SQLAlchemy 2.0 compatibility return self.query(model_class).get(primary_key) @@ -265,6 +265,7 @@ def remap_objects(p, k, obj): class ModelImportStore(metaclass=abc.ABCMeta): app: Optional[StoreAppProtocol] archive_dir: str + sa_session: Union[scoped_session, SessionlessContext] def __init__( self, @@ -494,7 +495,8 @@ def handle_dataset_object_edit(dataset_instance, dataset_attrs): if "id" in dataset_attrs and self.import_options.allow_edit and not self.sessionless: model_class = getattr(model, dataset_attrs["model_class"]) - dataset_instance: model.DatasetInstance = self.sa_session.get(model_class, dataset_attrs["id"]) + dataset_instance = self.sa_session.get(model_class, dataset_attrs["id"]) + assert isinstance(dataset_instance, model.DatasetInstance) attributes = [ "name", "extension", @@ -876,6 +878,7 @@ def materialize_elements(dc): dc = import_collection(collection_attrs["collection"]) if "id" in collection_attrs and self.import_options.allow_edit and not self.sessionless: hdca = self.sa_session.get(model.HistoryDatasetCollectionAssociation, collection_attrs["id"]) + assert hdca is not None # TODO: edit attributes... else: hdca = model.HistoryDatasetCollectionAssociation( @@ -2209,7 +2212,8 @@ def export_history( ) datasets = sa_session.scalars(stmt_hda).unique() for dataset in datasets: - dataset.annotation = get_item_annotation_str(sa_session, history.user, dataset) + # Add a new "annotation" attribute so that the user annotation for the dataset can be serialized without needing the user + dataset.annotation = get_item_annotation_str(sa_session, history.user, dataset) # type: ignore[attr-defined] should_include_file = (dataset.visible or include_hidden) and (not dataset.deleted or include_deleted) if not dataset.deleted and dataset.id in self.collection_datasets: should_include_file = True diff --git a/lib/galaxy/model/tags.py b/lib/galaxy/model/tags.py index 577fbc5b048d..cc1753419c2e 100644 --- a/lib/galaxy/model/tags.py +++ b/lib/galaxy/model/tags.py @@ -9,24 +9,26 @@ ) from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import ( + scoped_session, + sessionmaker, +) from sqlalchemy.sql import select from sqlalchemy.sql.expression import func import galaxy.model from galaxy.exceptions import ItemOwnershipException -from galaxy.model.scoped_session import galaxy_scoped_session +from galaxy.model import ( + GalaxySession, + Tag, +) from galaxy.util import ( strip_control_characters, unicodify, ) if TYPE_CHECKING: - from galaxy.model import ( - GalaxySession, - Tag, - User, - ) + from galaxy.model import User log = logging.getLogger(__name__) @@ -44,7 +46,7 @@ class TagHandler: Manages CRUD operations related to tagging objects. """ - def __init__(self, sa_session: galaxy_scoped_session, galaxy_session=None) -> None: + def __init__(self, sa_session: scoped_session, galaxy_session: Optional[GalaxySession] = None) -> None: self.sa_session = sa_session # Minimum tag length. self.min_tag_len = 1 @@ -58,11 +60,9 @@ def __init__(self, sa_session: galaxy_scoped_session, galaxy_session=None) -> No self.key_value_separators = "=:" # Initialize with known classes - add to this in subclasses. self.item_tag_assoc_info: Dict[str, ItemTagAssocInfo] = {} - # Can't include type annotation in signature, because lagom will attempt to look up - # GalaxySession, but can't find it due to the circular import - self.galaxy_session: Optional[GalaxySession] = galaxy_session + self.galaxy_session = galaxy_session - def create_tag_handler_session(self, galaxy_session: Optional["GalaxySession"]): + def create_tag_handler_session(self, galaxy_session: Optional[GalaxySession]): # Creates a transient tag handler that avoids repeated flushes return GalaxyTagHandlerSession(self.sa_session, galaxy_session=galaxy_session) @@ -112,7 +112,7 @@ def get_community_tags(self, item=None, limit=None): if not item_tag_assoc_class: return [] # Build select statement. - from_obj = item_tag_assoc_class.table.join(item_class.table).join(galaxy.model.Tag.table) + from_obj = item_tag_assoc_class.table.join(item_class.table).join(Tag.table) where_clause = self.get_id_col_in_item_tag_assoc_table(item_class) == item.id group_by = item_tag_assoc_class.table.c.tag_id # Do query and get result set. @@ -196,7 +196,7 @@ def item_has_tag(self, user, item, tag): tag_name = None if isinstance(tag, str): tag_name = tag - elif isinstance(tag, galaxy.model.Tag): + elif isinstance(tag, Tag): tag_name = tag.name elif isinstance(tag, galaxy.model.ItemTagAssociation): tag_name = tag.user_tname @@ -285,12 +285,12 @@ def get_tags_str(self, tags): def get_tag_by_id(self, tag_id): """Get a Tag object from a tag id.""" - return self.sa_session.get(galaxy.model.Tag, tag_id) + return self.sa_session.get(Tag, tag_id) def get_tag_by_name(self, tag_name): """Get a Tag object from a tag name (string).""" if tag_name: - return self.sa_session.scalars(select(galaxy.model.Tag).filter_by(name=tag_name.lower()).limit(1)).first() + return self.sa_session.scalars(select(Tag).filter_by(name=tag_name.lower()).limit(1)).first() return None def _create_tag(self, tag_str: str): @@ -322,11 +322,11 @@ def _create_tag(self, tag_str: str): return tag def _get_tag(self, tag_name): - return self.sa_session.scalars(select(galaxy.model.Tag).filter_by(name=tag_name).limit(1)).first() + return self.sa_session.scalars(select(Tag).filter_by(name=tag_name).limit(1)).first() def _create_tag_instance(self, tag_name): # For good performance caller should first check if there's already an appropriate tag - tag = galaxy.model.Tag(type=0, name=tag_name) + tag = Tag(type=0, name=tag_name) if not self.sa_session: return tag Session = sessionmaker(self.sa_session.bind) @@ -449,8 +449,8 @@ def _get_name_value_pair(self, tag_str) -> List[Optional[str]]: class GalaxyTagHandler(TagHandler): _item_tag_assoc_info: Dict[str, ItemTagAssocInfo] = {} - def __init__(self, sa_session: galaxy_scoped_session, galaxy_session=None): - TagHandler.__init__(self, sa_session, galaxy_session=galaxy_session) + def __init__(self, sa_session: scoped_session, galaxy_session: Optional[GalaxySession] = None): + super().__init__(sa_session, galaxy_session=galaxy_session) if not GalaxyTagHandler._item_tag_assoc_info: GalaxyTagHandler.init_tag_associations() self.item_tag_assoc_info = GalaxyTagHandler._item_tag_assoc_info @@ -496,7 +496,7 @@ def init_tag_associations(cls): class GalaxyTagHandlerSession(GalaxyTagHandler): """Like GalaxyTagHandler, but avoids one flush per created tag.""" - def __init__(self, sa_session, galaxy_session: Optional["GalaxySession"]): + def __init__(self, sa_session: scoped_session, galaxy_session: Optional[GalaxySession]): super().__init__(sa_session, galaxy_session) self.created_tags: Dict[str, Tag] = {} @@ -527,5 +527,5 @@ def get_tag_by_name(self, tag_name): class CommunityTagHandler(TagHandler): - def __init__(self, sa_session): - TagHandler.__init__(self, sa_session) + def __init__(self, sa_session: scoped_session): + super().__init__(sa_session) diff --git a/lib/galaxy/tool_shed/galaxy_install/install_manager.py b/lib/galaxy/tool_shed/galaxy_install/install_manager.py index 1008044acee2..493b7cced5a7 100644 --- a/lib/galaxy/tool_shed/galaxy_install/install_manager.py +++ b/lib/galaxy/tool_shed/galaxy_install/install_manager.py @@ -800,7 +800,7 @@ def update_tool_shed_repository( ) return (None, None) - def order_components_for_installation(self, tsr_ids, repo_info_dicts, tool_panel_section_keys): + def order_components_for_installation(self, tsr_ids: List[str], repo_info_dicts, tool_panel_section_keys): """ Some repositories may have repository dependencies that are required to be installed before the dependent repository. This method will inspect the list of repositories @@ -820,7 +820,7 @@ def order_components_for_installation(self, tsr_ids, repo_info_dicts, tool_panel prior_install_required_dict = repository_util.get_prior_import_or_install_required_dict( self.app, tsr_ids, repo_info_dicts ) - processed_tsr_ids = [] + processed_tsr_ids: List[str] = [] while len(processed_tsr_ids) != len(prior_install_required_dict.keys()): tsr_id = suc.get_next_prior_import_or_install_required_dict_entry( prior_install_required_dict, processed_tsr_ids diff --git a/lib/galaxy/tool_shed/galaxy_install/repository_dependencies/repository_dependency_manager.py b/lib/galaxy/tool_shed/galaxy_install/repository_dependencies/repository_dependency_manager.py index 812fbb1d00b3..c7ee588626b4 100644 --- a/lib/galaxy/tool_shed/galaxy_install/repository_dependencies/repository_dependency_manager.py +++ b/lib/galaxy/tool_shed/galaxy_install/repository_dependencies/repository_dependency_manager.py @@ -481,9 +481,8 @@ def get_required_repo_info_dicts(self, tool_shed_url, repo_info_dicts): ) encoded_required_repository_str = encoding_util.encoding_sep2.join(encoded_required_repository_tups) encoded_required_repository_str = encoding_util.tool_shed_encode(encoded_required_repository_str) - if repository_util.is_tool_shed_client(self.app): - # Handle secure / insecure Tool Shed URL protocol changes and port changes. - tool_shed_url = common_util.get_tool_shed_url_from_tool_shed_registry(self.app, tool_shed_url) + # Handle secure / insecure Tool Shed URL protocol changes and port changes. + tool_shed_url = common_util.get_tool_shed_url_from_tool_shed_registry(self.app, tool_shed_url) pathspec = ["repository", "get_required_repo_info_dict"] url = build_url(tool_shed_url, pathspec=pathspec) # Fix for handling 307 redirect not being handled nicely by urlopen() when the Request() has data provided diff --git a/lib/galaxy/tool_shed/util/repository_util.py b/lib/galaxy/tool_shed/util/repository_util.py index 2c92c061e06f..4bc101eca504 100644 --- a/lib/galaxy/tool_shed/util/repository_util.py +++ b/lib/galaxy/tool_shed/util/repository_util.py @@ -9,6 +9,7 @@ List, Optional, Tuple, + TYPE_CHECKING, Union, ) from urllib.error import HTTPError @@ -18,7 +19,6 @@ and_, false, ) -from sqlalchemy.orm import joinedload from galaxy import util from galaxy.model.base import check_database_connection @@ -31,6 +31,9 @@ ) from galaxy.util.tool_shed.tool_shed_registry import Registry +if TYPE_CHECKING: + from galaxy.tool_shed.galaxy_install.client import InstallationTarget + log = logging.getLogger(__name__) VALID_REPOSITORYNAME_RE = re.compile(r"^[a-z0-9\_]+$") @@ -103,7 +106,7 @@ def _check_or_update_tool_shed_status_for_installed_repository( def create_or_update_tool_shed_repository( - app, + app: "InstallationTarget", name, description, installed_changeset_revision, @@ -113,7 +116,7 @@ def create_or_update_tool_shed_repository( metadata_dict=None, current_changeset_revision=None, owner="", - dist_to_shed=False, + dist_to_shed: bool = False, ): """ Update a tool shed repository record in the Galaxy database with the new information received. @@ -242,7 +245,7 @@ def get_absolute_path_to_file_in_repository(repo_files_dir, file_name): def get_installed_repository( - app, + app: "InstallationTarget", tool_shed=None, name=None, owner=None, @@ -261,7 +264,7 @@ def get_installed_repository( if from_cache: tsr_cache = getattr(app, "tool_shed_repository_cache", None) if tsr_cache: - return app.tool_shed_repository_cache.get_installed_repository( + return tsr_cache.get_installed_repository( tool_shed=tool_shed, name=name, owner=owner, @@ -287,7 +290,7 @@ def get_installed_repository( return query.filter(and_(*clause_list)).first() -def get_installed_tool_shed_repository(app, id): +def get_installed_tool_shed_repository(app: "InstallationTarget", id): """Get a tool shed repository record from the Galaxy database defined by the id.""" rval = [] if isinstance(id, list): @@ -302,7 +305,7 @@ def get_installed_tool_shed_repository(app, id): return rval[0] -def get_prior_import_or_install_required_dict(app, tsr_ids, repo_info_dicts): +def get_prior_import_or_install_required_dict(app: "InstallationTarget", tsr_ids: List[str], repo_info_dicts): """ This method is used in the Tool Shed when exporting a repository and its dependencies, and in Galaxy when a repository and its dependencies are being installed. Return a @@ -311,7 +314,7 @@ def get_prior_import_or_install_required_dict(app, tsr_ids, repo_info_dicts): must be imported or installed prior to the repository associated with the tsr_id key. """ # Initialize the dictionary. - prior_import_or_install_required_dict = {} + prior_import_or_install_required_dict: Dict[str, List[str]] = {} for tsr_id in tsr_ids: prior_import_or_install_required_dict[tsr_id] = [] # Inspect the repository dependencies for each repository about to be installed and populate the dictionary. @@ -376,7 +379,7 @@ def get_repository_admin_role_name(repository_name, repository_owner): return f"{repository_name}_{repository_owner}_admin" -def get_repository_and_repository_dependencies_from_repo_info_dict(app, repo_info_dict): +def get_repository_and_repository_dependencies_from_repo_info_dict(app: "InstallationTarget", repo_info_dict): """Return a tool_shed_repository or repository record defined by the information in the received repo_info_dict.""" repository_name = list(repo_info_dict.keys())[0] repo_info_tuple = repo_info_dict[repository_name] @@ -389,53 +392,11 @@ def get_repository_and_repository_dependencies_from_repo_info_dict(app, repo_inf repository_dependencies, tool_dependencies, ) = get_repo_info_tuple_contents(repo_info_tuple) - if hasattr(app, "install_model"): - # In a tool shed client (Galaxy, or something install repositories like Galaxy) - tool_shed = get_tool_shed_from_clone_url(repository_clone_url) - repository = get_repository_for_dependency_relationship( - app, tool_shed, repository_name, repository_owner, changeset_revision - ) - else: - # We're in the tool shed. - repository = get_repository_by_name_and_owner(app, repository_name, repository_owner) - return repository, repository_dependencies - - -def get_repository_by_id(app, id): - """Get a repository from the database via id.""" - if is_tool_shed_client(app): - return app.install_model.context.query(app.install_model.ToolShedRepository).get(app.security.decode_id(id)) - else: - sa_session = app.model.session - return sa_session.query(app.model.Repository).get(app.security.decode_id(id)) - - -def get_repository_by_name_and_owner(app, name, owner, eagerload_columns=None): - """Get a repository from the database via name and owner""" - repository_query = get_repository_query(app) - if is_tool_shed_client(app): - return repository_query.filter( - and_( - app.install_model.ToolShedRepository.name == name, - app.install_model.ToolShedRepository.owner == owner, - ) - ).first() - # We're in the tool shed. - q = repository_query.filter( - and_( - app.model.Repository.name == name, - app.model.User.username == owner, - app.model.Repository.user_id == app.model.User.id, - ) + tool_shed = get_tool_shed_from_clone_url(repository_clone_url) + repository = get_repository_for_dependency_relationship( + app, tool_shed, repository_name, repository_owner, changeset_revision ) - if eagerload_columns: - q = q.options(joinedload(*eagerload_columns)) - return q.first() - - -def get_repository_by_name(app, name): - """Get a repository from the database via name.""" - return get_repository_query(app).filter_by(name=name).first() + return repository, repository_dependencies def get_repository_dependency_types(repository_dependencies): @@ -477,7 +438,7 @@ def get_repository_dependency_types(repository_dependencies): return has_repository_dependencies, has_repository_dependencies_only_if_compiling_contained_td -def get_repository_for_dependency_relationship(app, tool_shed, name, owner, changeset_revision): +def get_repository_for_dependency_relationship(app: "InstallationTarget", tool_shed, name, owner, changeset_revision): """ Return an installed tool_shed_repository database record that is defined by either the current changeset revision or the installed_changeset_revision. @@ -491,32 +452,37 @@ def get_repository_for_dependency_relationship(app, tool_shed, name, owner, chan repository = get_installed_repository( app=app, tool_shed=tool_shed, name=name, owner=owner, installed_changeset_revision=changeset_revision ) - if not repository: + if repository: + return repository + repository = get_installed_repository( + app=app, tool_shed=tool_shed, name=name, owner=owner, changeset_revision=changeset_revision + ) + if repository: + return repository + tool_shed_url = common_util.get_tool_shed_url_from_tool_shed_registry(app, tool_shed) + assert tool_shed_url + repository_clone_url = os.path.join(tool_shed_url, "repos", owner, name) + repo_info_tuple = (None, repository_clone_url, changeset_revision, None, owner, None, None) + repository, pcr = repository_was_previously_installed(app, tool_shed_url, name, repo_info_tuple) + if repository: + return repository + # The received changeset_revision is no longer installable, so get the next changeset_revision + # in the repository's changelog in the tool shed that is associated with repository_metadata. + params = dict(name=name, owner=owner, changeset_revision=changeset_revision) + pathspec = ["repository", "next_installable_changeset_revision"] + text = util.url_get( + tool_shed_url, auth=app.tool_shed_registry.url_auth(tool_shed_url), pathspec=pathspec, params=params + ) + if text: repository = get_installed_repository( - app=app, tool_shed=tool_shed, name=name, owner=owner, changeset_revision=changeset_revision + app=app, tool_shed=tool_shed, name=name, owner=owner, changeset_revision=text ) - if not repository: - tool_shed_url = common_util.get_tool_shed_url_from_tool_shed_registry(app, tool_shed) - repository_clone_url = os.path.join(tool_shed_url, "repos", owner, name) - repo_info_tuple = (None, repository_clone_url, changeset_revision, None, owner, None, None) - repository, pcr = repository_was_previously_installed(app, tool_shed_url, name, repo_info_tuple) - if not repository: - # The received changeset_revision is no longer installable, so get the next changeset_revision - # in the repository's changelog in the tool shed that is associated with repository_metadata. - tool_shed_url = common_util.get_tool_shed_url_from_tool_shed_registry(app, tool_shed) - params = dict(name=name, owner=owner, changeset_revision=changeset_revision) - pathspec = ["repository", "next_installable_changeset_revision"] - text = util.url_get( - tool_shed_url, auth=app.tool_shed_registry.url_auth(tool_shed_url), pathspec=pathspec, params=params - ) - if text: - repository = get_installed_repository( - app=app, tool_shed=tool_shed, name=name, owner=owner, changeset_revision=text - ) return repository -def get_repository_ids_requiring_prior_import_or_install(app, tsr_ids, repository_dependencies): +def get_repository_ids_requiring_prior_import_or_install( + app: "InstallationTarget", tsr_ids: List[str], repository_dependencies +): """ This method is used in the Tool Shed when exporting a repository and its dependencies, and in Galaxy when a repository and its dependencies are being installed. Inspect the @@ -527,7 +493,7 @@ def get_repository_ids_requiring_prior_import_or_install(app, tsr_ids, repositor and whose associated repositories must be imported / installed prior to the dependent repository associated with the received repository_dependencies. """ - prior_tsr_ids = [] + prior_tsr_ids: List[str] = [] if repository_dependencies: for key, rd_tups in repository_dependencies.items(): if key in ["description", "root_key"]: @@ -552,14 +518,11 @@ def get_repository_ids_requiring_prior_import_or_install(app, tsr_ids, repositor # of the dependent repository's tool dependency. if not util.asbool(only_if_compiling_contained_td): if util.asbool(prior_installation_required): - if is_tool_shed_client(app): - # We store the port, if one exists, in the database. - tool_shed = common_util.remove_protocol_from_tool_shed_url(tool_shed) - repository = get_repository_for_dependency_relationship( - app, tool_shed, name, owner, changeset_revision - ) - else: - repository = get_repository_by_name_and_owner(app, name, owner) + # We store the port, if one exists, in the database. + tool_shed = common_util.remove_protocol_from_tool_shed_url(tool_shed) + repository = get_repository_for_dependency_relationship( + app, tool_shed, name, owner, changeset_revision + ) if repository: encoded_repository_id = app.security.encode_id(repository.id) if encoded_repository_id in tsr_ids: @@ -582,20 +545,6 @@ def get_repository_owner_from_clone_url(repository_clone_url): return get_repository_owner(tmp_url) -def get_repository_query(app): - if is_tool_shed_client(app): - query = app.install_model.context.query(app.install_model.ToolShedRepository) - else: - query = app.model.context.query(app.model.Repository) - return query - - -def get_role_by_id(app, role_id): - """Get a Role from the database by id.""" - sa_session = app.model.session - return sa_session.query(app.model.Role).get(app.security.decode_id(role_id)) - - def get_tool_shed_from_clone_url(repository_clone_url): tmp_url = common_util.remove_protocol_and_user_from_clone_url(repository_clone_url) return tmp_url.split("/repos/")[0].rstrip("/") @@ -665,7 +614,9 @@ def is_tool_shed_client(app): return hasattr(app, "install_model") -def repository_was_previously_installed(app, tool_shed_url, repository_name, repo_info_tuple, from_tip=False): +def repository_was_previously_installed( + app: "InstallationTarget", tool_shed_url: str, repository_name, repo_info_tuple, from_tip: bool = False +): """ Find out if a repository is already installed into Galaxy - there are several scenarios where this is necessary. For example, this method will handle the case where the repository was previously @@ -674,7 +625,8 @@ def repository_was_previously_installed(app, tool_shed_url, repository_name, rep updating the one that was previously installed. We'll look in the database instead of on disk since the repository may be currently uninstalled. """ - tool_shed_url = common_util.get_tool_shed_url_from_tool_shed_registry(app, tool_shed_url) + base_url = common_util.get_tool_shed_url_from_tool_shed_registry(app, tool_shed_url) + assert base_url ( description, repository_clone_url, @@ -704,9 +656,7 @@ def repository_was_previously_installed(app, tool_shed_url, repository_name, rep from_tip=str(from_tip), ) pathspec = ["repository", "previous_changeset_revisions"] - text = util.url_get( - tool_shed_url, auth=app.tool_shed_registry.url_auth(tool_shed_url), pathspec=pathspec, params=params - ) + text = util.url_get(base_url, auth=app.tool_shed_registry.url_auth(base_url), pathspec=pathspec, params=params) if text: changeset_revisions = util.listify(text) for previous_changeset_revision in changeset_revisions: @@ -754,16 +704,11 @@ def set_repository_attributes(app, repository, status, error_message, deleted, u "get_repo_info_tuple_contents", "get_repository_admin_role_name", "get_repository_and_repository_dependencies_from_repo_info_dict", - "get_repository_by_id", - "get_repository_by_name", - "get_repository_by_name_and_owner", "get_repository_dependency_types", "get_repository_for_dependency_relationship", "get_repository_ids_requiring_prior_import_or_install", "get_repository_owner", "get_repository_owner_from_clone_url", - "get_repository_query", - "get_role_by_id", "get_tool_shed_from_clone_url", "get_tool_shed_repository_by_id", "get_tool_shed_status_for_installed_repository", diff --git a/lib/galaxy/tool_shed/util/shed_util_common.py b/lib/galaxy/tool_shed/util/shed_util_common.py index 72d72a6eab91..6f358183742b 100644 --- a/lib/galaxy/tool_shed/util/shed_util_common.py +++ b/lib/galaxy/tool_shed/util/shed_util_common.py @@ -1,5 +1,9 @@ import logging import re +from typing import ( + Dict, + List, +) from galaxy import util from galaxy.tool_shed.util import repository_util @@ -72,7 +76,9 @@ def get_ctx_rev(app, tool_shed_url, name, owner, changeset_revision): return ctx_rev -def get_next_prior_import_or_install_required_dict_entry(prior_required_dict, processed_tsr_ids): +def get_next_prior_import_or_install_required_dict_entry( + prior_required_dict: Dict[str, List[str]], processed_tsr_ids: List[str] +): """ This method is used in the Tool Shed when exporting a repository and its dependencies, and in Galaxy when a repository and its dependencies are being installed. The order in which the prior_required_dict diff --git a/lib/galaxy/tools/parameters/wrapped.py b/lib/galaxy/tools/parameters/wrapped.py index d23e9be5edf9..b611a61056ae 100644 --- a/lib/galaxy/tools/parameters/wrapped.py +++ b/lib/galaxy/tools/parameters/wrapped.py @@ -39,9 +39,9 @@ class LegacyUnprefixedDict(UserDict): # It used to be valid to access members of conditionals without specifying the conditional. # This dict provides a fallback when dict lookup fails using those old rules - def __init__(self, dict=None, **kwargs): + def __init__(self, initialdata=None, **kwargs): self._legacy_mapping: Dict[str, str] = {} - super().__init__(dict, **kwargs) + super().__init__(initialdata, **kwargs) def set_legacy_alias(self, new_key: str, old_key: str): self._legacy_mapping[old_key] = new_key diff --git a/lib/galaxy/tools/remote_tool_eval.py b/lib/galaxy/tools/remote_tool_eval.py index 6f2e273ffbe4..c4ac7e8e1e5a 100644 --- a/lib/galaxy/tools/remote_tool_eval.py +++ b/lib/galaxy/tools/remote_tool_eval.py @@ -80,6 +80,7 @@ def main(TMPDIR, WORKING_DIRECTORY, IMPORT_STORE_DIRECTORY) -> None: datatypes_registry = validate_and_load_datatypes_config(datatypes_config) object_store = get_object_store(WORKING_DIRECTORY) import_store = store.imported_store_for_metadata(IMPORT_STORE_DIRECTORY) + assert isinstance(import_store.sa_session, SessionlessContext) # TODO: clean up random places from which we read files in the working directory job_io = JobIO.from_json(os.path.join(IMPORT_STORE_DIRECTORY, "job_io.json"), sa_session=import_store.sa_session) tool_app_config = ToolAppConfig( diff --git a/lib/galaxy/webapps/galaxy/api/users.py b/lib/galaxy/webapps/galaxy/api/users.py index cbf9b670ad93..9e5a405c0178 100644 --- a/lib/galaxy/webapps/galaxy/api/users.py +++ b/lib/galaxy/webapps/galaxy/api/users.py @@ -692,6 +692,7 @@ def delete( payload: Optional[UserDeletionPayload] = None, ) -> DetailedUserModel: user_to_update = self.service.user_manager.by_id(user_id) + assert user_to_update is not None purge = payload and payload.purge or purge if trans.user_is_admin: if purge: diff --git a/lib/galaxy/webapps/galaxy/services/histories.py b/lib/galaxy/webapps/galaxy/services/histories.py index ff63b42af83d..c5c05c69ca5e 100644 --- a/lib/galaxy/webapps/galaxy/services/histories.py +++ b/lib/galaxy/webapps/galaxy/services/histories.py @@ -218,7 +218,7 @@ def index_query( payload: HistoryIndexQueryPayload, serialization_params: SerializationParams, include_total_count: bool = False, - ) -> Tuple[List[AnyHistoryView], int]: + ) -> Tuple[List[AnyHistoryView], Union[int, None]]: """Return a list of History accessible by the user :rtype: list diff --git a/lib/galaxy/webapps/galaxy/services/pages.py b/lib/galaxy/webapps/galaxy/services/pages.py index 197b77566e2d..9236e4403495 100644 --- a/lib/galaxy/webapps/galaxy/services/pages.py +++ b/lib/galaxy/webapps/galaxy/services/pages.py @@ -1,5 +1,8 @@ import logging -from typing import Tuple +from typing import ( + Tuple, + Union, +) from galaxy import exceptions from galaxy.celery.tasks import prepare_pdf_download @@ -60,7 +63,7 @@ def __init__( def index( self, trans, payload: PageIndexQueryPayload, include_total_count: bool = False - ) -> Tuple[PageSummaryList, int]: + ) -> Tuple[PageSummaryList, Union[int, None]]: """Return a list of Pages viewable by the user :rtype: list diff --git a/lib/galaxy/webapps/galaxy/services/sharable.py b/lib/galaxy/webapps/galaxy/services/sharable.py index de318bbfad6d..1ac68d7ae2c4 100644 --- a/lib/galaxy/webapps/galaxy/services/sharable.py +++ b/lib/galaxy/webapps/galaxy/services/sharable.py @@ -155,7 +155,7 @@ def _get_users(self, trans, emails_or_ids: List[UserIdentifier]) -> Tuple[Set[Us send_to_user = None if isinstance(email_or_id, int): send_to_user = self.manager.user_manager.by_id(email_or_id) - if send_to_user.deleted: + if send_to_user and send_to_user.deleted: send_to_user = None else: email_address = email_or_id.strip() diff --git a/lib/galaxy/webapps/galaxy/services/visualizations.py b/lib/galaxy/webapps/galaxy/services/visualizations.py index d9721d979207..e44a2e17218b 100644 --- a/lib/galaxy/webapps/galaxy/services/visualizations.py +++ b/lib/galaxy/webapps/galaxy/services/visualizations.py @@ -76,7 +76,7 @@ def index( trans: ProvidesUserContext, payload: VisualizationIndexQueryPayload, include_total_count: bool = False, - ) -> Tuple[VisualizationSummaryList, int]: + ) -> Tuple[VisualizationSummaryList, Union[int, None]]: """Return a list of Visualizations viewable by the user :rtype: list diff --git a/lib/galaxy/webapps/reports/app.py b/lib/galaxy/webapps/reports/app.py index 7e31c52725b7..6a933ae0b41b 100644 --- a/lib/galaxy/webapps/reports/app.py +++ b/lib/galaxy/webapps/reports/app.py @@ -2,7 +2,7 @@ import sys import time -import galaxy.model +import galaxy.model.mapping from galaxy.config import configure_logging from galaxy.model.base import SharedModelMapping from galaxy.security import idencoding diff --git a/lib/tool_shed/dependencies/attribute_handlers.py b/lib/tool_shed/dependencies/attribute_handlers.py index 2d72abd48e17..c9f91794a147 100644 --- a/lib/tool_shed/dependencies/attribute_handlers.py +++ b/lib/tool_shed/dependencies/attribute_handlers.py @@ -21,9 +21,9 @@ from tool_shed.util import ( hg_util, metadata_util, - repository_util, xml_util, ) +from tool_shed.webapp.model.db import get_repository_by_name_and_owner if TYPE_CHECKING: from tool_shed.context import ProvidesRepositoriesContext @@ -128,7 +128,7 @@ def handle_elem(self, elem): # Populate the changeset_revision attribute with the latest installable metadata revision for # the defined repository. We use the latest installable revision instead of the latest metadata # revision to ensure that the contents of the revision are valid. - repository = repository_util.get_repository_by_name_and_owner(self.app, name, owner) + repository = get_repository_by_name_and_owner(self.app.model.context, name, owner) if repository: lastest_installable_changeset_revision = metadata_util.get_latest_downloadable_changeset_revision( self.app, repository diff --git a/lib/tool_shed/dependencies/repository/relation_builder.py b/lib/tool_shed/dependencies/repository/relation_builder.py index 416282f78055..61b876568457 100644 --- a/lib/tool_shed/dependencies/repository/relation_builder.py +++ b/lib/tool_shed/dependencies/repository/relation_builder.py @@ -1,4 +1,12 @@ import logging +from typing import ( + Any, + Dict, + List, + Optional, + Tuple, + TYPE_CHECKING, +) import tool_shed.util.repository_util from galaxy.util import ( @@ -11,23 +19,27 @@ metadata_util, shed_util_common as suc, ) +from tool_shed.webapp.model.db import get_repository_by_name_and_owner + +if TYPE_CHECKING: + from tool_shed.structured_app import ToolShedApp log = logging.getLogger(__name__) class RelationBuilder: - def __init__(self, app, repository, repository_metadata, tool_shed_url, trans=None): - self.all_repository_dependencies = {} + def __init__(self, app: "ToolShedApp", repository, repository_metadata, tool_shed_url, trans=None): + self.all_repository_dependencies: Dict[str, Any] = {} self.app = app - self.circular_repository_dependencies = [] + self.circular_repository_dependencies: List[Tuple] = [] self.repository = repository self.repository_metadata = repository_metadata - self.handled_key_rd_dicts = [] - self.key_rd_dicts_to_be_processed = [] + self.handled_key_rd_dicts: List[Dict[str, List[str]]] = [] + self.key_rd_dicts_to_be_processed: List[Dict[str, List[str]]] = [] self.tool_shed_url = tool_shed_url self.trans = trans - def can_add_to_key_rd_dicts(self, key_rd_dict, key_rd_dicts): + def can_add_to_key_rd_dicts(self, key_rd_dict, key_rd_dicts: List[Dict[str, List[str]]]): """Handle the case where an update to the changeset revision was done.""" k = next(iter(key_rd_dict)) rd = key_rd_dict[k] @@ -40,13 +52,13 @@ def can_add_to_key_rd_dicts(self, key_rd_dict, key_rd_dicts): return False return True - def filter_only_if_compiling_contained_td(self, key_rd_dict): + def filter_only_if_compiling_contained_td(self, key_rd_dict: Dict[str, Any]): """ Return a copy of the received key_rd_dict with repository dependencies that are needed only_if_compiling_contained_td filtered out of the list of repository dependencies for each rd_key. """ - filtered_key_rd_dict = {} + filtered_key_rd_dict: Dict[str, Any] = {} for rd_key, required_rd_tup in key_rd_dict.items(): ( tool_shed, @@ -155,7 +167,7 @@ def get_repository_dependencies_for_changeset_revision(self): This method ensures that all required repositories to the nth degree are returned. """ # Assume the current repository does not have repository dependencies defined for it. - current_repository_key = None + current_repository_key: Optional[str] = None if metadata := self.repository_metadata.metadata: # The value of self.tool_shed_url must include the port, but doesn't have to include # the protocol. @@ -199,8 +211,8 @@ def get_repository_dependency_as_key(self, repository_dependency): tool_shed, name, owner, changeset_revision, prior_installation_required, only_if_compiling_contained_td ) - def get_updated_changeset_revisions_for_repository_dependencies(self, key_rd_dicts): - updated_key_rd_dicts = [] + def get_updated_changeset_revisions_for_repository_dependencies(self, key_rd_dicts: List[Dict[str, Any]]): + updated_key_rd_dicts: List[Dict[str, Any]] = [] for key_rd_dict in key_rd_dicts: key = next(iter(key_rd_dict)) repository_dependency = key_rd_dict[key] @@ -214,9 +226,7 @@ def get_updated_changeset_revisions_for_repository_dependencies(self, key_rd_dic ) = common_util.parse_repository_dependency_tuple(repository_dependency) tool_shed_is_this_tool_shed = suc.tool_shed_is_this_tool_shed(rd_toolshed, trans=self.trans) if tool_shed_is_this_tool_shed: - repository = tool_shed.util.repository_util.get_repository_by_name_and_owner( - self.app, rd_name, rd_owner - ) + repository = get_repository_by_name_and_owner(self.app.model.context, rd_name, rd_owner) if repository: repository_id = self.app.security.encode_id(repository.id) repository_metadata = metadata_util.get_repository_metadata_by_repository_id_changeset_revision( @@ -224,8 +234,7 @@ def get_updated_changeset_revisions_for_repository_dependencies(self, key_rd_dic ) if repository_metadata: # The repository changeset_revision is installable, so no updates are available. - new_key_rd_dict = {} - new_key_rd_dict[key] = repository_dependency + new_key_rd_dict = {key: repository_dependency} updated_key_rd_dicts.append(key_rd_dict) else: # The repository changeset_revision is no longer installable, so see if there's been an update. @@ -239,15 +248,16 @@ def get_updated_changeset_revisions_for_repository_dependencies(self, key_rd_dic ) ) if repository_metadata: - new_key_rd_dict = {} - new_key_rd_dict[key] = [ - rd_toolshed, - rd_name, - rd_owner, - repository_metadata.changeset_revision, - rd_prior_installation_required, - rd_only_if_compiling_contained_td, - ] + new_key_rd_dict = { + key: [ + rd_toolshed, + rd_name, + rd_owner, + repository_metadata.changeset_revision, + rd_prior_installation_required, + rd_only_if_compiling_contained_td, + ] + } # We have the updated changeset revision. updated_key_rd_dicts.append(new_key_rd_dict) else: @@ -288,7 +298,7 @@ def get_updated_changeset_revisions_for_repository_dependencies(self, key_rd_dic ) return updated_key_rd_dicts - def handle_circular_repository_dependency(self, repository_key, repository_dependency): + def handle_circular_repository_dependency(self, repository_key: str, repository_dependency): all_repository_dependencies_root_key = self.all_repository_dependencies["root_key"] repository_dependency_as_key = self.get_repository_dependency_as_key(repository_dependency) self.update_circular_repository_dependencies( @@ -297,18 +307,19 @@ def handle_circular_repository_dependency(self, repository_key, repository_depen if all_repository_dependencies_root_key != repository_dependency_as_key: self.all_repository_dependencies[repository_key] = [repository_dependency] - def handle_current_repository_dependency(self, current_repository_key): + def handle_current_repository_dependency(self, current_repository_key: str): current_repository_key_rd_dicts = [] for rd in self.all_repository_dependencies[current_repository_key]: rd_copy = [str(item) for item in rd] - new_key_rd_dict = {} - new_key_rd_dict[current_repository_key] = rd_copy + new_key_rd_dict = {current_repository_key: rd_copy} current_repository_key_rd_dicts.append(new_key_rd_dict) if current_repository_key_rd_dicts: self.handle_key_rd_dicts_for_repository(current_repository_key, current_repository_key_rd_dicts) return self.get_repository_dependencies_for_changeset_revision() - def handle_key_rd_dicts_for_repository(self, current_repository_key, repository_key_rd_dicts): + def handle_key_rd_dicts_for_repository( + self, current_repository_key, repository_key_rd_dicts: List[Dict[str, List[str]]] + ): key_rd_dict = repository_key_rd_dicts.pop(0) repository_dependency = key_rd_dict[current_repository_key] ( @@ -320,7 +331,7 @@ def handle_key_rd_dicts_for_repository(self, current_repository_key, repository_ only_if_compiling_contained_td, ) = common_util.parse_repository_dependency_tuple(repository_dependency) if suc.tool_shed_is_this_tool_shed(toolshed, trans=self.trans): - required_repository = tool_shed.util.repository_util.get_repository_by_name_and_owner(self.app, name, owner) + required_repository = get_repository_by_name_and_owner(self.app.model.context, name, owner) self.repository = required_repository repository_id = self.app.security.encode_id(required_repository.id) required_repository_metadata = metadata_util.get_repository_metadata_by_repository_id_changeset_revision( @@ -384,7 +395,7 @@ def in_circular_repository_dependencies(self, repository_key_rd_dict): return True return False - def in_key_rd_dicts(self, key_rd_dict, key_rd_dicts): + def in_key_rd_dicts(self, key_rd_dict: Dict[str, List[str]], key_rd_dicts: List[Dict[str, List[str]]]): """Return True if key_rd_dict is contained in the list of key_rd_dicts.""" k = next(iter(key_rd_dict)) v = key_rd_dict[k] @@ -394,7 +405,7 @@ def in_key_rd_dicts(self, key_rd_dict, key_rd_dicts): return True return False - def initialize_all_repository_dependencies(self, current_repository_key, repository_dependencies_dict): + def initialize_all_repository_dependencies(self, current_repository_key: str, repository_dependencies_dict): """Initialize the self.all_repository_dependencies dictionary.""" # It's safe to assume that current_repository_key in this case will have a value. self.all_repository_dependencies["root_key"] = current_repository_key @@ -418,7 +429,7 @@ def is_circular_repository_dependency(self, repository_key, repository_dependenc return False def populate_repository_dependency_objects_for_processing( - self, current_repository_key, repository_dependencies_dict + self, current_repository_key: str, repository_dependencies_dict ): """ The process that discovers all repository dependencies for a specified repository's changeset @@ -428,11 +439,10 @@ def populate_repository_dependency_objects_for_processing( more repository dependencies, so this method is repeatedly called until all repository dependencies have been discovered. """ - current_repository_key_rd_dicts = [] + current_repository_key_rd_dicts: List[Dict[str, Any]] = [] filtered_current_repository_key_rd_dicts = [] for rd_tup in repository_dependencies_dict["repository_dependencies"]: - new_key_rd_dict = {} - new_key_rd_dict[current_repository_key] = rd_tup + new_key_rd_dict = {current_repository_key: rd_tup} current_repository_key_rd_dicts.append(new_key_rd_dict) if current_repository_key_rd_dicts and current_repository_key: # Remove all repository dependencies that point to a revision within its own repository. @@ -470,8 +480,7 @@ def populate_repository_dependency_objects_for_processing( else: self.all_repository_dependencies[current_repository_key] = [repository_dependency] if not is_circular and self.can_add_to_key_rd_dicts(key_rd_dict, self.key_rd_dicts_to_be_processed): - new_key_rd_dict = {} - new_key_rd_dict[current_repository_key] = repository_dependency + new_key_rd_dict = {current_repository_key: repository_dependency} self.key_rd_dicts_to_be_processed.append(new_key_rd_dict) return filtered_current_repository_key_rd_dicts @@ -497,11 +506,11 @@ def prune_invalid_repository_dependencies(self, repository_dependencies): valid_repository_dependencies["root_key"] = root_key return valid_repository_dependencies - def remove_from_key_rd_dicts(self, key_rd_dict, key_rd_dicts): + def remove_from_key_rd_dicts(self, key_rd_dict: Dict[str, List[str]], key_rd_dicts: List[Dict[str, List[str]]]): """Eliminate the key_rd_dict from the list of key_rd_dicts if it is contained in the list.""" k = next(iter(key_rd_dict)) v = key_rd_dict[k] - clean_key_rd_dicts = [] + clean_key_rd_dicts: List[Dict[str, List[str]]] = [] for krd_dict in key_rd_dicts: key = next(iter(krd_dict)) val = krd_dict[key] @@ -510,9 +519,9 @@ def remove_from_key_rd_dicts(self, key_rd_dict, key_rd_dicts): clean_key_rd_dicts.append(krd_dict) return clean_key_rd_dicts - def remove_repository_dependency_reference_to_self(self, key_rd_dicts): + def remove_repository_dependency_reference_to_self(self, key_rd_dicts: List[Dict[str, Any]]): """Remove all repository dependencies that point to a revision within its own repository.""" - clean_key_rd_dicts = [] + clean_key_rd_dicts: List[Dict[str, Any]] = [] key = next(iter(key_rd_dicts[0])) repository_tup = key.split(container_util.STRSEP) ( @@ -541,12 +550,13 @@ def remove_repository_dependency_reference_to_self(self, key_rd_dicts): debug_msg += "since it refers to a revision within itself." log.debug(debug_msg) else: - new_key_rd_dict = {} - new_key_rd_dict[key] = repository_dependency + new_key_rd_dict = {key: repository_dependency} clean_key_rd_dicts.append(new_key_rd_dict) return clean_key_rd_dicts - def update_circular_repository_dependencies(self, repository_key, repository_dependency, repository_dependencies): + def update_circular_repository_dependencies( + self, repository_key: str, repository_dependency, repository_dependencies + ): repository_key_as_repository_dependency = repository_key.split(container_util.STRSEP) if repository_key_as_repository_dependency in repository_dependencies: found = False @@ -555,5 +565,5 @@ def update_circular_repository_dependencies(self, repository_key, repository_dep # The circular dependency has already been included. found = True if not found: - new_circular_tup = [repository_dependency, repository_key_as_repository_dependency] + new_circular_tup = (repository_dependency, repository_key_as_repository_dependency) self.circular_repository_dependencies.append(new_circular_tup) diff --git a/lib/tool_shed/grids/repository_grids.py b/lib/tool_shed/grids/repository_grids.py index 34631ded8816..350278827c6c 100644 --- a/lib/tool_shed/grids/repository_grids.py +++ b/lib/tool_shed/grids/repository_grids.py @@ -16,9 +16,9 @@ from tool_shed.util import ( hg_util, metadata_util, - repository_util, ) from tool_shed.webapp import model +from tool_shed.webapp.model.db import get_repository_by_name_and_owner log = logging.getLogger(__name__) @@ -1079,9 +1079,7 @@ def get_value(self, trans, grid, repository_metadata): for rd_tup in sorted_rd_tups: name, owner, changeset_revision = rd_tup[1:4] rd_line = "" - required_repository = repository_util.get_repository_by_name_and_owner( - trans.app, name, owner - ) + required_repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) if required_repository and not required_repository.deleted: required_repository_id = trans.security.encode_id(required_repository.id) required_repository_metadata = ( diff --git a/lib/tool_shed/managers/repositories.py b/lib/tool_shed/managers/repositories.py index 5e21ba862674..031d1ebe849d 100644 --- a/lib/tool_shed/managers/repositories.py +++ b/lib/tool_shed/managers/repositories.py @@ -56,7 +56,6 @@ create_repository as low_level_create_repository, get_repo_info_dict, get_repositories_by_category, - get_repository_by_name_and_owner, get_repository_in_tool_shed, validate_repository_name, ) @@ -69,6 +68,7 @@ Repository, RepositoryMetadata, ) +from tool_shed.webapp.model.db import get_repository_by_name_and_owner from tool_shed.webapp.search.repo_search import RepoSearch from tool_shed_client.schema import ( CreateRepositoryRequest, @@ -147,7 +147,7 @@ def check_updates(app: ToolShedApp, request: UpdatesRequest) -> Union[str, Dict[ changeset_revision = request.changeset_revision hexlify_this = request.hexlify repository = get_repository_by_name_and_owner( - app, name, owner, eagerload_columns=[Repository.downloadable_revisions] + app.model.context, name, owner, eagerload_columns=[Repository.downloadable_revisions] ) if repository and repository.downloadable_revisions: repository_metadata = get_repository_metadata_by_changeset_revision( @@ -289,7 +289,7 @@ def get_install_info(trans: ProvidesRepositoriesContext, name, owner, changeset_ if name and owner and changeset_revision: # Get the repository information. repository = get_repository_by_name_and_owner( - app, name, owner, eagerload_columns=[Repository.downloadable_revisions] + app.model.context, name, owner, eagerload_columns=[Repository.downloadable_revisions] ) if repository is None: log.debug(f"Cannot locate repository {name} owned by {owner}") @@ -360,7 +360,9 @@ def get_ordered_installable_revisions( eagerload_columns = [Repository.downloadable_revisions] if None not in [name, owner]: # Get the repository information. - repository = get_repository_by_name_and_owner(app, name, owner, eagerload_columns=eagerload_columns) + repository = get_repository_by_name_and_owner( + app.model.context, name, owner, eagerload_columns=eagerload_columns + ) if repository is None: raise ObjectNotFound(f"No repository named {name} found with owner {owner}") elif tsr_id is not None: diff --git a/lib/tool_shed/metadata/repository_metadata_manager.py b/lib/tool_shed/metadata/repository_metadata_manager.py index 26716442d944..b750e2250ceb 100644 --- a/lib/tool_shed/metadata/repository_metadata_manager.py +++ b/lib/tool_shed/metadata/repository_metadata_manager.py @@ -39,6 +39,7 @@ RepositoryMetadata, User, ) +from tool_shed.webapp.model.db import get_repository_by_name_and_owner log = logging.getLogger(__name__) @@ -551,7 +552,7 @@ def different_revision_defines_tip_only_repository_dependency(self, rd_tup, repo cleaned_tool_shed = common_util.remove_protocol_from_tool_shed_url(tool_shed) if cleaned_rd_tool_shed == cleaned_tool_shed and rd_name == name and rd_owner == owner: # Determine if the repository represented by the dependency tuple is an instance of the repository type TipOnly. - required_repository = repository_util.get_repository_by_name_and_owner(self.app, name, owner) + required_repository = get_repository_by_name_and_owner(self.app.model.context, name, owner) repository_type_class = self.app.repository_types_registry.get_class_by_label(required_repository.type) return isinstance(repository_type_class, TipOnly) return False diff --git a/lib/tool_shed/tools/tool_validator.py b/lib/tool_shed/tools/tool_validator.py index 2a99c7eb2bd5..a4d441977d26 100644 --- a/lib/tool_shed/tools/tool_validator.py +++ b/lib/tool_shed/tools/tool_validator.py @@ -2,6 +2,7 @@ import logging import os import tempfile +from typing import TYPE_CHECKING from galaxy.tool_shed.tools.tool_validator import ToolValidator as GalaxyToolValidator from galaxy.tools import Tool @@ -14,10 +15,18 @@ tool_util, ) +if TYPE_CHECKING: + from tool_shed.structured_app import ToolShedApp + log = logging.getLogger(__name__) class ToolValidator(GalaxyToolValidator): + app: "ToolShedApp" + + def __init__(self, app: "ToolShedApp"): + super().__init__(app) + def can_use_tool_config_disk_file(self, repository, repo, file_path, changeset_revision): """ Determine if repository's tool config file on disk can be used. This method diff --git a/lib/tool_shed/tools/tool_version_manager.py b/lib/tool_shed/tools/tool_version_manager.py index adb15537dc83..627cd07c324c 100644 --- a/lib/tool_shed/tools/tool_version_manager.py +++ b/lib/tool_shed/tools/tool_version_manager.py @@ -1,4 +1,5 @@ import logging +from typing import TYPE_CHECKING from tool_shed.util import ( hg_util, @@ -6,11 +7,14 @@ repository_util, ) +if TYPE_CHECKING: + from tool_shed.structured_app import ToolShedApp + log = logging.getLogger(__name__) class ToolVersionManager: - def __init__(self, app): + def __init__(self, app: "ToolShedApp"): self.app = app def get_version_lineage_for_tool(self, repository_id, repository_metadata, guid): diff --git a/lib/tool_shed/util/metadata_util.py b/lib/tool_shed/util/metadata_util.py index b7d18f2eaec7..3d12b7a80f8a 100644 --- a/lib/tool_shed/util/metadata_util.py +++ b/lib/tool_shed/util/metadata_util.py @@ -11,9 +11,9 @@ INITIAL_CHANGELOG_HASH, reversed_lower_upper_bounded_changelog, ) -from galaxy.tool_shed.util.repository_util import get_repository_by_name_and_owner from galaxy.util.tool_shed.common_util import parse_repository_dependency_tuple from tool_shed.util.hg_util import changeset2rev +from tool_shed.webapp.model.db import get_repository_by_name_and_owner if TYPE_CHECKING: from tool_shed.structured_app import ToolShedApp @@ -24,7 +24,7 @@ log = logging.getLogger(__name__) -def get_all_dependencies(app, metadata_entry, processed_dependency_links=None): +def get_all_dependencies(app: "ToolShedApp", metadata_entry, processed_dependency_links=None): processed_dependency_links = processed_dependency_links or [] encoder = app.security.encode_id value_mapper = {"repository_id": encoder, "id": encoder, "user_id": encoder} @@ -42,6 +42,7 @@ def get_all_dependencies(app, metadata_entry, processed_dependency_links=None): repository = app.model.session.get( app.model.Repository, app.security.decode_id(dependency_dict["repository_id"]) ) + assert repository dependency_dict["repository"] = repository.to_dict(value_mapper=value_mapper) if dependency_metadata.includes_tools: dependency_dict["tools"] = dependency_metadata.metadata["tools"] @@ -80,10 +81,10 @@ def get_current_repository_metadata_for_changeset_revision(app, repository, chan return None -def get_dependencies_for_metadata_revision(app, metadata): +def get_dependencies_for_metadata_revision(app: "ToolShedApp", metadata): dependencies = [] for _shed, name, owner, changeset, _prior, _ in metadata["repository_dependencies"]: - required_repository = get_repository_by_name_and_owner(app, name, owner) + required_repository = get_repository_by_name_and_owner(app.model.context, name, owner) updated_changeset = get_next_downloadable_changeset_revision(app, required_repository, changeset) if updated_changeset is None: continue @@ -208,7 +209,9 @@ def get_previous_metadata_changeset_revision(app, repository, before_changeset_r previous_changeset_revision = changeset_revision -def get_repository_dependency_tups_from_repository_metadata(app, repository_metadata, deprecated_only=False): +def get_repository_dependency_tups_from_repository_metadata( + app: "ToolShedApp", repository_metadata, deprecated_only=False +): """ Return a list of of tuples defining repository objects required by the received repository. The returned list defines the entire repository dependency tree. This method is called only from the Tool Shed. @@ -227,7 +230,7 @@ def get_repository_dependency_tups_from_repository_metadata(app, repository_meta toolshed, name, owner, changeset_revision, pir, oicct = parse_repository_dependency_tuple( repository_dependency_tup ) - repository = get_repository_by_name_and_owner(app, name, owner) + repository = get_repository_by_name_and_owner(app.model.context, name, owner) if repository: if deprecated_only: if repository.deprecated: @@ -290,12 +293,12 @@ def get_repository_metadata_by_repository_id_changeset_revision(app, id, changes return get_repository_metadata_by_changeset_revision(app, id, changeset_revision) -def get_updated_changeset_revisions(app, name, owner, changeset_revision): +def get_updated_changeset_revisions(app: "ToolShedApp", name, owner, changeset_revision): """ Return a string of comma-separated changeset revision hashes for all available updates to the received changeset revision for the repository defined by the received name and owner. """ - repository = get_repository_by_name_and_owner(app, name, owner) + repository = get_repository_by_name_and_owner(app.model.context, name, owner) # Get the upper bound changeset revision. upper_bound_changeset_revision = get_next_downloadable_changeset_revision(app, repository, changeset_revision) # Build the list of changeset revision hashes defining each available update up to, but excluding diff --git a/lib/tool_shed/util/repository_util.py b/lib/tool_shed/util/repository_util.py index 1ae90d7e488c..39766f71e54f 100644 --- a/lib/tool_shed/util/repository_util.py +++ b/lib/tool_shed/util/repository_util.py @@ -34,16 +34,11 @@ get_repo_info_tuple_contents, get_repository_admin_role_name, get_repository_and_repository_dependencies_from_repo_info_dict, - get_repository_by_id, - get_repository_by_name, - get_repository_by_name_and_owner, get_repository_dependency_types, get_repository_for_dependency_relationship, get_repository_ids_requiring_prior_import_or_install, get_repository_owner, get_repository_owner_from_clone_url, - get_repository_query, - get_role_by_id, get_tool_shed_from_clone_url, get_tool_shed_repository_by_id, get_tool_shed_status_for_installed_repository, @@ -63,14 +58,20 @@ get_repository_metadata_by_changeset_revision, repository_metadata_by_changeset_revision, ) +from tool_shed.webapp import model +from tool_shed.webapp.model.db import ( + get_repository_by_name_and_owner, + get_repository_query, +) if TYPE_CHECKING: + from sqlalchemy.orm import scoped_session + from tool_shed.context import ( ProvidesRepositoriesContext, ProvidesUserContext, ) from tool_shed.structured_app import ToolShedApp - from tool_shed.webapp.model import Repository log = logging.getLogger(__name__) @@ -78,6 +79,18 @@ VALID_REPOSITORYNAME_RE = re.compile(r"^[a-z0-9\_]+$") +def get_repository_by_id(app: "ToolShedApp", id): + """Get a repository from the database via id.""" + sa_session = app.model.session + return sa_session.query(model.Repository).get(app.security.decode_id(id)) + + +def get_role_by_id(app: "ToolShedApp", role_id): + """Get a Role from the database by id.""" + sa_session = app.model.session + return sa_session.query(model.Role).get(app.security.decode_id(role_id)) + + def create_repo_info_dict( app: "ToolShedApp", repository_clone_url, @@ -114,7 +127,7 @@ def create_repo_info_dict( repository_dependencies will be None. """ repo_info_dict = {} - repository = get_repository_by_name_and_owner(app, repository_name, repository_owner) + repository = get_repository_by_name_and_owner(app.model.context, repository_name, repository_owner) if app.name == "tool_shed": # We're in the tool shed. repository_metadata = repository_metadata_by_changeset_revision(app.model, repository.id, changeset_revision) @@ -160,7 +173,7 @@ def create_repo_info_dict( return repo_info_dict -def create_repository_admin_role(app: "ToolShedApp", repository: "Repository"): +def create_repository_admin_role(app: "ToolShedApp", repository: model.Repository): """ Create a new role with name-spaced name based on the repository name and its owner's public user name. This will ensure that the role name is unique. @@ -168,12 +181,12 @@ def create_repository_admin_role(app: "ToolShedApp", repository: "Repository"): sa_session = app.model.session name = get_repository_admin_role_name(str(repository.name), str(repository.user.username)) description = "A user or group member with this role can administer this repository." - role = app.model.Role(name=name, description=description, type=app.model.Role.types.SYSTEM) + role = model.Role(name=name, description=description, type=model.Role.types.SYSTEM) sa_session.add(role) # Associate the role with the repository owner. - app.model.UserRoleAssociation(repository.user, role) + model.UserRoleAssociation(repository.user, role) # Associate the role with the repository. - rra = app.model.RepositoryRoleAssociation(repository, role) + rra = model.RepositoryRoleAssociation(repository, role) sa_session.add(rra) return role @@ -188,12 +201,12 @@ def create_repository( category_ids: Optional[List[str]] = None, remote_repository_url=None, homepage_url=None, -) -> Tuple["Repository", str]: +) -> Tuple[model.Repository, str]: """Create a new ToolShed repository""" category_ids = category_ids or [] sa_session = app.model.session # Add the repository record to the database. - repository = app.model.Repository( + repository = model.Repository( name=name, type=type, remote_repository_url=remote_repository_url, @@ -206,8 +219,8 @@ def create_repository( if category_ids: # Create category associations for category_id in category_ids: - category = sa_session.get(app.model.Category, app.security.decode_id(category_id)) - rca = app.model.RepositoryCategoryAssociation(repository, category) + category = sa_session.get(model.Category, app.security.decode_id(category_id)) + rca = model.RepositoryCategoryAssociation(repository, category) sa_session.add(rca) # Create an admin role for the repository. create_repository_admin_role(app, repository) @@ -238,7 +251,7 @@ def create_repository( def generate_sharable_link_for_repository_in_tool_shed( - repository: "Repository", changeset_revision: Optional[str] = None + repository: model.Repository, changeset_revision: Optional[str] = None ) -> str: """Generate the URL for sharing a repository that is in the tool shed.""" base_url = web.url_for("/", qualified=True).rstrip("/") @@ -248,9 +261,9 @@ def generate_sharable_link_for_repository_in_tool_shed( return sharable_url -def get_repository_in_tool_shed(app, id, eagerload_columns=None): +def get_repository_in_tool_shed(app: "ToolShedApp", id, eagerload_columns=None): """Get a repository on the tool shed side from the database via id.""" - q = get_repository_query(app) + q = get_repository_query(app.model.context) if eagerload_columns: q = q.options(joinedload(*eagerload_columns)) return q.get(app.security.decode_id(id)) @@ -325,15 +338,17 @@ def get_repo_info_dict(trans: "ProvidesRepositoriesContext", repository_id, chan def get_repositories_by_category( - app: "ToolShedApp", category_id, installable=False, sort_order="asc", sort_key="name", page=None, per_page=25 + app: "ToolShedApp", + category_id, + installable: bool = False, + sort_order="asc", + sort_key="name", + page: Optional[int] = None, + per_page: int = 25, ): repositories = [] for repository in get_repositories( app.model.session, - app.model.Repository, - app.model.RepositoryCategoryAssociation, - app.model.User, - app.model.RepositoryMetadata, category_id, installable, sort_order, @@ -369,35 +384,35 @@ def handle_role_associations(app: "ToolShedApp", role, repository, **kwd): repository_owner = repository.user if kwd.get("manage_role_associations_button", False): in_users_list = util.listify(kwd.get("in_users", [])) - in_users = [sa_session.get(app.model.User, x) for x in in_users_list] + users = [y for y in (sa_session.get(model.User, x) for x in in_users_list) if y is not None] # Make sure the repository owner is always associated with the repostory's admin role. owner_associated = False - for user in in_users: + for user in users: if user.id == repository_owner.id: owner_associated = True break if not owner_associated: - in_users.append(repository_owner) + users.append(repository_owner) message += "The repository owner must always be associated with the repository's administrator role. " status = "error" in_groups_list = util.listify(kwd.get("in_groups", [])) - in_groups = [sa_session.get(app.model.Group, x) for x in in_groups_list] + groups = [sa_session.get(model.Group, x) for x in in_groups_list] in_repositories = [repository] app.security_agent.set_entity_role_associations( - roles=[role], users=in_users, groups=in_groups, repositories=in_repositories + roles=[role], users=users, groups=groups, repositories=in_repositories ) sa_session.refresh(role) - message += f"Role {escape(str(role.name))} has been associated with {len(in_users)} users, {len(in_groups)} groups and {len(in_repositories)} repositories. " + message += f"Role {escape(str(role.name))} has been associated with {len(users)} users, {len(groups)} groups and {len(in_repositories)} repositories. " in_users = [] out_users = [] in_groups = [] out_groups = [] - for user in get_current_users(sa_session, app.model.User): + for user in get_current_users(sa_session): if user in [x.user for x in role.users]: in_users.append((user.id, user.email)) else: out_users.append((user.id, user.email)) - for group in get_current_groups(sa_session, app.model.Group): + for group in get_current_groups(sa_session): if group in [x.group for x in role.groups]: in_groups.append((group.id, group.name)) else: @@ -421,11 +436,13 @@ def change_repository_name_in_hgrc_file(hgrc_file: str, new_name: str) -> None: config.write(fh) -def update_repository(trans: "ProvidesUserContext", id: str, **kwds) -> Tuple[Optional["Repository"], Optional[str]]: +def update_repository( + trans: "ProvidesUserContext", id: str, **kwds +) -> Tuple[Optional[model.Repository], Optional[str]]: """Update an existing ToolShed repository""" app = trans.app sa_session = app.model.session - repository = sa_session.get(app.model.Repository, app.security.decode_id(id)) + repository = sa_session.get(model.Repository, app.security.decode_id(id)) if repository is None: return None, "Unknown repository ID" @@ -437,8 +454,8 @@ def update_repository(trans: "ProvidesUserContext", id: str, **kwds) -> Tuple[Op def update_validated_repository( - trans: "ProvidesUserContext", repository: "Repository", **kwds -) -> Tuple[Optional["Repository"], Optional[str]]: + trans: "ProvidesUserContext", repository: model.Repository, **kwds +) -> Tuple[Optional[model.Repository], Optional[str]]: """Update an existing ToolShed repository metadata once permissions have been checked.""" app = trans.app sa_session = app.model.session @@ -455,13 +472,13 @@ def update_validated_repository( if "category_ids" in kwds and isinstance(kwds["category_ids"], list): # Remove existing category associations - delete_repository_category_associations(sa_session, app.model.RepositoryCategoryAssociation, repository.id) + delete_repository_category_associations(sa_session, model.RepositoryCategoryAssociation, repository.id) # Then (re)create category associations for category_id in kwds["category_ids"]: - category = sa_session.get(app.model.Category, app.security.decode_id(category_id)) + category = sa_session.get(model.Category, app.security.decode_id(category_id)) if category: - rca = app.model.RepositoryCategoryAssociation(repository, category) + rca = model.RepositoryCategoryAssociation(repository, category) sa_session.add(rca) else: pass @@ -512,7 +529,7 @@ def validate_repository_name(app: "ToolShedApp", name, user): return "Enter the required repository name." if name in ["repos"]: return f"The term '{name}' is a reserved word in the Tool Shed, so it cannot be used as a repository name." - check_existing = get_repository_by_name_and_owner(app, name, user.username) + check_existing = get_repository_by_name_and_owner(app.model.context, name, user.username) if check_existing is not None: if check_existing.deleted: return f"You own a deleted repository named {escape(name)}, please choose a different name." @@ -528,42 +545,32 @@ def validate_repository_name(app: "ToolShedApp", name, user): def get_repositories( - session, - repository_model, - repository_category_assoc_model, - user_model, - repository_metadata_model, + session: "scoped_session", category_id, - installable, + installable: bool, sort_order, sort_key, - page, - per_page, + page: Optional[int], + per_page: int, ): - Repository = repository_model - RepositoryCategoryAssociation = repository_category_assoc_model - User = user_model - RepositoryMetadata = repository_metadata_model - stmt = ( - select(Repository) + select(model.Repository) .join( - RepositoryCategoryAssociation, - Repository.id == RepositoryCategoryAssociation.repository_id, + model.RepositoryCategoryAssociation, + model.Repository.id == model.RepositoryCategoryAssociation.repository_id, ) - .join(User, User.id == Repository.user_id) - .where(RepositoryCategoryAssociation.category_id == category_id) + .join(model.User, model.User.id == model.Repository.user_id) + .where(model.RepositoryCategoryAssociation.category_id == category_id) ) if installable: - stmt1 = select(RepositoryMetadata.repository_id) - stmt = stmt.where(Repository.id.in_(stmt1)) + stmt1 = select(model.RepositoryMetadata.repository_id) + stmt = stmt.where(model.Repository.id.in_(stmt1)) if sort_key == "owner": - sort_by = User.username + sort_col = model.User.username else: - sort_by = Repository.name - if sort_order == "desc": - sort_by = sort_by.desc() + sort_col = model.Repository.name + sort_by = sort_col.desc() if sort_order == "desc" else sort_col stmt = stmt.order_by(sort_by) if page is not None: @@ -575,13 +582,13 @@ def get_repositories( return session.scalars(stmt) -def get_current_users(session, user_model): - stmt = select(user_model).where(user_model.deleted == false()).order_by(user_model.email) +def get_current_users(session: "scoped_session"): + stmt = select(model.User).where(model.User.deleted == false()).order_by(model.User.email) return session.scalars(stmt) -def get_current_groups(session, group_model): - stmt = select(group_model).where(group_model.deleted == false()).order_by(group_model.name) +def get_current_groups(session: "scoped_session"): + stmt = select(model.Group).where(model.Group.deleted == false()).order_by(model.Group.name) return session.scalars(stmt) @@ -609,15 +616,12 @@ def delete_repository_category_associations(session, repository_category_assoc_m "get_repository_admin_role_name", "get_repository_and_repository_dependencies_from_repo_info_dict", "get_repository_by_id", - "get_repository_by_name", - "get_repository_by_name_and_owner", "get_repository_dependency_types", "get_repository_for_dependency_relationship", "get_repository_ids_requiring_prior_import_or_install", "get_repository_in_tool_shed", "get_repository_owner", "get_repository_owner_from_clone_url", - "get_repository_query", "get_role_by_id", "get_tool_shed_from_clone_url", "get_tool_shed_repository_by_id", diff --git a/lib/tool_shed/util/shed_util_common.py b/lib/tool_shed/util/shed_util_common.py index b3fbbb477532..68cf6bd4f881 100644 --- a/lib/tool_shed/util/shed_util_common.py +++ b/lib/tool_shed/util/shed_util_common.py @@ -89,7 +89,9 @@ def count_repositories_in_category(app: "ToolShedApp", category_id: str) -> int: .select_from(app.model.RepositoryCategoryAssociation) .where(app.model.RepositoryCategoryAssociation.category_id == app.security.decode_id(category_id)) ) - return app.model.session.scalar(stmt) + count = app.model.session.scalar(stmt) + assert count is not None + return count def get_categories(app: "ToolShedApp"): diff --git a/lib/tool_shed/webapp/api/repositories.py b/lib/tool_shed/webapp/api/repositories.py index 08a89df60088..057c5cab269a 100644 --- a/lib/tool_shed/webapp/api/repositories.py +++ b/lib/tool_shed/webapp/api/repositories.py @@ -52,6 +52,7 @@ tool_util, ) from tool_shed.webapp import model +from tool_shed.webapp.model.db import get_repository_by_name_and_owner from tool_shed_client.schema import ( CreateRepositoryRequest, LegacyInstallInfoTuple, @@ -92,7 +93,7 @@ def add_repository_registry_entry(self, trans, payload, **kwd): owner = payload.get("owner", "") if not owner: raise HTTPBadRequest(detail="Missing required parameter 'owner'.") - repository = repository_util.get_repository_by_name_and_owner(self.app, name, owner) + repository = get_repository_by_name_and_owner(self.app.model.context, name, owner) if repository is None: error_message = f"Cannot locate repository with name {name} and owner {owner}," log.debug(error_message) @@ -331,7 +332,7 @@ def remove_repository_registry_entry(self, trans, payload, **kwd): owner = payload.get("owner", "") if not owner: raise HTTPBadRequest(detail="Missing required parameter 'owner'.") - repository = repository_util.get_repository_by_name_and_owner(self.app, name, owner) + repository = get_repository_by_name_and_owner(self.app.model.context, name, owner) if repository is None: error_message = f"Cannot locate repository with name {name} and owner {owner}," log.debug(error_message) diff --git a/lib/tool_shed/webapp/api/repository_revisions.py b/lib/tool_shed/webapp/api/repository_revisions.py index 1e051aa43232..3b7404c0e472 100644 --- a/lib/tool_shed/webapp/api/repository_revisions.py +++ b/lib/tool_shed/webapp/api/repository_revisions.py @@ -11,11 +11,9 @@ web, ) from galaxy.webapps.base.controller import HTTPBadRequest -from tool_shed.util import ( - metadata_util, - repository_util, -) +from tool_shed.util import metadata_util from tool_shed.webapp.model import RepositoryMetadata +from tool_shed.webapp.model.db import get_repository_by_name_and_owner from . import BaseShedAPIController log = logging.getLogger(__name__) @@ -84,7 +82,7 @@ def repository_dependencies(self, trans, id, **kwd): rd_tups = metadata["repository_dependencies"]["repository_dependencies"] for rd_tup in rd_tups: tool_shed, name, owner, changeset_revision = rd_tup[0:4] - repository_dependency = repository_util.get_repository_by_name_and_owner(trans.app, name, owner) + repository_dependency = get_repository_by_name_and_owner(trans.sa_session, name, owner) if repository_dependency is None: log.debug(f"Cannot locate repository dependency {name} owned by {owner}.") continue diff --git a/lib/tool_shed/webapp/app.py b/lib/tool_shed/webapp/app.py index 1a10551ed0f7..1c2d99b6dd3b 100644 --- a/lib/tool_shed/webapp/app.py +++ b/lib/tool_shed/webapp/app.py @@ -85,7 +85,7 @@ def __init__(self, **kwd) -> None: self.user_manager = self._register_singleton(UserManager, UserManager(self, app_type="tool_shed")) self.api_keys_manager = self._register_singleton(ApiKeyManager) # initialize the Tool Shed tag handler. - self.tag_handler = CommunityTagHandler(self) + self.tag_handler = CommunityTagHandler(self.model.context) # Initialize the Tool Shed tool data tables. Never pass a configuration file here # because the Tool Shed should always have an empty dictionary! self.tool_data_tables = galaxy.tools.data.ToolDataTableManager(self.config.tool_data_path) diff --git a/lib/tool_shed/webapp/controllers/hg.py b/lib/tool_shed/webapp/controllers/hg.py index 7bb8b7bee214..a2b4cf8a3235 100644 --- a/lib/tool_shed/webapp/controllers/hg.py +++ b/lib/tool_shed/webapp/controllers/hg.py @@ -5,7 +5,7 @@ from galaxy import web from galaxy.exceptions import ObjectNotFound from galaxy.webapps.base.controller import BaseUIController -from tool_shed.util.repository_util import get_repository_by_name_and_owner +from tool_shed.webapp.model.db import get_repository_by_name_and_owner log = logging.getLogger(__name__) @@ -35,7 +35,7 @@ def make_web_app(): path_info = kwd.get("path_info", None) if path_info and len(path_info.split("/")) == 2: owner, name = path_info.split("/") - repository = get_repository_by_name_and_owner(trans.app, name, owner) + repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) if repository: if repository.deprecated: raise ObjectNotFound("Requested repository not found or deprecated.") diff --git a/lib/tool_shed/webapp/controllers/repository.py b/lib/tool_shed/webapp/controllers/repository.py index 5dffafadc011..e329526cf7d5 100644 --- a/lib/tool_shed/webapp/controllers/repository.py +++ b/lib/tool_shed/webapp/controllers/repository.py @@ -61,6 +61,10 @@ RepositoryCategoryAssociation, RepositoryMetadata, ) +from tool_shed.webapp.model.db import ( + get_repository_by_name, + get_repository_by_name_and_owner, +) from tool_shed.webapp.util import ratings_util log = logging.getLogger(__name__) @@ -136,7 +140,7 @@ def browse_categories(self, trans, **kwd): # We'll try to get the desired encoded repository id to pass on. try: repository_name = kwd["id"] - repository = repository_util.get_repository_by_name(trans.app, repository_name) + repository = get_repository_by_name(trans.sa_session, repository_name) kwd["id"] = trans.security.encode_id(repository.id) except Exception: pass @@ -561,7 +565,7 @@ def browse_valid_categories(self, trans, **kwd): # We'll try to get the desired encoded repository id to pass on. try: name = kwd["id"] - repository = repository_util.get_repository_by_name(trans.app, name) + repository = get_repository_by_name(trans.sa_session, name) kwd["id"] = trans.security.encode_id(repository.id) except Exception: pass @@ -1034,7 +1038,7 @@ def has_galaxy_utilities(repository_metadata): name = kwd.get("name", None) owner = kwd.get("owner", None) changeset_revision = kwd.get("changeset_revision", None) - repository = repository_util.get_repository_by_name_and_owner(trans.app, name, owner) + repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) repository_metadata = metadata_util.get_repository_metadata_by_changeset_revision( trans.app, trans.security.encode_id(repository.id), changeset_revision ) @@ -1132,7 +1136,7 @@ def get_ctx_rev(self, trans, **kwd): repository_name = kwd["name"] repository_owner = kwd["owner"] changeset_revision = kwd["changeset_revision"] - repository = repository_util.get_repository_by_name_and_owner(trans.app, repository_name, repository_owner) + repository = get_repository_by_name_and_owner(trans.sa_session, repository_name, repository_owner) repo = repository.hg_repo if ctx := hg_util.get_changectx_for_changeset(repo, changeset_revision): return str(ctx.rev()) @@ -1154,7 +1158,7 @@ def get_latest_downloadable_changeset_revision(self, trans, **kwd): repository_name = kwd.get("name", None) repository_owner = kwd.get("owner", None) if repository_name is not None and repository_owner is not None: - repository = repository_util.get_repository_by_name_and_owner(trans.app, repository_name, repository_owner) + repository = get_repository_by_name_and_owner(trans.sa_session, repository_name, repository_owner) if repository: return metadata_util.get_latest_downloadable_changeset_revision(trans.app, repository) return hg_util.INITIAL_CHANGELOG_HASH @@ -1169,7 +1173,7 @@ def get_readme_files(self, trans, **kwd): repository_owner = kwd.get("owner", None) changeset_revision = kwd.get("changeset_revision", None) if repository_name is not None and repository_owner is not None and changeset_revision is not None: - repository = repository_util.get_repository_by_name_and_owner(trans.app, repository_name, repository_owner) + repository = get_repository_by_name_and_owner(trans.sa_session, repository_name, repository_owner) return readmes(trans.app, repository, changeset_revision) return {} @@ -1182,7 +1186,7 @@ def get_repository_dependencies(self, trans, **kwd): name = kwd.get("name", None) owner = kwd.get("owner", None) changeset_revision = kwd.get("changeset_revision", None) - repository = repository_util.get_repository_by_name_and_owner(trans.app, name, owner) + repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) # get_repository_dependencies( self, app, changeset, toolshed_url ) dependencies = repository.get_repository_dependencies( trans.app, changeset_revision, web.url_for("/", qualified=True) @@ -1196,7 +1200,7 @@ def get_repository_id(self, trans, **kwd): """Given a repository name and owner, return the encoded repository id.""" repository_name = kwd["name"] repository_owner = kwd["owner"] - repository = repository_util.get_repository_by_name_and_owner(trans.app, repository_name, repository_owner) + repository = get_repository_by_name_and_owner(trans.sa_session, repository_name, repository_owner) if repository: return trans.security.encode_id(repository.id) return "" @@ -1251,7 +1255,7 @@ def get_repository_type(self, trans, **kwd): """Given a repository name and owner, return the type.""" repository_name = kwd["name"] repository_owner = kwd["owner"] - repository = repository_util.get_repository_by_name_and_owner(trans.app, repository_name, repository_owner) + repository = get_repository_by_name_and_owner(trans.sa_session, repository_name, repository_owner) return str(repository.type) @web.json @@ -1280,7 +1284,7 @@ def get_required_repo_info_dict(self, trans, encoded_str=None): prior_installation_required, only_if_compiling_contained_td, ) = common_util.parse_repository_dependency_tuple(required_repository_tup) - repository = repository_util.get_repository_by_name_and_owner(trans.app, name, owner) + repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) encoded_repository_ids.append(trans.security.encode_id(repository.id)) changeset_revisions.append(changeset_revision) if encoded_repository_ids and changeset_revisions: @@ -1298,7 +1302,7 @@ def get_tool_dependencies(self, trans, **kwd): name = kwd.get("name", None) owner = kwd.get("owner", None) changeset_revision = kwd.get("changeset_revision", None) - repository = repository_util.get_repository_by_name_and_owner(trans.app, name, owner) + repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) dependencies = repository.get_tool_dependencies(trans.app, changeset_revision) if len(dependencies) > 0: return encoding_util.tool_shed_encode(dependencies) @@ -1312,7 +1316,7 @@ def get_tool_dependencies_config_contents(self, trans, **kwd): """ name = kwd.get("name", None) owner = kwd.get("owner", None) - repository = repository_util.get_repository_by_name_and_owner(trans.app, name, owner) + repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) # TODO: We're currently returning the tool_dependencies.xml file that is available on disk. We need # to enhance this process to retrieve older versions of the tool-dependencies.xml file from the repository # manafest. @@ -1335,7 +1339,7 @@ def get_tool_dependency_definition_metadata(self, trans, **kwd): """ repository_name = kwd["name"] repository_owner = kwd["owner"] - repository = repository_util.get_repository_by_name_and_owner(trans.app, repository_name, repository_owner) + repository = get_repository_by_name_and_owner(trans.sa_session, repository_name, repository_owner) encoded_id = trans.app.security.encode_id(repository.id) repository_tip = repository.tip() repository_metadata = metadata_util.get_repository_metadata_by_changeset_revision( @@ -1352,7 +1356,7 @@ def get_tool_versions(self, trans, **kwd): name = kwd["name"] owner = kwd["owner"] changeset_revision = kwd["changeset_revision"] - repository = repository_util.get_repository_by_name_and_owner(trans.app, name, owner) + repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) repo = repository.hg_repo tool_version_dicts = [] for changeset in repo.changelog: @@ -1374,7 +1378,7 @@ def get_updated_repository_information(self, trans, name, owner, changeset_revis Generate a dictionary that contains the information about a repository that is necessary for installing it into a local Galaxy instance. """ - repository = repository_util.get_repository_by_name_and_owner(trans.app, name, owner) + repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) repository_id = trans.security.encode_id(repository.id) repository_clone_url = common_util.generate_clone_url_for_repository_in_tool_shed(trans.user, repository) repository_metadata = metadata_util.get_repository_metadata_by_changeset_revision( @@ -1909,7 +1913,7 @@ def next_installable_changeset_revision(self, trans, **kwd): name = kwd.get("name", None) owner = kwd.get("owner", None) changeset_revision = kwd.get("changeset_revision", None) - repository = repository_util.get_repository_by_name_and_owner(trans.app, name, owner) + repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) # Get the next installable changeset_revision beyond the received changeset_revision. next_changeset_revision = metadata_util.get_next_downloadable_changeset_revision( trans.app, repository, changeset_revision @@ -2008,7 +2012,7 @@ def previous_changeset_revisions(self, trans, from_tip=False, **kwd): name = kwd.get("name", None) owner = kwd.get("owner", None) if name is not None and owner is not None: - repository = repository_util.get_repository_by_name_and_owner(trans.app, name, owner) + repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) from_tip = util.string_as_bool(from_tip) if from_tip: changeset_revision = repository.tip() @@ -2207,7 +2211,7 @@ def sharable_owner(self, trans, owner): def sharable_repository(self, trans, owner, name): """Support for sharable URL for a specified repository, e.g. http://example.org/view/owner/name.""" try: - repository = repository_util.get_repository_by_name_and_owner(trans.app, name, owner) + repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) except Exception: repository = None if repository: @@ -2242,7 +2246,7 @@ def sharable_repository(self, trans, owner, name): def sharable_repository_revision(self, trans, owner, name, changeset_revision): """Support for sharable URL for a specified repository revision, e.g. http://example.org/view/owner/name/changeset_revision.""" try: - repository = repository_util.get_repository_by_name_and_owner(trans.app, name, owner) + repository = get_repository_by_name_and_owner(trans.sa_session, name, owner) except Exception: repository = None if repository: diff --git a/lib/tool_shed/webapp/model/__init__.py b/lib/tool_shed/webapp/model/__init__.py index b950938b869d..afe9d245f302 100644 --- a/lib/tool_shed/webapp/model/__init__.py +++ b/lib/tool_shed/webapp/model/__init__.py @@ -54,11 +54,7 @@ from galaxy.util.bunch import Bunch from galaxy.util.dictifiable import Dictifiable from galaxy.util.hash_util import new_insecure_hash -from tool_shed.dependencies.repository import relation_builder -from tool_shed.util import ( - hg_util, - metadata_util, -) +from tool_shed.util import hg_util from tool_shed.util.hgweb_config import hgweb_config_manager log = logging.getLogger(__name__) @@ -493,11 +489,14 @@ def get_repository_dependencies(self, app, changeset, toolshed_url): # have repository dependencies. However, if a readme file is uploaded, or some other change # is made that does not create a new downloadable changeset revision but updates the existing # one, we still want to be able to get repository dependencies. - repository_metadata = metadata_util.get_current_repository_metadata_for_changeset_revision(app, self, changeset) + from tool_shed.dependencies.repository.relation_builder import RelationBuilder + from tool_shed.util.metadata_util import get_current_repository_metadata_for_changeset_revision + + repository_metadata = get_current_repository_metadata_for_changeset_revision(app, self, changeset) if repository_metadata: metadata = repository_metadata.metadata if metadata: - rb = relation_builder.RelationBuilder(app, self, repository_metadata, toolshed_url) + rb = RelationBuilder(app, self, repository_metadata, toolshed_url) repository_dependencies = rb.get_repository_dependencies_for_changeset_revision() if repository_dependencies: return repository_dependencies @@ -507,14 +506,18 @@ def get_type_class(self, app): return app.repository_types_registry.get_class_by_label(self.type) def get_tool_dependencies(self, app, changeset_revision): - changeset_revision = metadata_util.get_next_downloadable_changeset_revision(app, self, changeset_revision) + from tool_shed.util.metadata_util import get_next_downloadable_changeset_revision + + changeset_revision = get_next_downloadable_changeset_revision(app, self, changeset_revision) for downloadable_revision in self.downloadable_revisions: if downloadable_revision.changeset_revision == changeset_revision: return downloadable_revision.metadata.get("tool_dependencies", {}) return {} def installable_revisions(self, app, sort_revisions=True): - return metadata_util.get_metadata_revisions(app, self, sort_revisions=sort_revisions) + from tool_shed.util.metadata_util import get_metadata_revisions + + return get_metadata_revisions(app, self, sort_revisions=sort_revisions) def is_new(self): tip_rev = self.hg_repo.changelog.tiprev() diff --git a/lib/tool_shed/webapp/model/db/__init__.py b/lib/tool_shed/webapp/model/db/__init__.py new file mode 100644 index 000000000000..990226999ce4 --- /dev/null +++ b/lib/tool_shed/webapp/model/db/__init__.py @@ -0,0 +1,36 @@ +from typing import TYPE_CHECKING + +from sqlalchemy import and_ +from sqlalchemy.orm import joinedload + +from tool_shed.webapp.model import ( + Repository, + User, +) + +if TYPE_CHECKING: + from sqlalchemy.orm import scoped_session + + +def get_repository_query(session: "scoped_session"): + return session.query(Repository) + + +def get_repository_by_name(session: "scoped_session", name): + """Get a repository from the database via name.""" + return get_repository_query(session).filter_by(name=name).first() + + +def get_repository_by_name_and_owner(session: "scoped_session", name, owner, eagerload_columns=None): + """Get a repository from the database via name and owner""" + repository_query = get_repository_query(session) + q = repository_query.filter( + and_( + Repository.name == name, + User.username == owner, + Repository.user_id == User.id, + ) + ) + if eagerload_columns: + q = q.options(joinedload(*eagerload_columns)) + return q.first() diff --git a/lib/tool_shed/webapp/templates/webapps/tool_shed/repository/common.mako b/lib/tool_shed/webapp/templates/webapps/tool_shed/repository/common.mako index e775470d39bb..ef01febb00de 100644 --- a/lib/tool_shed/webapp/templates/webapps/tool_shed/repository/common.mako +++ b/lib/tool_shed/webapp/templates/webapps/tool_shed/repository/common.mako @@ -534,7 +534,7 @@ <%def name="render_repository_dependency( repository_dependency, pad, parent, row_counter, row_is_header=False, render_repository_actions_for='tool_shed' )"> <% from galaxy.util import asbool - from tool_shed.util.repository_util import get_repository_by_name_and_owner + from tool_shed.webapp.model.db import get_repository_by_name_and_owner encoded_id = trans.security.encode_id( repository_dependency.id ) if trans.webapp.name == 'galaxy': if repository_dependency.tool_shed_repository_id: @@ -561,7 +561,7 @@ else: # We're in the tool shed. cell_type = 'td' - rd = get_repository_by_name_and_owner( trans.app, repository_name, repository_owner ) + rd = get_repository_by_name_and_owner(trans.sa_session, repository_name, repository_owner) %> <% from galaxy.util import asbool - from tool_shed.util.repository_util import get_repository_by_name_and_owner + from tool_shed.webapp.model.db import get_repository_by_name_and_owner encoded_id = trans.security.encode_id( repository_dependency.id ) if trans.webapp.name == 'galaxy': if repository_dependency.tool_shed_repository_id: @@ -590,7 +590,7 @@ else: # We're in the tool shed. cell_type = 'td' - rd = get_repository_by_name_and_owner( trans.app, repository_name, repository_owner ) + rd = get_repository_by_name_and_owner(trans.sa_session, repository_name, repository_owner) %> None: app = self._app sa_session = app.model.session job = sa_session.get(app.model.Job, app.security.decode_id(job_dict["id"])) - + assert job self._wait_for_external_state(sa_session, job, app.model.Job.states.RUNNING) assert not job.finished @@ -308,7 +308,7 @@ def test_external_job_delete(self) -> None: app = self._app sa_session = app.model.session job = sa_session.get(app.model.Job, app.security.decode_id(job_dict["id"])) - + assert job self._wait_for_external_state(sa_session, job, app.model.Job.states.RUNNING) external_id = job.job_runner_external_id @@ -342,6 +342,7 @@ def test_exit_code_127(self, history_id: str) -> None: job_id = app.security.decode_id(running_response.json()["jobs"][0]["id"]) sa_session = app.model.session job = sa_session.get(app.model.Job, job_id) + assert job self._wait_for_external_state(sa_session=sa_session, job=job, expected=app.model.Job.states.RUNNING) external_id = job.job_runner_external_id diff --git a/test/integration/test_page_revision_json_encoding.py b/test/integration/test_page_revision_json_encoding.py index f112b282ae2e..596291e9062d 100644 --- a/test/integration/test_page_revision_json_encoding.py +++ b/test/integration/test_page_revision_json_encoding.py @@ -33,6 +33,7 @@ def test_page_encoding(self, history_id: str): api_asserts.assert_status_code_is_ok(page_response) sa_session = self._app.model.session page_revision = sa_session.scalars(select(model.PageRevision).filter_by(content_format="html")).all()[0] + assert page_revision.content is not None assert history_num_re.search(page_revision.content), page_revision.content assert f'''id="History-{history_id}"''' not in page_revision.content, page_revision.content @@ -59,6 +60,7 @@ def test_page_encoding_markdown(self, history_id: str): api_asserts.assert_status_code_is_ok(page_response) sa_session = self._app.model.session page_revision = sa_session.scalars(select(model.PageRevision).filter_by(content_format="markdown")).all()[0] + assert page_revision.content is not None assert ( """```galaxy history_dataset_display(history_dataset_id=1) diff --git a/test/integration/test_remote_files_posix.py b/test/integration/test_remote_files_posix.py index 75a449ff2e45..9da0ae04f5cf 100644 --- a/test/integration/test_remote_files_posix.py +++ b/test/integration/test_remote_files_posix.py @@ -144,6 +144,7 @@ def test_links_by_default(self): assert content == "a\n", content stmt = select(Dataset).order_by(Dataset.create_time.desc()).limit(1) dataset = self._app.model.session.execute(stmt).unique().scalar_one() + assert dataset.external_filename is not None assert dataset.external_filename.endswith("/root/a") assert os.path.exists(dataset.external_filename) assert open(dataset.external_filename).read() == "a\n" diff --git a/test/integration/test_workflow_handler_configuration.py b/test/integration/test_workflow_handler_configuration.py index 27cfeb5375d8..ed8df0937e0c 100644 --- a/test/integration/test_workflow_handler_configuration.py +++ b/test/integration/test_workflow_handler_configuration.py @@ -7,6 +7,7 @@ import time from json import dumps +from galaxy import model from galaxy_test.base.populators import ( DatasetPopulator, WorkflowPopulator, @@ -127,7 +128,8 @@ def _get_workflow_invocations(self, history_id: str): # into Galaxy's internal state. app = self._app history_id = app.security.decode_id(history_id) - history = app.model.session.get(app.model.History, history_id) + history = app.model.session.get(model.History, history_id) + assert history is not None workflow_invocations = history.workflow_invocations return workflow_invocations diff --git a/test/unit/app/authnz/test_custos_authnz.py b/test/unit/app/authnz/test_custos_authnz.py index 70398ec71ec1..a0ba9a688987 100644 --- a/test/unit/app/authnz/test_custos_authnz.py +++ b/test/unit/app/authnz/test_custos_authnz.py @@ -6,7 +6,11 @@ datetime, timedelta, ) -from typing import Optional +from typing import ( + cast, + Optional, + TYPE_CHECKING, +) from unittest import SkipTest from urllib.parse import ( parse_qs, @@ -28,6 +32,9 @@ ) from galaxy.util.unittest import TestCase +if TYPE_CHECKING: + from sqlalchemy.orm import scoped_session + class TestCustosAuthnz(TestCase): _create_oauth2_session_called = False @@ -233,7 +240,7 @@ def __init__(self, app=None, user=None, history=None, **kwargs): self.cookies = {} self.cookies_args = {} self.request = Request() - self.sa_session = Session() + self.sa_session = cast("scoped_session", Session()) self.user = None def set_cookie(self, value, name=None, **kwargs): diff --git a/test/unit/app/jobs/test_job_wrapper.py b/test/unit/app/jobs/test_job_wrapper.py index 0a2976c0d2d5..4e9a754d507c 100644 --- a/test/unit/app/jobs/test_job_wrapper.py +++ b/test/unit/app/jobs/test_job_wrapper.py @@ -5,6 +5,7 @@ cast, Dict, Type, + TYPE_CHECKING, ) from galaxy.app_unittest_utils.tools_support import ( @@ -26,6 +27,9 @@ from galaxy.util.bunch import Bunch from galaxy.util.unittest import TestCase +if TYPE_CHECKING: + from sqlalchemy.orm import scoped_session + TEST_TOOL_ID = "cufftest" TEST_VERSION_COMMAND = "bwa --version" TEST_DEPENDENCIES_COMMANDS = ". /galaxy/modules/bwa/0.5.9/env.sh" @@ -51,7 +55,7 @@ def setUp(self): job.user = User() job.object_store_id = "foo" self.model_objects: Dict[Type[Base], Dict[int, Base]] = {Job: {345: job}} - self.app.model.session = MockContext(self.model_objects) + self.app.model.session = cast("scoped_session", MockContext(self.model_objects)) self.app._toolbox = cast(ToolBox, MockToolbox(MockTool(self))) self.working_directory = os.path.join(self.test_directory, "working") diff --git a/test/unit/app/jobs/test_runner_local.py b/test/unit/app/jobs/test_runner_local.py index ed897c9484b5..3358854eedc0 100644 --- a/test/unit/app/jobs/test_runner_local.py +++ b/test/unit/app/jobs/test_runner_local.py @@ -1,7 +1,11 @@ import os import threading import time -from typing import Optional +from typing import ( + cast, + Optional, + TYPE_CHECKING, +) import psutil @@ -14,6 +18,9 @@ from galaxy.util import bunch from galaxy.util.unittest import TestCase +if TYPE_CHECKING: + from sqlalchemy.orm import scoped_session + class TestLocalJobRunner(TestCase, UsesTools): def setUp(self): @@ -97,7 +104,7 @@ def test_shutdown_no_jobs(self): def test_stopping_job_at_shutdown(self): self.job_wrapper.command_line = '''python -c "import time; time.sleep(15)"''' - self.app.model.session = bunch.Bunch(add=lambda x: None, flush=lambda: None) + self.app.model.session = cast("scoped_session", bunch.Bunch(add=lambda x: None, flush=lambda: None)) runner = local.LocalJobRunner(self.app, 1) runner.start() self.app.config.monitor_thread_join_timeout = 15 diff --git a/test/unit/app/managers/test_JobConnectionsManager.py b/test/unit/app/managers/test_JobConnectionsManager.py index a3e12d91d89b..ba5033e595bd 100644 --- a/test/unit/app/managers/test_JobConnectionsManager.py +++ b/test/unit/app/managers/test_JobConnectionsManager.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import pytest from sqlalchemy import union @@ -7,9 +9,11 @@ HistoryDatasetCollectionAssociation, Job, ) -from galaxy.model.scoped_session import galaxy_scoped_session from galaxy.model.unittest_utils import GalaxyDataTestApp +if TYPE_CHECKING: + from sqlalchemy.orm import scoped_session + @pytest.fixture def sa_session(): @@ -23,7 +27,7 @@ def job_connections_manager(sa_session) -> JobConnectionsManager: # ============================================================================= -def setup_connected_dataset(sa_session: galaxy_scoped_session): +def setup_connected_dataset(sa_session: "scoped_session"): center_hda = HistoryDatasetAssociation(sa_session=sa_session, create_dataset=True) input_hda = HistoryDatasetAssociation(sa_session=sa_session, create_dataset=True) input_hdca = HistoryDatasetCollectionAssociation() @@ -52,7 +56,7 @@ def setup_connected_dataset(sa_session: galaxy_scoped_session): return center_hda, expected_graph -def setup_connected_dataset_collection(sa_session: galaxy_scoped_session): +def setup_connected_dataset_collection(sa_session: "scoped_session"): center_hdca = HistoryDatasetCollectionAssociation() input_hda1 = HistoryDatasetAssociation(sa_session=sa_session, create_dataset=True) input_hda2 = HistoryDatasetAssociation(sa_session=sa_session, create_dataset=True) diff --git a/test/unit/app/managers/test_user_file_sources.py b/test/unit/app/managers/test_user_file_sources.py index 5654a4ef14ec..841f79ce0838 100644 --- a/test/unit/app/managers/test_user_file_sources.py +++ b/test/unit/app/managers/test_user_file_sources.py @@ -824,7 +824,7 @@ def _assert_secret_absent(self, user_file_source: UserFileSourceModel, secret_na assert sec_val in ["", None] def _assert_modify_throws_exception( - self, user_file_source: UserFileSourceModel, modify: ModifyInstancePayload, exception_type: Type + self, user_file_source: UserFileSourceModel, modify: ModifyInstancePayload, exception_type: Type[Exception] ): exception_thrown = False try: diff --git a/test/unit/app/managers/test_user_object_stores.py b/test/unit/app/managers/test_user_object_stores.py index 1ad20cc9af91..5485df25a0f6 100644 --- a/test/unit/app/managers/test_user_object_stores.py +++ b/test/unit/app/managers/test_user_object_stores.py @@ -504,7 +504,10 @@ def _init_managers(self, tmp_path, config_dict=None): self.manager = manager def _assert_modify_throws_exception( - self, user_object_store: UserConcreteObjectStoreModel, modify: ModifyInstancePayload, exception_type: Type + self, + user_object_store: UserConcreteObjectStoreModel, + modify: ModifyInstancePayload, + exception_type: Type[Exception], ): exception_thrown = False try: diff --git a/test/unit/data/model/test_model_store.py b/test/unit/data/model/test_model_store.py index 1ccec1ee245a..220c010a0877 100644 --- a/test/unit/data/model/test_model_store.py +++ b/test/unit/data/model/test_model_store.py @@ -23,6 +23,7 @@ from galaxy import model from galaxy.model import store from galaxy.model.metadata import MetadataTempFile +from galaxy.model.store import SessionlessContext from galaxy.model.unittest_utils import GalaxyDataTestApp from galaxy.model.unittest_utils.store_fixtures import ( deferred_hda_model_store_dict, @@ -922,6 +923,7 @@ def test_sessionless_import_edit_datasets(): import_model_store.perform_import() # Not using app.sa_session but a session mock that has a query/find pattern emulating usage # of real sa_session. + assert isinstance(import_model_store.sa_session, SessionlessContext) d1 = import_model_store.sa_session.query(model.HistoryDatasetAssociation).find(h.datasets[0].id) d2 = import_model_store.sa_session.query(model.HistoryDatasetAssociation).find(h.datasets[1].id) assert d1 is not None diff --git a/test/unit/data/test_galaxy_mapping.py b/test/unit/data/test_galaxy_mapping.py index 67313f5e1ecf..ad709bccb548 100644 --- a/test/unit/data/test_galaxy_mapping.py +++ b/test/unit/data/test_galaxy_mapping.py @@ -11,9 +11,8 @@ ) import galaxy.datatypes.registry -import galaxy.model -import galaxy.model.mapping as mapping from galaxy import model +from galaxy.model import mapping from galaxy.model.database_utils import create_database from galaxy.model.metadata import MetadataTempFile from galaxy.model.orm.util import ( @@ -27,7 +26,7 @@ datatypes_registry = galaxy.datatypes.registry.Registry() datatypes_registry.load_datatypes() -galaxy.model.set_datatypes_registry(datatypes_registry) +model.set_datatypes_registry(datatypes_registry) DB_URI = "sqlite:///:memory:" # docker run -e POSTGRES_USER=galaxy -p 5432:5432 -d postgres @@ -269,7 +268,6 @@ def test_flush_refreshes(self): # Normally I don't believe in unit testing library code, but the behaviors around attribute # states and flushing in SQL Alchemy is very subtle and it is good to have a executable # reference for how it behaves in the context of Galaxy objects. - model = self.model user = model.User(email=random_email(), password="password") galaxy_session = model.GalaxySession() galaxy_session_other = model.GalaxySession() @@ -305,7 +303,7 @@ def test_flush_refreshes(self): self._non_empty_flush() if session().in_transaction(): session.commit() - assert expected_id == galaxy.model.cached_id(galaxy_model_object) + assert expected_id == model.cached_id(galaxy_model_object) assert "id" in inspect(galaxy_model_object).unloaded # Keeping the following failed experiments here for future reference, @@ -333,7 +331,7 @@ def test_flush_refreshes(self): session.flush() if session().in_transaction(): session.commit() - assert galaxy.model.cached_id(galaxy_model_object_new) + assert model.cached_id(galaxy_model_object_new) assert "id" in inspect(galaxy_model_object_new).unloaded # Verify a targeted flush prevent expiring unrelated objects. @@ -580,14 +578,14 @@ def test_cannot_make_private_objectstore_dataset_public(self): with pytest.raises(Exception) as exec_info: self._make_owned(security_agent, u_from, d1) - assert galaxy.model.CANNOT_SHARE_PRIVATE_DATASET_MESSAGE in str(exec_info.value) + assert model.CANNOT_SHARE_PRIVATE_DATASET_MESSAGE in str(exec_info.value) def test_cannot_make_private_objectstore_dataset_shared(self): security_agent = GalaxyRBACAgent(self.model.session) u_from, u_to, _ = self._three_users("cannot_make_private_shared") - h = self.model.History(name="History for Prevent Sharing", user=u_from) - d1 = self.model.HistoryDatasetAssociation( + h = model.History(name="History for Prevent Sharing", user=u_from) + d1 = model.HistoryDatasetAssociation( extension="txt", history=h, create_dataset=True, sa_session=self.model.session ) self.persist(h, d1) @@ -597,14 +595,14 @@ def test_cannot_make_private_objectstore_dataset_shared(self): with pytest.raises(Exception) as exec_info: security_agent.privately_share_dataset(d1.dataset, [u_to]) - assert galaxy.model.CANNOT_SHARE_PRIVATE_DATASET_MESSAGE in str(exec_info.value) + assert model.CANNOT_SHARE_PRIVATE_DATASET_MESSAGE in str(exec_info.value) def test_cannot_set_dataset_permisson_on_private(self): security_agent = GalaxyRBACAgent(self.model.session) u_from, u_to, _ = self._three_users("cannot_set_permissions_on_private") - h = self.model.History(name="History for Prevent Sharing", user=u_from) - d1 = self.model.HistoryDatasetAssociation( + h = model.History(name="History for Prevent Sharing", user=u_from) + d1 = model.HistoryDatasetAssociation( extension="txt", history=h, create_dataset=True, sa_session=self.model.session ) self.persist(h, d1) @@ -617,14 +615,14 @@ def test_cannot_set_dataset_permisson_on_private(self): with pytest.raises(Exception) as exec_info: security_agent.set_dataset_permission(d1.dataset, {access_action: [role]}) - assert galaxy.model.CANNOT_SHARE_PRIVATE_DATASET_MESSAGE in str(exec_info.value) + assert model.CANNOT_SHARE_PRIVATE_DATASET_MESSAGE in str(exec_info.value) def test_cannot_make_private_dataset_public(self): security_agent = GalaxyRBACAgent(self.model.session) u_from, u_to, u_other = self._three_users("cannot_make_private_dataset_public") - h = self.model.History(name="History for Annotation", user=u_from) - d1 = self.model.HistoryDatasetAssociation( + h = model.History(name="History for Annotation", user=u_from) + d1 = model.HistoryDatasetAssociation( extension="txt", history=h, create_dataset=True, sa_session=self.model.session ) self.persist(h, d1) @@ -634,7 +632,7 @@ def test_cannot_make_private_dataset_public(self): with pytest.raises(Exception) as exec_info: security_agent.make_dataset_public(d1.dataset) - assert galaxy.model.CANNOT_SHARE_PRIVATE_DATASET_MESSAGE in str(exec_info.value) + assert model.CANNOT_SHARE_PRIVATE_DATASET_MESSAGE in str(exec_info.value) def _three_users(self, suffix): email_from = f"user_{suffix}e1@example.com" @@ -667,7 +665,7 @@ def _set_permissions(self, security_agent, dataset, permissions): def new_hda(self, history, **kwds): object_store_id = kwds.pop("object_store_id", None) - hda = self.model.HistoryDatasetAssociation(create_dataset=True, sa_session=self.model.session, **kwds) + hda = model.HistoryDatasetAssociation(create_dataset=True, sa_session=self.model.session, **kwds) if object_store_id is not None: hda.dataset.object_store_id = object_store_id return history.add_dataset(hda) @@ -686,8 +684,8 @@ def _db_uri(cls): def _invocation_for_workflow(user, workflow): - h1 = galaxy.model.History(name="WorkflowHistory1", user=user) - workflow_invocation = galaxy.model.WorkflowInvocation() + h1 = model.History(name="WorkflowHistory1", user=user) + workflow_invocation = model.WorkflowInvocation() workflow_invocation.workflow = workflow workflow_invocation.history = h1 workflow_invocation.state = "new" @@ -695,10 +693,10 @@ def _invocation_for_workflow(user, workflow): def _workflow_from_steps(user, steps): - stored_workflow = galaxy.model.StoredWorkflow() + stored_workflow = model.StoredWorkflow() add_object_to_object_session(stored_workflow, user) stored_workflow.user = user - workflow = galaxy.model.Workflow() + workflow = model.Workflow() if steps: for step in steps: if get_object_session(step): diff --git a/test/unit/data/test_quota.py b/test/unit/data/test_quota.py index f06984e66113..bd8e2352684a 100644 --- a/test/unit/data/test_quota.py +++ b/test/unit/data/test_quota.py @@ -1,4 +1,5 @@ import uuid +from decimal import Decimal from galaxy import model from galaxy.model.unittest_utils.utils import random_email @@ -16,9 +17,8 @@ class TestPurgeUsage(BaseModelTestCase): def setUp(self): super().setUp() - model = self.model u = model.User(email=random_email(), password="password") - u.disk_usage = 25 + u.disk_usage = Decimal(25) self.persist(u) h = model.History(name="History for Purging", user=u) @@ -27,7 +27,7 @@ def setUp(self): self.h = h def _setup_dataset(self): - d1 = self.model.HistoryDatasetAssociation( + d1 = model.HistoryDatasetAssociation( extension="txt", history=self.h, create_dataset=True, sa_session=self.model.session ) d1.dataset.total_size = 10 @@ -39,7 +39,7 @@ def test_calculate_usage(self): quota_source_info = QuotaSourceInfo(None, True) d1.purge_usage_from_quota(self.u, quota_source_info) self.persist(self.u) - assert int(self.u.disk_usage) == 15 + assert self.u.disk_usage == Decimal(15) def test_calculate_usage_untracked(self): # test quota tracking off on the objectstore @@ -47,7 +47,7 @@ def test_calculate_usage_untracked(self): quota_source_info = QuotaSourceInfo(None, False) d1.purge_usage_from_quota(self.u, quota_source_info) self.persist(self.u) - assert int(self.u.disk_usage) == 25 + assert self.u.disk_usage == Decimal(25) def test_calculate_usage_per_source(self): self.u.adjust_total_disk_usage(124, "myquotalabel") @@ -57,7 +57,7 @@ def test_calculate_usage_per_source(self): quota_source_info = QuotaSourceInfo("myquotalabel", True) d1.purge_usage_from_quota(self.u, quota_source_info) self.persist(self.u) - assert int(self.u.disk_usage) == 25 + assert self.u.disk_usage == Decimal(25) usages = self.u.dictify_usage() assert len(usages) == 2 @@ -67,7 +67,6 @@ def test_calculate_usage_per_source(self): class TestCalculateUsage(BaseModelTestCase): def setUp(self): - model = self.model u = model.User(email=f"calc_usage{uuid.uuid1()}@example.com", password="password") self.persist(u) h = model.History(name="History for Calculated Usage", user=u) @@ -76,7 +75,6 @@ def setUp(self): self.h = h def _add_dataset(self, total_size, object_store_id=None): - model = self.model d1 = model.HistoryDatasetAssociation( extension="txt", history=self.h, create_dataset=True, sa_session=self.model.session ) @@ -86,7 +84,6 @@ def _add_dataset(self, total_size, object_store_id=None): return d1 def test_calculate_usage(self): - model = self.model u = self.u h = self.h @@ -141,7 +138,7 @@ def test_calculate_usage_readjusts_incorrect_quota(self): self._refresh_user_and_assert_disk_usage_is(10) # lets break this to simulate the actual bugs we observe in Galaxy. - u.disk_usage = -10 + u.disk_usage = Decimal(-10) self.persist(u) self._refresh_user_and_assert_disk_usage_is(-10) @@ -150,7 +147,7 @@ def test_calculate_usage_readjusts_incorrect_quota(self): self._refresh_user_and_assert_disk_usage_is(10) # break it again - u.disk_usage = 1000 + u.disk_usage = Decimal(1000) self.persist(u) self._refresh_user_and_assert_disk_usage_is(1000) @@ -188,7 +185,6 @@ def test_calculate_usage_disabled_quota(self): assert u.calculate_disk_usage_default_source(object_store) == 15 def test_calculate_usage_alt_quota(self): - model = self.model u = self.u self._add_dataset(10) @@ -202,7 +198,7 @@ def test_calculate_usage_alt_quota(self): object_store = MockObjectStore(quota_source_map) u.calculate_and_set_disk_usage(object_store) - model.context.refresh(u) + self.model.context.refresh(u) usages = u.dictify_usage(object_store) assert len(usages) == 2 assert usages[0].quota_source_label is None @@ -224,7 +220,6 @@ def test_calculate_usage_alt_quota(self): assert usage.total_disk_usage == 0 def test_calculate_usage_removes_unused_quota_labels(self): - model = self.model u = self.u d = self._add_dataset(10) @@ -238,7 +233,7 @@ def test_calculate_usage_removes_unused_quota_labels(self): object_store = MockObjectStore(quota_source_map) u.calculate_and_set_disk_usage(object_store) - model.context.refresh(u) + self.model.context.refresh(u) usages = u.dictify_usage() assert len(usages) == 2 assert usages[0].quota_source_label is None @@ -249,7 +244,7 @@ def test_calculate_usage_removes_unused_quota_labels(self): alt_source.default_quota_source = "new_alt_source" u.calculate_and_set_disk_usage(object_store) - model.context.refresh(u) + self.model.context.refresh(u) usages = u.dictify_usage() assert len(usages) == 2 assert usages[0].quota_source_label is None @@ -262,7 +257,7 @@ def test_calculate_usage_removes_unused_quota_labels(self): d.purge_usage_from_quota(u, quota_source_map.info) self.model.session.add(d) self.model.session.flush() - model.context.refresh(u) + self.model.context.refresh(u) usages = u.dictify_usage() assert len(usages) == 2 @@ -270,7 +265,6 @@ def test_calculate_usage_removes_unused_quota_labels(self): assert usages[0].total_disk_usage == 0 def test_dictify_usage_unused_quota_labels(self): - model = self.model u = self.u self._add_dataset(10) @@ -287,12 +281,11 @@ def test_dictify_usage_unused_quota_labels(self): object_store = MockObjectStore(quota_source_map) u.calculate_and_set_disk_usage(object_store) - model.context.refresh(u) + self.model.context.refresh(u) usages = u.dictify_usage(object_store) assert len(usages) == 3 def test_calculate_usage_default_storage_disabled(self): - model = self.model u = self.u self._add_dataset(10) @@ -305,7 +298,7 @@ def test_calculate_usage_default_storage_disabled(self): object_store = MockObjectStore(quota_source_map) u.calculate_and_set_disk_usage(object_store) - model.context.refresh(u) + self.model.context.refresh(u) usages = u.dictify_usage(object_store) assert len(usages) == 2 assert usages[0].quota_source_label is None @@ -315,8 +308,7 @@ def test_calculate_usage_default_storage_disabled(self): assert usages[1].total_disk_usage == 15 def test_update_usage_from_labeled_to_unlabeled(self): - model = self.model - quota_agent = DatabaseQuotaAgent(model) + quota_agent = DatabaseQuotaAgent(self.model) u = self.u self._add_dataset(10) @@ -336,8 +328,7 @@ def test_update_usage_from_labeled_to_unlabeled(self): self._refresh_user_and_assert_disk_usage_is(0, "alt_source") def test_update_usage_from_unlabeled_to_labeled(self): - model = self.model - quota_agent = DatabaseQuotaAgent(model) + quota_agent = DatabaseQuotaAgent(self.model) u = self.u d = self._add_dataset(10) @@ -363,16 +354,15 @@ def _refresh_user_and_assert_disk_usage_is(self, usage, label=None): assert u.disk_usage == usage else: usages = u.dictify_usage() - for u in usages: - if u.quota_source_label == label: - assert int(u.total_disk_usage) == int(usage) + for uqbu in usages: + if uqbu.quota_source_label == label: + assert int(uqbu.total_disk_usage) == int(usage) class TestQuota(BaseModelTestCase): def setUp(self): super().setUp() - model = self.model - self.quota_agent = DatabaseQuotaAgent(model) + self.quota_agent = DatabaseQuotaAgent(self.model) def test_quota(self): u = model.User(email="quota@example.com", password="password") @@ -420,7 +410,6 @@ def test_quota(self): self._assert_user_quota_is(u, None) def test_labeled_quota(self): - model = self.model u = model.User(email="labeled_quota@example.com", password="password") self.persist(u) @@ -457,11 +446,11 @@ def _assert_user_quota_is(self, user, amount, quota_source_label=None): if quota_source_label is None: if amount is None: user.total_disk_usage = 1000 - job = self.model.Job() + job = model.Job() job.user = user assert not self.quota_agent.is_over_quota(None, job, None) else: - job = self.model.Job() + job = model.Job() job.user = user user.total_disk_usage = amount - 1 assert not self.quota_agent.is_over_quota(None, job, None) @@ -471,7 +460,6 @@ def _assert_user_quota_is(self, user, amount, quota_source_label=None): class TestUsage(BaseModelTestCase): def test_usage(self): - model = self.model u = model.User(email="usage@example.com", password="password") self.persist(u) @@ -481,7 +469,6 @@ def test_usage(self): assert u.get_disk_usage() == 123 def test_labeled_usage(self): - model = self.model u = model.User(email="labeled.usage@example.com", password="password") self.persist(u) assert len(u.quota_source_usages) == 0 diff --git a/test/unit/tool_shed/test_tool_panel_manager.py b/test/unit/tool_shed/test_tool_panel_manager.py index 6907e81ddb8c..c43005d37157 100644 --- a/test/unit/tool_shed/test_tool_panel_manager.py +++ b/test/unit/tool_shed/test_tool_panel_manager.py @@ -8,11 +8,16 @@ from galaxy.tool_shed.galaxy_install.tools import tool_panel_manager from galaxy.util import parse_xml from tool_shed.tools import tool_version_manager +from ._util import TestToolShedApp DEFAULT_GUID = "123456" class TestToolPanelManager(BaseToolBoxTestCase): + def setUp(self): + super().setUp() + self.ts_app = TestToolShedApp() + def get_new_toolbox(self): return SimplifiedToolBox(self) @@ -203,4 +208,4 @@ def tpm(self): @property def tvm(self): - return tool_version_manager.ToolVersionManager(self.app) + return tool_version_manager.ToolVersionManager(self.ts_app) From 7b94bb0a4e8d139e32a9230d28ae008a084bbb1a Mon Sep 17 00:00:00 2001 From: Nicola Soranzo Date: Wed, 12 Feb 2025 10:33:25 +0000 Subject: [PATCH 4/5] More specific ``RequireAppT`` type --- .../tool_shed/tools/data_table_manager.py | 75 ++++++++++--------- lib/galaxy/tool_shed/tools/tool_validator.py | 13 ++-- lib/galaxy/tool_shed/util/repository_util.py | 4 +- lib/tool_shed/structured_app.py | 11 ++- lib/tool_shed/util/commit_util.py | 4 +- 5 files changed, 61 insertions(+), 46 deletions(-) diff --git a/lib/galaxy/tool_shed/tools/data_table_manager.py b/lib/galaxy/tool_shed/tools/data_table_manager.py index f5d0a5669a7d..56358c965c4f 100644 --- a/lib/galaxy/tool_shed/tools/data_table_manager.py +++ b/lib/galaxy/tool_shed/tools/data_table_manager.py @@ -4,7 +4,6 @@ from typing import ( List, TYPE_CHECKING, - Union, ) from galaxy.tool_shed.galaxy_install.client import InstallationTarget @@ -16,18 +15,49 @@ from galaxy.util.tool_shed import xml_util if TYPE_CHECKING: - from galaxy.structured_app import BasicSharedApp + from galaxy.model.tool_shed_install import ToolShedRepository + from galaxy.util.path import StrPath + from tool_shed.structured_app import RequiredAppT log = logging.getLogger(__name__) -RequiredAppT = Union["BasicSharedApp", InstallationTarget] +class BaseShedToolDataTableManager: + def __init__(self, app: "RequiredAppT"): + self.app = app + + def handle_sample_tool_data_table_conf_file(self, filename: "StrPath", persist: bool = False): + """ + Parse the incoming filename and add new entries to the in-memory + self.app.tool_data_tables dictionary. If persist is True (should + only occur if call is from the Galaxy side, not the tool shed), the + new entries will be appended to Galaxy's shed_tool_data_table_conf.xml + file on disk. + """ + error = False + try: + new_table_elems, message = self.app.tool_data_tables.add_new_entries_from_config_file( + config_filename=filename, + tool_data_path=self.app.config.shed_tool_data_path, + shed_tool_data_table_config=self.app.config.shed_tool_data_table_config, + persist=persist, + ) + if message: + error = True + except Exception as e: + message = str(e) + error = True + return error, message + + def reset_tool_data_tables(self): + # Reset the tool_data_tables to an empty dictionary. + self.app.tool_data_tables.data_tables = {} -class ShedToolDataTableManager: - app: RequiredAppT +class ShedToolDataTableManager(BaseShedToolDataTableManager): + app: InstallationTarget - def __init__(self, app: RequiredAppT): + def __init__(self, app: InstallationTarget): self.app = app def generate_repository_info_elem( @@ -105,30 +135,7 @@ def handle_missing_data_table_entry(self, relative_install_dir, tool_path, repos self.reset_tool_data_tables() return repository_tools_tups - def handle_sample_tool_data_table_conf_file(self, filename, persist=False): - """ - Parse the incoming filename and add new entries to the in-memory - self.app.tool_data_tables dictionary. If persist is True (should - only occur if call is from the Galaxy side, not the tool shed), the - new entries will be appended to Galaxy's shed_tool_data_table_conf.xml - file on disk. - """ - error = False - try: - new_table_elems, message = self.app.tool_data_tables.add_new_entries_from_config_file( - config_filename=filename, - tool_data_path=self.app.config.shed_tool_data_path, - shed_tool_data_table_config=self.app.config.shed_tool_data_table_config, - persist=persist, - ) - if message: - error = True - except Exception as e: - message = str(e) - error = True - return error, message - - def get_target_install_dir(self, tool_shed_repository): + def get_target_install_dir(self, tool_shed_repository: "ToolShedRepository"): tool_path, relative_target_dir = tool_shed_repository.get_tool_relative_path(self.app) # This is where index files will reside on a per repo/installed version basis. target_dir = os.path.join(self.app.config.shed_tool_data_path, relative_target_dir) @@ -136,7 +143,7 @@ def get_target_install_dir(self, tool_shed_repository): os.makedirs(target_dir) return target_dir, tool_path, relative_target_dir - def install_tool_data_tables(self, tool_shed_repository, tool_index_sample_files): + def install_tool_data_tables(self, tool_shed_repository: "ToolShedRepository", tool_index_sample_files): TOOL_DATA_TABLE_FILE_NAME = "tool_data_table_conf.xml" TOOL_DATA_TABLE_FILE_SAMPLE_NAME = f"{TOOL_DATA_TABLE_FILE_NAME}.sample" SAMPLE_SUFFIX = ".sample" @@ -168,7 +175,7 @@ def install_tool_data_tables(self, tool_shed_repository, tool_index_sample_files if tree: root = tree.getroot() if root.tag == "tables": - elems = list(root) + elems = list(iter(root)) else: log.warning( "The '%s' data table file has '%s' instead of as root element, skipping.", @@ -196,10 +203,6 @@ def install_tool_data_tables(self, tool_shed_repository, tool_index_sample_files self.app.tool_data_tables.to_xml_file(tool_data_table_conf_filename, elems) return tool_data_table_conf_filename, elems - def reset_tool_data_tables(self): - # Reset the tool_data_tables to an empty dictionary. - self.app.tool_data_tables.data_tables = {} - # For backwards compatibility with exisiting data managers ToolDataTableManager = ShedToolDataTableManager diff --git a/lib/galaxy/tool_shed/tools/tool_validator.py b/lib/galaxy/tool_shed/tools/tool_validator.py index 428601c65113..9f42cf009963 100644 --- a/lib/galaxy/tool_shed/tools/tool_validator.py +++ b/lib/galaxy/tool_shed/tools/tool_validator.py @@ -1,9 +1,7 @@ import logging +from typing import TYPE_CHECKING -from galaxy.tool_shed.tools.data_table_manager import ( - RequiredAppT, - ShedToolDataTableManager, -) +from galaxy.tool_shed.tools.data_table_manager import BaseShedToolDataTableManager from galaxy.tool_shed.util import ( basic_util, hg_util, @@ -16,13 +14,16 @@ ) from galaxy.tools.parameters import dynamic_options +if TYPE_CHECKING: + from tool_shed.structured_app import RequiredAppT + log = logging.getLogger(__name__) class ToolValidator: - def __init__(self, app: RequiredAppT): + def __init__(self, app: "RequiredAppT"): self.app = app - self.stdtm = ShedToolDataTableManager(self.app) + self.stdtm = BaseShedToolDataTableManager(self.app) def check_tool_input_params(self, repo_dir, tool_config_name, tool, sample_files): """ diff --git a/lib/galaxy/tool_shed/util/repository_util.py b/lib/galaxy/tool_shed/util/repository_util.py index 4bc101eca504..34d6f512686b 100644 --- a/lib/galaxy/tool_shed/util/repository_util.py +++ b/lib/galaxy/tool_shed/util/repository_util.py @@ -19,6 +19,7 @@ and_, false, ) +from typing_extensions import TypeIs from galaxy import util from galaxy.model.base import check_database_connection @@ -33,6 +34,7 @@ if TYPE_CHECKING: from galaxy.tool_shed.galaxy_install.client import InstallationTarget + from tool_shed.structured_app import RequiredAppT log = logging.getLogger(__name__) @@ -604,7 +606,7 @@ def get_tool_shed_status_for_installed_repository(app, repository: ToolShedRepos return get_tool_shed_status_for(tool_shed_registry, repository) -def is_tool_shed_client(app): +def is_tool_shed_client(app: "RequiredAppT") -> TypeIs["InstallationTarget"]: """ The tool shed and clients to the tool (i.e. Galaxy) require a lot of similar functionality in this file but with small differences. This diff --git a/lib/tool_shed/structured_app.py b/lib/tool_shed/structured_app.py index 8a38e008fbbe..e384682223fd 100644 --- a/lib/tool_shed/structured_app.py +++ b/lib/tool_shed/structured_app.py @@ -1,8 +1,13 @@ -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Union, +) from galaxy.structured_app import BasicSharedApp if TYPE_CHECKING: + from galaxy.tool_shed.galaxy_install.client import InstallationTarget + from galaxy.tools.data import ToolDataTableManager from tool_shed.managers.model_cache import ModelCache from tool_shed.repository_registry import RegistryInterface from tool_shed.repository_types.registry import Registry as RepositoryTypesRegistry @@ -18,3 +23,7 @@ class ToolShedApp(BasicSharedApp): hgweb_config_manager: "HgWebConfigManager" security_agent: "CommunityRBACAgent" model_cache: "ModelCache" + tool_data_tables: "ToolDataTableManager" + + +RequiredAppT = Union[ToolShedApp, "InstallationTarget"] diff --git a/lib/tool_shed/util/commit_util.py b/lib/tool_shed/util/commit_util.py index ab3864f8e98d..afbaa3f88039 100644 --- a/lib/tool_shed/util/commit_util.py +++ b/lib/tool_shed/util/commit_util.py @@ -18,9 +18,9 @@ from sqlalchemy.sql.expression import null import tool_shed.repository_types.util as rt_util +from galaxy.tool_shed.tools.data_table_manager import BaseShedToolDataTableManager from galaxy.util import checkers from galaxy.util.path import safe_relpath -from tool_shed.tools.data_table_manager import ShedToolDataTableManager from tool_shed.util import ( basic_util, hg_util, @@ -216,7 +216,7 @@ def handle_directory_changes( # Handle the special case where a tool_data_table_conf.xml.sample file is being uploaded # by parsing the file and adding new entries to the in-memory app.tool_data_tables # dictionary. - stdtm = ShedToolDataTableManager(app) + stdtm = BaseShedToolDataTableManager(app) error, message = stdtm.handle_sample_tool_data_table_conf_file(filename_in_archive, persist=False) if error: return ( From 0accf3ff6c33cd5b9f70ce941b68c4409f07b5e9 Mon Sep 17 00:00:00 2001 From: Nicola Soranzo Date: Fri, 14 Feb 2025 17:12:15 +0000 Subject: [PATCH 5/5] Improve type annotation of ``SharedModelMapping`` --- lib/galaxy/model/base.py | 23 +++++++++++++++++++---- lib/galaxy/model/mapping.py | 3 +++ lib/tool_shed/webapp/model/mapping.py | 6 ++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/lib/galaxy/model/base.py b/lib/galaxy/model/base.py index 06448c76ee5e..88557be25be5 100644 --- a/lib/galaxy/model/base.py +++ b/lib/galaxy/model/base.py @@ -16,6 +16,7 @@ Dict, List, Type, + TYPE_CHECKING, Union, ) @@ -28,6 +29,20 @@ from galaxy.util.bunch import Bunch +if TYPE_CHECKING: + from galaxy.model import ( + APIKeys as GalaxyAPIKeys, + GalaxySession as GalaxyGalaxySession, + PasswordResetToken as GalaxyPasswordResetToken, + User as GalaxyUser, + ) + from tool_shed.webapp.model import ( + APIKeys as ToolShedAPIKeys, + GalaxySession as ToolShedGalaxySession, + PasswordResetToken as ToolShedPasswordResetToken, + User as ToolShedUser, + ) + log = logging.getLogger(__name__) # Create a ContextVar with mutable state, this allows sync tasks in the context @@ -128,10 +143,10 @@ class SharedModelMapping(ModelMapping): a way to do app.model. for common code shared by the tool shed and Galaxy. """ - User: Type - GalaxySession: Type - APIKeys: Type - PasswordResetToken: Type + User: Union[Type["GalaxyUser"], Type["ToolShedUser"]] + GalaxySession: Union[Type["GalaxyGalaxySession"], Type["ToolShedGalaxySession"]] + APIKeys: Union[Type["GalaxyAPIKeys"], Type["ToolShedAPIKeys"]] + PasswordResetToken: Union[Type["GalaxyPasswordResetToken"], Type["ToolShedPasswordResetToken"]] def versioned_objects(iter): diff --git a/lib/galaxy/model/mapping.py b/lib/galaxy/model/mapping.py index c2a059c0582d..a57fdc9be849 100644 --- a/lib/galaxy/model/mapping.py +++ b/lib/galaxy/model/mapping.py @@ -2,6 +2,7 @@ from threading import local from typing import ( Optional, + Type, TYPE_CHECKING, ) @@ -17,6 +18,7 @@ from galaxy.model.triggers.update_audit_table import install as install_timestamp_triggers if TYPE_CHECKING: + from galaxy.model import User as GalaxyUser from galaxy.objectstore import BaseObjectStore log = logging.getLogger(__name__) @@ -25,6 +27,7 @@ class GalaxyModelMapping(SharedModelMapping): + User: Type["GalaxyUser"] security_agent: GalaxyRBACAgent thread_local_log: Optional[local] diff --git a/lib/tool_shed/webapp/model/mapping.py b/lib/tool_shed/webapp/model/mapping.py index 94ff0317f3c4..8f60f908b65b 100644 --- a/lib/tool_shed/webapp/model/mapping.py +++ b/lib/tool_shed/webapp/model/mapping.py @@ -8,6 +8,8 @@ Any, Dict, Optional, + Type, + TYPE_CHECKING, ) import tool_shed.webapp.model @@ -17,12 +19,16 @@ from tool_shed.webapp.model import mapper_registry from tool_shed.webapp.security import CommunityRBACAgent +if TYPE_CHECKING: + from tool_shed.webapp.model import User as ToolShedUser + log = logging.getLogger(__name__) metadata = mapper_registry.metadata class ToolShedModelMapping(SharedModelMapping): + User: Type["ToolShedUser"] security_agent: CommunityRBACAgent shed_counter: shed_statistics.ShedCounter create_tables: bool