Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotation improvements #19642

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/galaxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,8 @@ def __init__(self, configure_logging=True, use_converters=True, use_display_appl
self._register_singleton(ShortTermStorageMonitor, short_term_storage_manager) # type: ignore[type-abstract]

# Tag handler
self.tag_handler = self._register_singleton(GalaxyTagHandler)
tag_handler = GalaxyTagHandler(self.model.context)
self.tag_handler = self._register_singleton(GalaxyTagHandler, tag_handler)
self.user_manager = self._register_singleton(UserManager)
self._register_singleton(GalaxySessionManager)
self.hda_manager = self._register_singleton(HDAManager)
Expand Down
1 change: 0 additions & 1 deletion lib/galaxy/app_unittest_utils/tools_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def __init__(self, model_objects=None):
self.flushed = False
self.model_objects = model_objects or defaultdict(dict)
self.created_objects = []
self.current = self

def expunge_all(self):
self.expunged_all = True
Expand Down
10 changes: 5 additions & 5 deletions lib/galaxy/celery/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
)
from sqlalchemy.dialects.postgresql import insert as ps_insert
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import scoped_session

from galaxy.model import CeleryUserRateLimit
from galaxy.model.scoped_session import galaxy_scoped_session


class GalaxyTaskBeforeStart:
Expand Down Expand Up @@ -45,7 +45,7 @@ class GalaxyTaskBeforeStartUserRateLimit(GalaxyTaskBeforeStart):
def __init__(
self,
tasks_per_user_per_sec: float,
ga_scoped_session: galaxy_scoped_session,
ga_scoped_session: scoped_session,
):
try:
self.task_exec_countdown_secs = 1 / tasks_per_user_per_sec
Expand All @@ -68,7 +68,7 @@ def __call__(self, task: Task, task_id, args, kwargs):

@abstractmethod
def calculate_task_start_time(
self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime
self, user_id: int, sa_session: scoped_session, task_interval_secs: float, now: datetime.datetime
) -> datetime.datetime:
return now

Expand All @@ -80,7 +80,7 @@ class GalaxyTaskBeforeStartUserRateLimitPostgres(GalaxyTaskBeforeStartUserRateLi
"""

def calculate_task_start_time(
self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime
self, user_id: int, sa_session: scoped_session, task_interval_secs: float, now: datetime.datetime
) -> datetime.datetime:
update_stmt = (
update(CeleryUserRateLimit)
Expand Down Expand Up @@ -125,7 +125,7 @@ class GalaxyTaskBeforeStartUserRateLimitStandard(GalaxyTaskBeforeStartUserRateLi
)

def calculate_task_start_time(
self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime
self, user_id: int, sa_session: scoped_session, task_interval_secs: float, now: datetime.datetime
) -> datetime.datetime:
last_scheduled_time = None
last_scheduled_time = sa_session.scalars(self._select_stmt, {"userid": user_id}).first()
Expand Down
5 changes: 2 additions & 3 deletions lib/galaxy/managers/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
User,
)
from galaxy.model.base import ModelMapping
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.model.tags import GalaxyTagHandlerSession
from galaxy.schema.tasks import RequestUser
from galaxy.security.idencoding import IdEncodingHelper
Expand Down Expand Up @@ -155,10 +154,10 @@ def log_event(self, message, tool_id=None, **kwargs):
self.sa_session.commit()

@property
def sa_session(self) -> galaxy_scoped_session:
def sa_session(self):
"""Provide access to Galaxy's SQLAlchemy session.

:rtype: galaxy.model.scoped_session.galaxy_scoped_session
:rtype: sqlalchemy.orm.scoped_session
"""
return self.app.model.session

Expand Down
10 changes: 7 additions & 3 deletions lib/galaxy/managers/histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)

Expand Down Expand Up @@ -85,6 +86,9 @@
RawTextTerm,
)

if TYPE_CHECKING:
from sqlalchemy.engine import ScalarResult

log = logging.getLogger(__name__)

INDEX_SEARCH_FILTERS = {
Expand All @@ -95,7 +99,7 @@
}


class HistoryManager(sharable.SharableModelManager, deletable.PurgableManagerMixin, SortableManager):
class HistoryManager(sharable.SharableModelManager[model.History], deletable.PurgableManagerMixin, SortableManager):
model_class = model.History
foreign_key_name = "history"
user_share_model = model.HistoryUserShareAssociation
Expand All @@ -120,7 +124,7 @@ def __init__(

def index_query(
self, trans: ProvidesUserContext, payload: HistoryIndexQueryPayload, include_total_count: bool = False
) -> Tuple[List[model.History], int]:
) -> Tuple["ScalarResult[model.History]", Union[int, None]]:
show_deleted = False
show_own = payload.show_own
show_published = payload.show_published
Expand Down Expand Up @@ -234,7 +238,7 @@ def p_tag_filter(term_text: str, quoted: bool):
stmt = stmt.limit(payload.limit)
if payload.offset is not None:
stmt = stmt.offset(payload.offset)
return trans.sa_session.scalars(stmt), total_matches # type:ignore[return-value]
return trans.sa_session.scalars(stmt), total_matches

# .... sharable
# overriding to handle anonymous users' current histories in both cases
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/history_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def state_counts(self, history):
statement: Select = (
select(sql.column("state"), func.count()).select_from(contents_subquery).group_by(sql.column("state"))
)
counts = self.app.model.session().execute(statement).fetchall()
return dict(counts)
counts = self.app.model.session.execute(statement).fetchall()
return dict(counts) # type:ignore[arg-type]

def active_counts(self, history):
"""
Expand Down
8 changes: 6 additions & 2 deletions lib/galaxy/managers/job_connections.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING

from sqlalchemy import (
literal,
union,
Expand All @@ -9,13 +11,15 @@

from galaxy import model
from galaxy.managers.base import get_class
from galaxy.model.scoped_session import galaxy_scoped_session

if TYPE_CHECKING:
from sqlalchemy.orm import scoped_session


class JobConnectionsManager:
"""Get connections graph of inputs and outputs for given item"""

def __init__(self, sa_session: galaxy_scoped_session):
def __init__(self, sa_session: "scoped_session"):
self.sa_session = sa_session

def get_connections_graph(self, id: int, src: str):
Expand Down
1 change: 1 addition & 0 deletions lib/galaxy/managers/model_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def import_model_store(self, request: ImportModelStoreTaskRequest):

def _build_user_context(self, user_id: int):
user = self._user_manager.by_id(user_id)
assert user is not None
user_context = ModelStoreUserContext(self._app, user)
return user_context

Expand Down
18 changes: 11 additions & 7 deletions lib/galaxy/managers/pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
Callable,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)

import sqlalchemy
from sqlalchemy import (
desc,
false,
Expand Down Expand Up @@ -76,6 +77,9 @@
RawTextTerm,
)

if TYPE_CHECKING:
from sqlalchemy.engine import ScalarResult

log = logging.getLogger(__name__)

# Copied from https://github.com/kurtmckee/feedparser
Expand Down Expand Up @@ -121,7 +125,7 @@
}


class PageManager(sharable.SharableModelManager, UsesAnnotations):
class PageManager(sharable.SharableModelManager[model.Page], UsesAnnotations):
"""Provides operations for managing a Page."""

model_class = model.Page
Expand All @@ -139,7 +143,7 @@ def __init__(self, app: MinimalManagerApp):

def index_query(
self, trans: ProvidesUserContext, payload: PageIndexQueryPayload, include_total_count: bool = False
) -> Tuple[sqlalchemy.engine.Result, int]:
) -> Tuple["ScalarResult[model.Page]", Union[int, None]]:
show_deleted = payload.deleted
show_own = payload.show_own
show_published = payload.show_published
Expand Down Expand Up @@ -242,7 +246,7 @@ def p_tag_filter(term_text: str, quoted: bool):
stmt = stmt.limit(payload.limit)
if payload.offset is not None:
stmt = stmt.offset(payload.offset)
return trans.sa_session.scalars(stmt), total_matches # type:ignore[return-value]
return trans.sa_session.scalars(stmt), total_matches

def create_page(self, trans, payload: CreatePagePayload):
user = trans.get_user()
Expand All @@ -269,7 +273,7 @@ def create_page(self, trans, payload: CreatePagePayload):
content = self.rewrite_content_for_import(trans, content, content_format)

# Create the new stored page
page = trans.app.model.Page()
page = model.Page()
page.title = payload.title
page.slug = payload.slug
if (page_annotation := payload.annotation) is not None:
Expand All @@ -278,7 +282,7 @@ def create_page(self, trans, payload: CreatePagePayload):

page.user = user
# And the first (empty) page revision
page_revision = trans.app.model.PageRevision()
page_revision = model.PageRevision()
page_revision.title = payload.title
page_revision.page = page
page.latest_revision = page_revision
Expand Down Expand Up @@ -310,7 +314,7 @@ def save_new_revision(self, trans, page, payload):
content_format = page.latest_revision.content_format
content = self.rewrite_content_for_import(trans, content, content_format=content_format)

page_revision = trans.app.model.PageRevision()
page_revision = model.PageRevision()
page_revision.title = title
page_revision.page = page
page.latest_revision = page_revision
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def purge(self, trans: ProvidesUserContext, role: model.Role) -> model.Role:
raise RequestParameterInvalidException(f"Role '{role.name}' has not been deleted, so it cannot be purged.")
# Delete UserRoleAssociations
for ura in role.users:
user = sa_session.get(trans.app.model.User, ura.user_id)
user = sa_session.get(model.User, ura.user_id)
assert user
# Delete DefaultUserPermissions for associated users
for dup in user.default_permissions:
Expand Down
6 changes: 3 additions & 3 deletions lib/galaxy/managers/secured.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ class AccessibleManagerMixin:
def by_id(self, id: int): ...

# don't want to override by_id since consumers will also want to fetch w/o any security checks
def is_accessible(self, item, user: model.User, **kwargs: Any) -> bool:
def is_accessible(self, item, user: Optional[model.User], **kwargs: Any) -> bool:
"""
Return True if the item accessible to user.
"""
# override in subclasses
raise exceptions.NotImplemented("Abstract interface Method")

def get_accessible(self, id: int, user: model.User, **kwargs: Any):
def get_accessible(self, id: int, user: Optional[model.User], **kwargs: Any):
"""
Return the item with the given id if it's accessible to user,
otherwise raise an error.
Expand All @@ -50,7 +50,7 @@ def get_accessible(self, id: int, user: model.User, **kwargs: Any):
item = self.by_id(id)
return self.error_unless_accessible(item, user, **kwargs)

def error_unless_accessible(self, item: "Query", user, **kwargs):
def error_unless_accessible(self, item: "Query", user: Optional[model.User], **kwargs):
"""
Raise an error if the item is NOT accessible to user, otherwise return the item.

Expand Down
5 changes: 4 additions & 1 deletion lib/galaxy/managers/sharable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Optional,
Set,
Type,
TypeVar,
)

from slugify import slugify
Expand Down Expand Up @@ -54,10 +55,12 @@
from galaxy.util.hash_util import md5_hash_str

log = logging.getLogger(__name__)
# Only model classes that have `users_shared_with` field
U = TypeVar("U", model.History, model.Page, model.StoredWorkflow, model.Visualization)


class SharableModelManager(
base.ModelManager,
base.ModelManager[U],
secured.OwnableManagerMixin,
secured.AccessibleManagerMixin,
annotatable.AnnotatableManagerMixin,
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _error_on_duplicate_email(self, email: str) -> None:
if self.by_email(email) is not None:
raise exceptions.Conflict("Email must be unique", email=email)

def by_id(self, user_id: int) -> model.User:
def by_id(self, user_id: int) -> Optional[model.User]:
return self.app.model.session.get(self.model_class, user_id)

# ---- filters
Expand Down
12 changes: 8 additions & 4 deletions lib/galaxy/managers/visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import logging
from typing import (
Dict,
List,
Tuple,
TYPE_CHECKING,
Union,
)

from sqlalchemy import (
Expand Down Expand Up @@ -44,6 +45,9 @@
RawTextTerm,
)

if TYPE_CHECKING:
from sqlalchemy.engine import ScalarResult

log = logging.getLogger(__name__)


Expand All @@ -59,7 +63,7 @@
}


class VisualizationManager(sharable.SharableModelManager):
class VisualizationManager(sharable.SharableModelManager[model.Visualization]):
"""
Handle operations outside and between visualizations and other models.
"""
Expand All @@ -76,7 +80,7 @@ class VisualizationManager(sharable.SharableModelManager):

def index_query(
self, trans: ProvidesUserContext, payload: VisualizationIndexQueryPayload, include_total_count: bool = False
) -> Tuple[List[model.Visualization], int]:
) -> Tuple["ScalarResult[model.Visualization]", Union[int, None]]:
show_deleted = payload.deleted
show_own = payload.show_own
show_published = payload.show_published
Expand Down Expand Up @@ -171,7 +175,7 @@ def p_tag_filter(term_text: str, quoted: bool):
stmt = stmt.limit(payload.limit)
if payload.offset is not None:
stmt = stmt.offset(payload.offset)
return trans.sa_session.scalars(stmt), total_matches # type:ignore[return-value]
return trans.sa_session.scalars(stmt), total_matches


class VisualizationSerializer(sharable.SharableModelSerializer):
Expand Down
Loading
Loading