From 449b046dfa3c7cbfc9e12eab292b3b09aa5b3920 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 21 Sep 2023 07:24:25 +0200 Subject: [PATCH 001/103] WIP RBAC --- src/zenml/cli/__init__.py | 1 - src/zenml/cli/role.py | 313 -------- src/zenml/cli/user_management.py | 156 +--- src/zenml/client.py | 553 -------------- src/zenml/models/__init__.py | 53 -- src/zenml/models/pipeline_models.py | 8 +- src/zenml/models/role_models.py | 83 -- src/zenml/models/server_models.py | 3 + src/zenml/models/team_models.py | 118 --- .../models/team_role_assignment_models.py | 95 --- src/zenml/models/user_models.py | 12 +- .../models/user_role_assignment_models.py | 94 --- src/zenml/zen_server/auth.py | 79 +- src/zenml/zen_server/rbac_interface.py | 62 ++ .../zen_server/routers/auth_endpoints.py | 17 +- .../zen_server/routers/pipelines_endpoints.py | 32 +- .../routers/role_assignments_endpoints.py | 14 +- .../zen_server/routers/roles_endpoints.py | 16 +- .../zen_server/routers/server_endpoints.py | 8 +- .../zen_server/routers/stacks_endpoints.py | 21 +- .../team_role_assignments_endpoints.py | 238 +++--- src/zenml/zen_server/utils.py | 25 + src/zenml/zen_server/zen_server_api.py | 15 +- src/zenml/zen_stores/base_zen_store.py | 95 +-- src/zenml/zen_stores/rest_zen_store.py | 319 -------- src/zenml/zen_stores/schemas/__init__.py | 16 - src/zenml/zen_stores/schemas/role_schemas.py | 262 ------- src/zenml/zen_stores/schemas/team_schemas.py | 111 --- src/zenml/zen_stores/schemas/user_schemas.py | 11 - .../zen_stores/schemas/workspace_schemas.py | 10 - src/zenml/zen_stores/sql_zen_store.py | 714 ++---------------- src/zenml/zen_stores/zen_store_interface.py | 271 ------- 32 files changed, 412 insertions(+), 3413 deletions(-) delete mode 100644 src/zenml/cli/role.py delete mode 100644 src/zenml/models/role_models.py delete mode 100644 src/zenml/models/team_models.py delete mode 100644 src/zenml/models/team_role_assignment_models.py delete mode 100644 src/zenml/models/user_role_assignment_models.py create mode 100644 src/zenml/zen_server/rbac_interface.py delete mode 100644 src/zenml/zen_stores/schemas/role_schemas.py delete mode 100644 src/zenml/zen_stores/schemas/team_schemas.py diff --git a/src/zenml/cli/__init__.py b/src/zenml/cli/__init__.py index fe741647760..6e3102eace3 100644 --- a/src/zenml/cli/__init__.py +++ b/src/zenml/cli/__init__.py @@ -1548,7 +1548,6 @@ def my_pipeline(...): from zenml.cli.model import * # noqa from zenml.cli.pipeline import * # noqa from zenml.cli.workspace import * # noqa -from zenml.cli.role import * # noqa from zenml.cli.secret import * # noqa from zenml.cli.server import * # noqa from zenml.cli.stack import * # noqa diff --git a/src/zenml/cli/role.py b/src/zenml/cli/role.py deleted file mode 100644 index 4f4439c5305..00000000000 --- a/src/zenml/cli/role.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Functionality to administer roles of the ZenML CLI and server.""" - -from typing import Any, List, Optional, Tuple - -import click - -from zenml.cli import utils as cli_utils -from zenml.cli.cli import TagGroup, cli -from zenml.cli.utils import list_options -from zenml.client import Client -from zenml.console import console -from zenml.enums import CliCategories, PermissionType -from zenml.exceptions import EntityExistsError, IllegalOperationError -from zenml.models import RoleFilterModel, UserRoleAssignmentFilterModel - - -@cli.group(cls=TagGroup, tag=CliCategories.IDENTITY_AND_SECURITY) -def role() -> None: - """Commands for role management.""" - - -@role.command("list") -@list_options(RoleFilterModel) -def list_roles(**kwargs: Any) -> None: - """List all roles that fulfill the filter requirements. - - Args: - **kwargs: Keyword arguments to filter the list of roles. - """ - client = Client() - with console.status("Listing roles...\n"): - roles = client.list_roles(**kwargs) - if not roles.items: - cli_utils.declare("No roles found for the given filters.") - return - - cli_utils.print_pydantic_models( - roles, - exclude_columns=["created", "updated"], - ) - - -@role.command("create", help="Create a new role.") -@click.argument("role_name", type=str, required=True) -@click.option( - "--permissions", - "-p", - "permissions", - type=click.Choice(choices=PermissionType.values()), - multiple=True, - help="Name of permission to attach to this role.", -) -def create_role(role_name: str, permissions: List[str]) -> None: - """Create a new role. - - Args: - role_name: Name of the role to create. - permissions: Permissions to assign - """ - try: - Client().create_role(name=role_name, permissions_list=permissions) - cli_utils.declare(f"Created role '{role_name}'.") - except EntityExistsError as e: - cli_utils.error(str(e)) - - -@role.command("update", help="Update an existing role.") -@click.argument("role_name", type=str, required=True) -@click.option( - "--name", "-n", "new_name", type=str, required=False, help="New role name." -) -@click.option( - "--remove-permission", - "-r", - type=click.Choice(choices=PermissionType.values()), - multiple=True, - help="Name of permission to remove.", -) -@click.option( - "--add-permission", - "-a", - type=click.Choice(choices=PermissionType.values()), - multiple=True, - help="Name of permission to add.", -) -def update_role( - role_name: str, - new_name: Optional[str] = None, - remove_permission: Optional[List[str]] = None, - add_permission: Optional[List[str]] = None, -) -> None: - """Update an existing role. - - Args: - role_name: The name of the role. - new_name: The new name of the role. - remove_permission: Name of permission to remove from role - add_permission: Name of permission to add to role - """ - try: - Client().update_role( - name_id_or_prefix=role_name, - new_name=new_name, - remove_permission=remove_permission, - add_permission=add_permission, - ) - except ( - EntityExistsError, - KeyError, - RuntimeError, - IllegalOperationError, - ) as err: - cli_utils.error(str(err)) - cli_utils.declare(f"Updated role '{role_name}'.") - - -@role.command("delete", help="Delete a role.") -@click.argument("role_name_or_id", type=str, required=True) -def delete_role(role_name_or_id: str) -> None: - """Delete a role. - - Args: - role_name_or_id: Name or ID of the role to delete. - """ - try: - Client().delete_role(name_id_or_prefix=role_name_or_id) - except (KeyError, IllegalOperationError) as err: - cli_utils.error(str(err)) - cli_utils.declare(f"Deleted role '{role_name_or_id}'.") - - -@role.command("assign", help="Assign a role.") -@click.argument("role_name_or_id", type=str, required=True) -@click.option("--workspace", "workspace_name_or_id", type=str, required=False) -@click.option( - "--user", "user_names_or_ids", type=str, required=False, multiple=True -) -@click.option( - "--team", "team_names_or_ids", type=str, required=False, multiple=True -) -def assign_role( - role_name_or_id: str, - user_names_or_ids: Tuple[str], - team_names_or_ids: Tuple[str], - workspace_name_or_id: Optional[str] = None, -) -> None: - """Assign a role. - - Args: - role_name_or_id: Name or IDs of the role to assign. - user_names_or_ids : Names or IDs of users to assign the role to. - team_names_or_ids: Names or IDs of teams to assign the role to. - workspace_name_or_id: Name or IDs of a workspace in which to assign the - role. If this is not provided, the role will be assigned globally. - """ - # Assign the role to users - for user_name_or_id in user_names_or_ids: - try: - Client().create_user_role_assignment( - role_name_or_id=role_name_or_id, - user_name_or_id=user_name_or_id, - workspace_name_or_id=workspace_name_or_id, - ) - except KeyError as err: - cli_utils.error(str(err)) - except EntityExistsError as err: - cli_utils.error(str(err)) - else: - cli_utils.declare( - f"Assigned role '{role_name_or_id}' to user '{user_name_or_id}'." - ) - - # Assign the role to teams - for team_name_or_id in team_names_or_ids: - try: - Client().create_team_role_assignment( - role_name_or_id=role_name_or_id, - team_name_or_id=team_name_or_id, - workspace_name_or_id=workspace_name_or_id, - ) - except KeyError as err: - cli_utils.error(str(err)) - except EntityExistsError as err: - cli_utils.warning(str(err)) - else: - cli_utils.declare( - f"Assigned role '{role_name_or_id}' to team '{team_name_or_id}'." - ) - - -@role.command("revoke", help="Revoke a role.") -@click.argument("role_name_or_id", type=str, required=True) -@click.option("--workspace", "workspace_name_or_id", type=str, required=False) -@click.option( - "--user", "user_names_or_ids", type=str, required=False, multiple=True -) -@click.option( - "--team", "team_names_or_ids", type=str, required=False, multiple=True -) -def revoke_role( - role_name_or_id: str, - user_names_or_ids: Tuple[str], - team_names_or_ids: Tuple[str], - workspace_name_or_id: Optional[str] = None, -) -> None: - """Revoke a role. - - Args: - role_name_or_id: Name or IDs of the role to revoke. - user_names_or_ids: Names or IDs of users from which to revoke the role. - team_names_or_ids: Names or IDs of teams from which to revoke the role. - workspace_name_or_id: Name or IDs of a workspace in which to revoke the - role. If this is not provided, the role will be revoked globally. - """ - client = Client() - - role = client.get_role(name_id_or_prefix=role_name_or_id) - workspace_id = None - if workspace_name_or_id: - workspace_id = client.get_workspace( - name_id_or_prefix=workspace_name_or_id - ).id - - # Revoke the role from users - for user_name_or_id in user_names_or_ids: - user = client.get_user(name_id_or_prefix=user_name_or_id) - try: - user_role_assignments = client.list_user_role_assignment( - role_id=role.id, - user_id=user.id, - workspace_id=workspace_id, - ) - for user_role_assignment in user_role_assignments.items: - Client().delete_user_role_assignment(user_role_assignment.id) - except KeyError as err: - cli_utils.warning(str(err)) - else: - cli_utils.declare( - f"Revoked role '{role_name_or_id}' from user " - f"'{user_name_or_id}'." - ) - - # Revoke the role from teams - for team_name_or_id in team_names_or_ids: - team = client.get_team(name_id_or_prefix=team_name_or_id) - try: - team_role_assignments = client.list_team_role_assignment( - role_id=role.id, - team_id=team.id, - workspace_id=workspace_id, - ) - for team_role_assignment in team_role_assignments.items: - Client().delete_user_role_assignment(team_role_assignment.id) - except KeyError as err: - cli_utils.warning(str(err)) - else: - cli_utils.declare( - f"Revoked role '{role_name_or_id}' from team " - f"'{team_name_or_id}'." - ) - - -@role.group() -def assignment() -> None: - """Commands for role management.""" - - -@assignment.command("list") -@list_options(UserRoleAssignmentFilterModel) -def list_role_assignments(**kwargs: Any) -> None: - """List all user role assignments that fulfill the filter requirements. - - Args: - kwargs: Keyword arguments. - """ - client = Client() - with console.status("Listing roles...\n"): - role_assignments = client.list_user_role_assignment(**kwargs) - if not role_assignments.items: - cli_utils.declare( - "No roles assignments found for the given filters." - ) - return - cli_utils.print_pydantic_models( - role_assignments, exclude_columns=["id", "created", "updated"] - ) - - -@cli.group(cls=TagGroup, tag=CliCategories.IDENTITY_AND_SECURITY) -def permission() -> None: - """Commands for role management.""" - - -@permission.command("list") -def list_permissions() -> None: - """List all role assignments.""" - permissions = [i.value for i in PermissionType] - cli_utils.declare( - f"The following permissions are currently supported: " f"{permissions}" - ) diff --git a/src/zenml/cli/user_management.py b/src/zenml/cli/user_management.py index 73b250fb82c..cf4c9c8f4ab 100644 --- a/src/zenml/cli/user_management.py +++ b/src/zenml/cli/user_management.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Functionality to administer users of the ZenML CLI and server.""" -from typing import Any, List, Optional +from typing import Any, Optional import click @@ -24,7 +24,7 @@ from zenml.console import console from zenml.enums import CliCategories, StoreType from zenml.exceptions import EntityExistsError, IllegalOperationError -from zenml.models import TeamFilterModel, UserFilterModel +from zenml.models import UserFilterModel @cli.group(cls=TagGroup, tag=CliCategories.IDENTITY_AND_SECURITY) @@ -120,18 +120,8 @@ def list_users(ctx: click.Context, **kwargs: Any) -> None: required=False, type=str, ) -@click.option( - "--role", - "-r", - "initial_role", - help="Give the user an initial role.", - required=False, - type=str, - default="admin", -) def create_user( user_name: str, - initial_role: str = "admin", password: Optional[str] = None, ) -> None: """Create a new user. @@ -139,7 +129,6 @@ def create_user( Args: user_name: The name of the user to create. password: The password of the user to create. - initial_role: Give the user an initial role """ client = Client() if not password: @@ -157,9 +146,7 @@ def create_user( ) try: - new_user = client.create_user( - name=user_name, password=password, initial_role=initial_role - ) + new_user = client.create_user(name=user_name, password=password) cli_utils.declare(f"Created user '{new_user.name}'.") except EntityExistsError as err: @@ -242,140 +229,3 @@ def delete_user(user_name_or_id: str) -> None: except (KeyError, IllegalOperationError) as err: cli_utils.error(str(err)) cli_utils.declare(f"Deleted user '{user_name_or_id}'.") - - -@cli.group(cls=TagGroup, tag=CliCategories.IDENTITY_AND_SECURITY) -def team() -> None: - """Commands for team management.""" - - -@team.command("list") -@list_options(TeamFilterModel) -def list_teams(**kwargs: Any) -> None: - """List all teams that fulfill the filter requirements. - - Args: - kwargs: The filter options. - """ - client = Client() - - with console.status("Listing teams...\n"): - teams = client.list_teams(**kwargs) - - if not teams: - cli_utils.declare("No teams found with the given filter.") - return - - cli_utils.print_pydantic_models( - teams, - exclude_columns=["id", "created", "updated"], - ) - - -@team.command("describe", help="List all users in a team.") -@click.argument("team_name_or_id", type=str, required=True) -def describe_team(team_name_or_id: str) -> None: - """List all users in a team. - - Args: - team_name_or_id: The name or ID of the team to describe. - """ - try: - team_ = Client().get_team(name_id_or_prefix=team_name_or_id) - except KeyError as err: - cli_utils.error(str(err)) - else: - cli_utils.print_pydantic_models( - [team_], - exclude_columns=[ - "created", - "updated", - ], - ) - - -@team.command("create", help="Create a new team.") -@click.argument("team_name", type=str, required=True) -@click.option( - "--user", - "-u", - "users", - type=str, - multiple=True, - help="Name of users to add to this team.", -) -def create_team(team_name: str, users: Optional[List[str]] = None) -> None: - """Create a new team. - - Args: - team_name: Name of the team to create. - users: Users to add to this team - """ - try: - Client().create_team(name=team_name, users=users) - except EntityExistsError as err: - cli_utils.error(str(err)) - cli_utils.declare(f"Created team '{team_name}'.") - - -@team.command("update", help="Update an existing team.") -@click.argument("team_name", type=str, required=True) -@click.option( - "--name", "-n", "new_name", type=str, required=False, help="New team name." -) -@click.option( - "--remove-user", - "-r", - "remove_users", - type=str, - multiple=True, - help="Name or Id of users to remove.", -) -@click.option( - "--add-user", - "-a", - "add_users", - type=str, - multiple=True, - help="Name or Id of users to add.", -) -def update_team( - team_name: str, - new_name: Optional[str] = None, - remove_users: Optional[List[str]] = None, - add_users: Optional[List[str]] = None, -) -> None: - """Update an existing team. - - Args: - team_name: The name of the team. - new_name: The new name of the team. - remove_users: Users to remove from the team - add_users: Users to add to the team. - """ - try: - team_ = Client().update_team( - name_id_or_prefix=team_name, - new_name=new_name, - remove_users=remove_users, - add_users=add_users, - ) - except (EntityExistsError, KeyError) as err: - cli_utils.error(str(err)) - else: - cli_utils.declare(f"Updated team '{team_.name}'.") - - -@team.command("delete", help="Delete a team.") -@click.argument("team_name_or_id", type=str, required=True) -def delete_team(team_name_or_id: str) -> None: - """Delete a team. - - Args: - team_name_or_id: The name or ID of the team to delete. - """ - try: - Client().delete_team(team_name_or_id) - except KeyError as err: - cli_utils.error(str(err)) - cli_utils.declare(f"Deleted team '{team_name_or_id}'.") diff --git a/src/zenml/client.py b/src/zenml/client.py index 858f8d44fb3..25a69c873f1 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -26,7 +26,6 @@ List, Mapping, Optional, - Set, Tuple, Type, TypeVar, @@ -52,7 +51,6 @@ from zenml.enums import ( ArtifactType, LogicalOperators, - PermissionType, SecretScope, StackComponentType, StoreType, @@ -88,10 +86,6 @@ PipelineResponseModel, PipelineRunFilterModel, PipelineRunResponseModel, - RoleFilterModel, - RoleRequestModel, - RoleResponseModel, - RoleUpdateModel, RunMetadataRequestModel, RunMetadataResponseModel, SecretFilterModel, @@ -110,19 +104,9 @@ StackUpdateModel, StepRunFilterModel, StepRunResponseModel, - TeamFilterModel, - TeamRequestModel, - TeamResponseModel, - TeamRoleAssignmentFilterModel, - TeamRoleAssignmentRequestModel, - TeamRoleAssignmentResponseModel, - TeamUpdateModel, UserFilterModel, UserRequestModel, UserResponseModel, - UserRoleAssignmentFilterModel, - UserRoleAssignmentRequestModel, - UserRoleAssignmentResponseModel, UserUpdateModel, WorkspaceFilterModel, WorkspaceRequestModel, @@ -642,14 +626,12 @@ def active_user(self) -> "UserResponseModel": def create_user( self, name: str, - initial_role: Optional[str] = None, password: Optional[str] = None, ) -> UserResponseModel: """Create a new user. Args: name: The name of the user. - initial_role: Optionally, an initial role to assign to the user. password: The password of the user. If not provided, the user will be created with empty password. @@ -662,13 +644,6 @@ def create_user( ) created_user = self.zen_store.create_user(user=user) - if initial_role: - self.create_user_role_assignment( - role_name_or_id=initial_role, - user_name_or_id=created_user.id, - workspace_name_or_id=None, - ) - return created_user def get_user( @@ -794,534 +769,6 @@ def update_user( user_id=user.id, user_update=user_update ) - # ---- # - # TEAM # - # ---- # - - def get_team( - self, - name_id_or_prefix: Union[str, UUID], - allow_name_prefix_match: bool = True, - ) -> TeamResponseModel: - """Gets a team. - - Args: - name_id_or_prefix: The name or ID of the team. - allow_name_prefix_match: If True, allow matching by name prefix. - - Returns: - The Team - """ - return self._get_entity_by_id_or_name_or_prefix( - get_method=self.zen_store.get_team, - list_method=self.list_teams, - name_id_or_prefix=name_id_or_prefix, - allow_name_prefix_match=allow_name_prefix_match, - ) - - def list_teams( - self, - sort_by: str = "created", - page: int = PAGINATION_STARTING_PAGE, - size: int = PAGE_SIZE_DEFAULT, - logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - ) -> Page[TeamResponseModel]: - """List all teams. - - Args: - sort_by: The column to sort by - page: The page of items - size: The maximum size of all pages - logical_operator: Which logical operator to use [and, or] - id: Use the id of teams to filter by. - created: Use to filter by time of creation - updated: Use the last updated date for filtering - name: Use the team name for filtering - - Returns: - The Team - """ - return self.zen_store.list_teams( - TeamFilterModel( - sort_by=sort_by, - page=page, - size=size, - logical_operator=logical_operator, - id=id, - created=created, - updated=updated, - name=name, - ) - ) - - def create_team( - self, name: str, users: Optional[List[str]] = None - ) -> TeamResponseModel: - """Create a team. - - Args: - name: Name of the team. - users: Users to add to the team. - - Returns: - The created team. - """ - user_list: List[UUID] = [] - if users: - user_list.extend( - self.get_user(name_id_or_prefix=user_name_or_id).id - for user_name_or_id in users - ) - - team = TeamRequestModel(name=name, users=user_list) - - return self.zen_store.create_team(team=team) - - def delete_team(self, name_id_or_prefix: str) -> None: - """Delete a team. - - Args: - name_id_or_prefix: The name or ID of the team to delete. - """ - team = self.get_team(name_id_or_prefix, allow_name_prefix_match=False) - self.zen_store.delete_team(team_name_or_id=team.id) - - def update_team( - self, - name_id_or_prefix: str, - new_name: Optional[str] = None, - remove_users: Optional[List[str]] = None, - add_users: Optional[List[str]] = None, - ) -> TeamResponseModel: - """Update a team. - - Args: - name_id_or_prefix: The name or ID of the team to update. - new_name: The new name of the team. - remove_users: The users to remove from the team. - add_users: The users to add to the team. - - Returns: - The updated team. - - Raises: - RuntimeError: If the same user is in both `remove_users` and - `add_users`. - """ - team = self.get_team(name_id_or_prefix, allow_name_prefix_match=False) - - team_update = TeamUpdateModel(name=new_name or team.name) - if remove_users is not None and add_users is not None: - if union_add_rm := set(remove_users) & set(add_users): - raise RuntimeError( - f"The `remove_user` and `add_user` " - f"options both contain the same value(s): " - f"`{union_add_rm}`. Please rerun command and make sure " - f"that the same user does not show up for " - f"`remove_user` and `add_user`." - ) - - # Only if permissions are being added or removed will they need to be - # set for the update model - team_users = ( - [u.id for u in team.users] if remove_users or add_users else [] - ) - if remove_users: - for rm_p in remove_users: - user = self.get_user(rm_p) - try: - team_users.remove(user.id) - except KeyError: - logger.warning( - f"Role {remove_users} was already not " - f"part of the '{team.name}' Team." - ) - if add_users: - team_users.extend(self.get_user(add_u).id for add_u in add_users) - if team_users: - team_update.users = team_users - - return self.zen_store.update_team( - team_id=team.id, team_update=team_update - ) - - # ----- # - # ROLES # - # ----- # - - def get_role( - self, - name_id_or_prefix: Union[str, UUID], - allow_name_prefix_match: bool = True, - ) -> RoleResponseModel: - """Gets a role. - - Args: - name_id_or_prefix: The name or ID of the role. - allow_name_prefix_match: If True, allow matching by name prefix. - - Returns: - The fetched role. - """ - return self._get_entity_by_id_or_name_or_prefix( - get_method=self.zen_store.get_role, - list_method=self.list_roles, - name_id_or_prefix=name_id_or_prefix, - allow_name_prefix_match=allow_name_prefix_match, - ) - - def list_roles( - self, - sort_by: str = "created", - page: int = PAGINATION_STARTING_PAGE, - size: int = PAGE_SIZE_DEFAULT, - logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - ) -> Page[RoleResponseModel]: - """List all roles. - - Args: - sort_by: The column to sort by - page: The page of items - size: The maximum size of all pages - logical_operator: The logical operator to use between column filters - id: Use the id of roles to filter by. - created: Use to filter by time of creation - updated: Use the last updated date for filtering - name: Use the role name for filtering - - Returns: - The Role - """ - return self.zen_store.list_roles( - RoleFilterModel( - sort_by=sort_by, - page=page, - size=size, - logical_operator=logical_operator, - id=id, - created=created, - updated=updated, - name=name, - ) - ) - - def create_role( - self, name: str, permissions_list: List[str] - ) -> RoleResponseModel: - """Creates a role. - - Args: - name: The name for the new role. - permissions_list: The permissions to attach to this role. - - Returns: - The newly created role. - """ - permissions: Set[PermissionType] = { - PermissionType(permission) - for permission in permissions_list - if permission in PermissionType.values() - } - new_role = RoleRequestModel(name=name, permissions=permissions) - return self.zen_store.create_role(new_role) - - def update_role( - self, - name_id_or_prefix: str, - new_name: Optional[str] = None, - remove_permission: Optional[List[str]] = None, - add_permission: Optional[List[str]] = None, - ) -> RoleResponseModel: - """Updates a role. - - Args: - name_id_or_prefix: The name or ID of the role. - new_name: The new name for the role - remove_permission: Permissions to remove from this role. - add_permission: Permissions to add to this role. - - Returns: - The updated role. - - Raises: - RuntimeError: If the same permission is in both the - `remove_permission` and `add_permission` lists. - """ - role = self.get_role( - name_id_or_prefix=name_id_or_prefix, allow_name_prefix_match=False - ) - - role_update = RoleUpdateModel(name=new_name or role.name) # type: ignore[call-arg] - - if remove_permission is not None and add_permission is not None: - if union_add_rm := set(remove_permission) & set(add_permission): - raise RuntimeError( - f"The `remove_permission` and `add_permission` " - f"options both contain the same value(s): " - f"`{union_add_rm}`. Please rerun command and make sure " - f"that the same role does not show up for " - f"`remove_permission` and `add_permission`." - ) - - # Only if permissions are being added or removed will they need to be - # set for the update model - if remove_permission or add_permission: - role_permissions = role.permissions - - if remove_permission: - for rm_p in remove_permission: - if rm_p in PermissionType: - try: - role_permissions.remove(PermissionType(rm_p)) - except KeyError: - logger.warning( - f"Role {remove_permission} was already not " - f"part of the {role} Role." - ) - if add_permission: - for add_p in add_permission: - if add_p in PermissionType.values(): - # Set won't throw an error if the item was already in it - role_permissions.add(PermissionType(add_p)) - - if role_permissions is not None: - role_update.permissions = set(role_permissions) - - return Client().zen_store.update_role( - role_id=role.id, role_update=role_update - ) - - def delete_role(self, name_id_or_prefix: str) -> None: - """Deletes a role. - - Args: - name_id_or_prefix: The name or ID of the role. - """ - role = self.get_role( - name_id_or_prefix=name_id_or_prefix, allow_name_prefix_match=False - ) - self.zen_store.delete_role(role_name_or_id=role.id) - - # --------------------- # - # USER ROLE ASSIGNMENTS # - # --------------------- # - - def get_user_role_assignment( - self, role_assignment_id: UUID - ) -> UserRoleAssignmentResponseModel: - """Get a role assignment. - - Args: - role_assignment_id: The id of the role assignments - - Returns: - The role assignment. - """ - return self.zen_store.get_user_role_assignment( - user_role_assignment_id=role_assignment_id - ) - - def create_user_role_assignment( - self, - role_name_or_id: Union[str, UUID], - user_name_or_id: Union[str, UUID], - workspace_name_or_id: Optional[Union[str, UUID]] = None, - ) -> UserRoleAssignmentResponseModel: - """Create a role assignment. - - Args: - role_name_or_id: Name or ID of the role to assign. - user_name_or_id: Name or ID of the user or team to assign - the role to. - workspace_name_or_id: workspace scope within which to assign the role. - - Returns: - The newly created role assignment. - """ - role = self.get_role(name_id_or_prefix=role_name_or_id) - workspace = None - if workspace_name_or_id: - workspace = self.get_workspace( - name_id_or_prefix=workspace_name_or_id - ) - user = self.get_user(name_id_or_prefix=user_name_or_id) - role_assignment = UserRoleAssignmentRequestModel( - role=role.id, - user=user.id, - workspace=workspace, - ) - return self.zen_store.create_user_role_assignment( - user_role_assignment=role_assignment - ) - - def delete_user_role_assignment(self, role_assignment_id: UUID) -> None: - """Delete a role assignment. - - Args: - role_assignment_id: The id of the role assignments - - """ - self.zen_store.delete_user_role_assignment(role_assignment_id) - - def list_user_role_assignment( - self, - sort_by: str = "created", - page: int = PAGINATION_STARTING_PAGE, - size: int = PAGE_SIZE_DEFAULT, - logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - workspace_id: Optional[Union[str, UUID]] = None, - user_id: Optional[Union[str, UUID]] = None, - role_id: Optional[Union[str, UUID]] = None, - ) -> Page[UserRoleAssignmentResponseModel]: - """List all user role assignments. - - Args: - sort_by: The column to sort by - page: The page of items - size: The maximum size of all pages - logical_operator: Which logical operator to use [and, or] - id: Use the id of the user role assignment to filter by. - created: Use to filter by time of creation - updated: Use the last updated date for filtering - workspace_id: The id of the workspace to filter by. - user_id: The id of the user to filter by. - role_id: The id of the role to filter by. - - Returns: - The Team - """ - return self.zen_store.list_user_role_assignments( - UserRoleAssignmentFilterModel( - sort_by=sort_by, - page=page, - size=size, - logical_operator=logical_operator, - id=id, - created=created, - updated=updated, - workspace_id=workspace_id, - user_id=user_id, - role_id=role_id, - ) - ) - - # --------------------- # - # TEAM ROLE ASSIGNMENTS # - # --------------------- # - - def get_team_role_assignment( - self, team_role_assignment_id: UUID - ) -> TeamRoleAssignmentResponseModel: - """Get a role assignment. - - Args: - team_role_assignment_id: The id of the role assignments - - Returns: - The role assignment. - """ - return self.zen_store.get_team_role_assignment( - team_role_assignment_id=team_role_assignment_id - ) - - def create_team_role_assignment( - self, - role_name_or_id: Union[str, UUID], - team_name_or_id: Union[str, UUID], - workspace_name_or_id: Optional[Union[str, UUID]] = None, - ) -> TeamRoleAssignmentResponseModel: - """Create a role assignment. - - Args: - role_name_or_id: Name or ID of the role to assign. - team_name_or_id: Name or ID of the team to assign - the role to. - workspace_name_or_id: workspace scope within which to assign the role. - - Returns: - The newly created role assignment. - """ - role = self.get_role(name_id_or_prefix=role_name_or_id) - workspace = None - if workspace_name_or_id: - workspace = self.get_workspace( - name_id_or_prefix=workspace_name_or_id - ) - team = self.get_team(name_id_or_prefix=team_name_or_id) - role_assignment = TeamRoleAssignmentRequestModel( - role=role.id, - team=team.id, - workspace=workspace, - ) - return self.zen_store.create_team_role_assignment( - team_role_assignment=role_assignment - ) - - def delete_team_role_assignment(self, role_assignment_id: UUID) -> None: - """Delete a role assignment. - - Args: - role_assignment_id: The id of the role assignments - - """ - self.zen_store.delete_team_role_assignment(role_assignment_id) - - def list_team_role_assignment( - self, - sort_by: str = "created", - page: int = PAGINATION_STARTING_PAGE, - size: int = PAGE_SIZE_DEFAULT, - logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - workspace_id: Optional[Union[str, UUID]] = None, - team_id: Optional[Union[str, UUID]] = None, - role_id: Optional[Union[str, UUID]] = None, - ) -> Page[TeamRoleAssignmentResponseModel]: - """List all team role assignments. - - Args: - sort_by: The column to sort by - page: The page of items - size: The maximum size of all pages - logical_operator: Which logical operator to use [and, or] - id: Use the id of the team role assignment to filter by. - created: Use to filter by time of creation - updated: Use the last updated date for filtering - workspace_id: The id of the workspace to filter by. - team_id: The id of the team to filter by. - role_id: The id of the role to filter by. - - Returns: - The Team - """ - return self.zen_store.list_team_role_assignments( - TeamRoleAssignmentFilterModel( - sort_by=sort_by, - page=page, - size=size, - logical_operator=logical_operator, - id=id, - created=created, - updated=updated, - workspace_id=workspace_id, - team_id=team_id, - role_id=role_id, - ) - ) - # --------- # # WORKSPACE # # --------- # diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index b7e39cc21a9..fa94ba7e21b 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -61,12 +61,6 @@ WorkspaceResponseModel, WorkspaceUpdateModel, ) -from zenml.models.role_models import ( - RoleFilterModel, - RoleRequestModel, - RoleResponseModel, - RoleUpdateModel, -) from zenml.models.run_metadata_models import ( RunMetadataFilterModel, RunMetadataRequestModel, @@ -109,17 +103,6 @@ StepRunResponseModel, StepRunUpdateModel, ) -from zenml.models.team_models import ( - TeamFilterModel, - TeamRequestModel, - TeamResponseModel, - TeamUpdateModel, -) -from zenml.models.team_role_assignment_models import ( - TeamRoleAssignmentFilterModel, - TeamRoleAssignmentRequestModel, - TeamRoleAssignmentResponseModel, -) from zenml.models.user_models import ( UserAuthModel, UserFilterModel, @@ -127,11 +110,6 @@ UserResponseModel, UserUpdateModel, ) -from zenml.models.user_role_assignment_models import ( - UserRoleAssignmentFilterModel, - UserRoleAssignmentRequestModel, - UserRoleAssignmentResponseModel, -) from zenml.models.code_repository_models import ( CodeRepositoryFilterModel, CodeRepositoryRequestModel, @@ -161,23 +139,6 @@ WorkspaceResponseModel=WorkspaceResponseModel, ) -UserResponseModel.update_forward_refs(TeamResponseModel=TeamResponseModel) - -TeamResponseModel.update_forward_refs(UserResponseModel=UserResponseModel) - -UserRoleAssignmentResponseModel.update_forward_refs( - RoleResponseModel=RoleResponseModel, - TeamResponseModel=TeamResponseModel, - UserResponseModel=UserResponseModel, - WorkspaceResponseModel=WorkspaceResponseModel, -) - -TeamRoleAssignmentResponseModel.update_forward_refs( - RoleResponseModel=RoleResponseModel, - TeamResponseModel=TeamResponseModel, - UserResponseModel=UserResponseModel, - WorkspaceResponseModel=WorkspaceResponseModel, -) PipelineResponseModel.update_forward_refs( UserResponseModel=UserResponseModel, @@ -294,16 +255,6 @@ "WorkspaceResponseModel", "WorkspaceUpdateModel", "WorkspaceFilterModel", - "UserRoleAssignmentRequestModel", - "UserRoleAssignmentResponseModel", - "UserRoleAssignmentFilterModel", - "TeamRoleAssignmentRequestModel", - "TeamRoleAssignmentResponseModel", - "TeamRoleAssignmentFilterModel", - "RoleRequestModel", - "RoleResponseModel", - "RoleUpdateModel", - "RoleFilterModel", "RunMetadataFilterModel", "RunMetadataRequestModel", "RunMetadataResponseModel", @@ -334,10 +285,6 @@ "StepRunResponseModel", "StepRunUpdateModel", "StepRunFilterModel", - "TeamRequestModel", - "TeamResponseModel", - "TeamUpdateModel", - "TeamFilterModel", "UserRequestModel", "UserResponseModel", "UserUpdateModel", diff --git a/src/zenml/models/pipeline_models.py b/src/zenml/models/pipeline_models.py index aa49019bb5c..d0fb8d6c676 100644 --- a/src/zenml/models/pipeline_models.py +++ b/src/zenml/models/pipeline_models.py @@ -23,7 +23,6 @@ from zenml.models.base_models import ( WorkspaceScopedRequestModel, WorkspaceScopedResponseModel, - update_model, ) from zenml.models.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH from zenml.models.filter_models import WorkspaceScopedFilterModel @@ -187,6 +186,9 @@ class PipelineRequestModel(PipelineBaseModel, WorkspaceScopedRequestModel): # ------ # -@update_model -class PipelineUpdateModel(PipelineRequestModel): +class PipelineUpdateModel(BaseModel): """Pipeline update model.""" + + # None of the pipeline attributes should be updated ATM, but this model + # and the corresponding endpoint might be useful once we allow adding + # tags/descriptions and updating them from the dashboard diff --git a/src/zenml/models/role_models.py b/src/zenml/models/role_models.py deleted file mode 100644 index 02289c94b34..00000000000 --- a/src/zenml/models/role_models.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Models representing roles that can be assigned to users or teams.""" - -from typing import Optional, Set - -from pydantic import BaseModel, Field - -from zenml.enums import PermissionType -from zenml.models.base_models import ( - BaseRequestModel, - BaseResponseModel, - update_model, -) -from zenml.models.constants import STR_FIELD_MAX_LENGTH -from zenml.models.filter_models import BaseFilterModel - -# ---- # -# BASE # -# ---- # - - -class RoleBaseModel(BaseModel): - """Base model for roles.""" - - name: str = Field( - title="The unique name of the role.", - max_length=STR_FIELD_MAX_LENGTH, - ) - permissions: Set[PermissionType] - - -# -------- # -# RESPONSE # -# -------- # - - -class RoleResponseModel(RoleBaseModel, BaseResponseModel): - """Response model for roles.""" - - -# ------ # -# FILTER # -# ------ # - - -class RoleFilterModel(BaseFilterModel): - """Model to enable advanced filtering of all Users.""" - - name: Optional[str] = Field( - default=None, - description="Name of the role", - ) - - -# ------- # -# REQUEST # -# ------- # - - -class RoleRequestModel(RoleBaseModel, BaseRequestModel): - """Request model for roles.""" - - -# ------ # -# UPDATE # -# ------ # - - -@update_model -class RoleUpdateModel(RoleRequestModel): - """Update model for roles.""" diff --git a/src/zenml/models/server_models.py b/src/zenml/models/server_models.py index 8f31876adb9..dcfae841fcd 100644 --- a/src/zenml/models/server_models.py +++ b/src/zenml/models/server_models.py @@ -54,6 +54,9 @@ class ServerModel(BaseModel): title="The ZenML version that the server is running.", ) + zenml_cloud: bool = Field( + False, title="Flag to indicate whether this is a ZenML cloud server." + ) debug: bool = Field( False, title="Flag to indicate whether ZenML is running on debug mode." ) diff --git a/src/zenml/models/team_models.py b/src/zenml/models/team_models.py deleted file mode 100644 index 276ba781b44..00000000000 --- a/src/zenml/models/team_models.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Models representing teams.""" - -from typing import TYPE_CHECKING, List, Optional -from uuid import UUID - -from pydantic import BaseModel, Field - -from zenml.models.base_models import ( - BaseRequestModel, - BaseResponseModel, - update_model, -) -from zenml.models.constants import STR_FIELD_MAX_LENGTH -from zenml.models.filter_models import BaseFilterModel - -if TYPE_CHECKING: - from zenml.models.user_models import UserResponseModel - - -# ---- # -# BASE # -# ---- # - - -class TeamBaseModel(BaseModel): - """Base model for teams.""" - - name: str = Field( - title="The unique name of the team.", - max_length=STR_FIELD_MAX_LENGTH, - ) - - -# -------- # -# RESPONSE # -# -------- # - - -class TeamResponseModel(TeamBaseModel, BaseResponseModel): - """Response model for teams.""" - - users: List["UserResponseModel"] = Field( - title="The list of users within this team." - ) - - @property - def user_ids(self) -> List[UUID]: - """Returns a list of user IDs that are part of this team. - - Returns: - A list of user IDs. - """ - if self.users: - return [u.id for u in self.users] - else: - return [] - - @property - def user_names(self) -> List[str]: - """Returns a list names of users that are part of this team. - - Returns: - A list of names of users. - """ - if self.users: - return [u.name for u in self.users] - else: - return [] - - -# ------ # -# FILTER # -# ------ # - - -class TeamFilterModel(BaseFilterModel): - """Model to enable advanced filtering of all Teams.""" - - name: Optional[str] = Field( - default=None, - description="Name of the team", - ) - - -# ------- # -# REQUEST # -# ------- # - - -class TeamRequestModel(TeamBaseModel, BaseRequestModel): - """Request model for teams.""" - - users: Optional[List[UUID]] = Field( - default=None, title="The list of users within this team." - ) - - -# ------ # -# UPDATE # -# ------ # - - -@update_model -class TeamUpdateModel(TeamRequestModel): - """Update model for teams.""" diff --git a/src/zenml/models/team_role_assignment_models.py b/src/zenml/models/team_role_assignment_models.py deleted file mode 100644 index 74afd814da3..00000000000 --- a/src/zenml/models/team_role_assignment_models.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Models representing role assignments.""" - -from typing import TYPE_CHECKING, Optional, Union -from uuid import UUID - -from pydantic import BaseModel, Field - -from zenml.models import BaseFilterModel -from zenml.models.base_models import BaseRequestModel, BaseResponseModel - -if TYPE_CHECKING: - from zenml.models.role_models import RoleResponseModel - from zenml.models.team_models import TeamResponseModel - from zenml.models.workspace_models import WorkspaceResponseModel - -# ---- # -# BASE # -# ---- # - - -class TeamRoleAssignmentBaseModel(BaseModel): - """Base model for role assignments.""" - - -# -------- # -# RESPONSE # -# -------- # - - -class TeamRoleAssignmentResponseModel( - TeamRoleAssignmentBaseModel, BaseResponseModel -): - """Response model for role assignments with all entities hydrated.""" - - workspace: Optional["WorkspaceResponseModel"] = Field( - title="The workspace scope of this role assignment.", default=None - ) - team: Optional["TeamResponseModel"] = Field( - title="The team the role is assigned to.", default=None - ) - role: Optional["RoleResponseModel"] = Field( - title="The assigned role.", default=None - ) - - -# ------ # -# FILTER # -# ------ # - - -class TeamRoleAssignmentFilterModel(BaseFilterModel): - """Model to enable advanced filtering of all Role Assignments.""" - - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, description="Workspace of the RoleAssignment" - ) - team_id: Optional[Union[UUID, str]] = Field( - default=None, description="Team in the RoleAssignment" - ) - role_id: Optional[Union[UUID, str]] = Field( - default=None, description="Role in the RoleAssignment" - ) - - -# ------- # -# REQUEST # -# ------- # - - -class TeamRoleAssignmentRequestModel( - TeamRoleAssignmentBaseModel, BaseRequestModel -): - """Request model for role assignments using UUIDs for all entities.""" - - workspace: Optional[UUID] = Field( - default=None, title="The workspace that the role is limited to." - ) - team: UUID = Field( - title="The user that the role is assigned to.", - ) - - role: UUID = Field(title="The role.") diff --git a/src/zenml/models/user_models.py b/src/zenml/models/user_models.py index f6fee6a009a..5caa8e56508 100644 --- a/src/zenml/models/user_models.py +++ b/src/zenml/models/user_models.py @@ -33,7 +33,7 @@ from zenml.config.global_config import GlobalConfiguration from zenml.exceptions import AuthorizationException from zenml.logger import get_logger -from zenml.models import BaseFilterModel, RoleResponseModel +from zenml.models import BaseFilterModel from zenml.models.base_models import ( BaseRequestModel, BaseResponseModel, @@ -45,7 +45,6 @@ if TYPE_CHECKING: from passlib.context import CryptContext # type: ignore[import] - from zenml.models.team_models import TeamResponseModel logger = get_logger(__name__) @@ -223,12 +222,6 @@ class UserResponseModel(UserBaseModel, BaseResponseModel): activation_token: Optional[str] = Field( default=None, max_length=STR_FIELD_MAX_LENGTH ) - teams: Optional[List["TeamResponseModel"]] = Field( - default=None, title="The list of teams for this user." - ) - roles: Optional[List["RoleResponseModel"]] = Field( - default=None, title="The list of roles for this user." - ) email: Optional[str] = Field( default="", title="The email address associated with the account.", @@ -264,9 +257,6 @@ class UserAuthModel(UserBaseModel, BaseResponseModel): activation_token: Optional[SecretStr] = Field(default=None, exclude=True) password: Optional[SecretStr] = Field(default=None, exclude=True) - teams: Optional[List["TeamResponseModel"]] = Field( - default=None, title="The list of teams for this user." - ) def generate_access_token(self, permissions: List[str]) -> str: """Generates an access token. diff --git a/src/zenml/models/user_role_assignment_models.py b/src/zenml/models/user_role_assignment_models.py deleted file mode 100644 index 3400a3da27d..00000000000 --- a/src/zenml/models/user_role_assignment_models.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Models representing role assignments.""" - -from typing import TYPE_CHECKING, Optional, Union -from uuid import UUID - -from pydantic import BaseModel, Field - -from zenml.models import BaseFilterModel -from zenml.models.base_models import BaseRequestModel, BaseResponseModel - -if TYPE_CHECKING: - from zenml.models.role_models import RoleResponseModel - from zenml.models.user_models import UserResponseModel - from zenml.models.workspace_models import WorkspaceResponseModel - -# ---- # -# BASE # -# ---- # - - -class UserRoleAssignmentBaseModel(BaseModel): - """Base model for role assignments.""" - - -# -------- # -# RESPONSE # -# -------- # - - -class UserRoleAssignmentResponseModel( - UserRoleAssignmentBaseModel, BaseResponseModel -): - """Response model for role assignments with all entities hydrated.""" - - workspace: Optional["WorkspaceResponseModel"] = Field( - title="The workspace scope of this role assignment.", default=None - ) - user: Optional["UserResponseModel"] = Field( - title="The user the role is assigned to.", default=None - ) - role: Optional["RoleResponseModel"] = Field( - title="The assigned role.", default=None - ) - - -# ------ # -# FILTER # -# ------ # - - -class UserRoleAssignmentFilterModel(BaseFilterModel): - """Model to enable advanced filtering of all Role Assignments.""" - - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, description="Workspace of the RoleAssignment" - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, description="User in the RoleAssignment" - ) - role_id: Optional[Union[UUID, str]] = Field( - default=None, description="Role in the RoleAssignment" - ) - - -# ------- # -# REQUEST # -# ------- # - - -class UserRoleAssignmentRequestModel( - UserRoleAssignmentBaseModel, BaseRequestModel -): - """Request model for role assignments using UUIDs for all entities.""" - - workspace: Optional[UUID] = Field( - default=None, - title="The workspace that the role is limited to.", - ) - user: UUID = Field(title="The user that the role is assigned to.") - - role: UUID = Field(title="The role.") diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index f3c43ca31ac..bbc41054b5c 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -34,7 +34,8 @@ from zenml.models import UserResponseModel from zenml.models.user_models import JWTToken, JWTTokenType, UserAuthModel from zenml.utils.enum_utils import StrEnum -from zenml.zen_server.utils import ROOT_URL_PATH, zen_store +from zenml.zen_server.rbac_interface import Resource +from zenml.zen_server.utils import ROOT_URL_PATH, rbac, zen_store from zenml.zen_stores.base_zen_store import DEFAULT_USERNAME logger = get_logger(__name__) @@ -223,28 +224,20 @@ def oauth2_password_bearer_authentication( HTTPException: If the JWT token could not be authorized. """ if security_scopes.scopes: - authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' + pass else: - authenticate_value = "Bearer" + pass auth_context = authenticate_credentials(access_token=token) try: - access_token = JWTToken.decode( - token_type=JWTTokenType.ACCESS_TOKEN, token=token - ) + JWTToken.decode(token_type=JWTTokenType.ACCESS_TOKEN, token=token) except AuthorizationException: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) - for scope in security_scopes.scopes: - if scope not in access_token.permissions: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not enough permissions", - headers={"WWW-Authenticate": authenticate_value}, - ) + if auth_context is None: # We have to return an additional WWW-Authenticate header here with the # value Bearer to be compliant with the OAuth2 spec. @@ -253,6 +246,7 @@ def oauth2_password_bearer_authentication( detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) + return auth_context @@ -300,3 +294,62 @@ def authentication_provider() -> Callable[..., AuthContext]: authorize = authentication_provider() + + +def verify_permissions( + resource_type: str, + action: str, + resource_id: Optional[UUID] = None, +) -> None: + """Verifies if a user has permissions to perform an action on a resource. + + Args: + resource: The resource type the user wants to perform the action on. + action: The action the user wants to perform. + resource_id: ID of the resource the user wants to perform the action on. + + Raises: + HTTPException: If the user is not allowed to perform the action. + """ + if "ZENML_CLOUD" not in os.environ: + return + + user_id = get_auth_context().user.id + resource = Resource(type=resource_type, id=resource_id) + + if not rbac().has_permission( + user=user_id, resource=resource, action=action + ): + raise HTTPException(status_code=403) + + +def get_allowed_resource_ids( + resource_type: str, + action: str, +) -> Optional[List[UUID]]: + """Get all resource IDs of a resource type that a user can access. + + Args: + resource_type: The resource type. + action: The action the user wants to perform on the resource. + + Returns: + A list of resource IDs or `None` if the user has full access to the + all instances of the resource. + """ + if "ZENML_CLOUD" not in os.environ: + # Full access in any case + return None + + user_id = get_auth_context().user.id + ( + has_full_resource_access, + allowed_ids, + ) = rbac().list_allowed_resource_ids( + user=user_id, resource=Resource(type=resource_type), action=action + ) + + if has_full_resource_access: + return None + + return [UUID(id) for id in allowed_ids] diff --git a/src/zenml/zen_server/rbac_interface.py b/src/zenml/zen_server/rbac_interface.py new file mode 100644 index 00000000000..bda6b3721b5 --- /dev/null +++ b/src/zenml/zen_server/rbac_interface.py @@ -0,0 +1,62 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple +from uuid import UUID + +from pydantic import BaseModel + +from zenml.enums import StrEnum + + +class Action(StrEnum): + CREATE = "create" + READ = "read" + UPDATE = "update" + DELETE = "delete" + + +class ResourceType(StrEnum): + STACK = "stack" + COMPONENT = "component" + PIPELINE = "pipeline" + + +class Resource(BaseModel): + type: str + id: Optional[UUID] = None + + +class RBACInterface(ABC): + @abstractmethod + def has_permission( + self, user: UUID, resource: Resource, action: str + ) -> bool: + """Checks if a user has permission to perform an action on a resource. + + Args: + user: ID of the user which wants to access a resource. + resource: The resource the user wants to access. + action: The action that the user wants to perform on the resource. + + Returns: + Whether the user has permission to perform an action on a resource. + """ + + @abstractmethod + def list_allowed_resource_ids( + self, user: UUID, resource: Resource, action: str + ) -> Tuple[bool, List[str]]: + """Lists all resource IDs of a resource type that a user can access. + + Args: + user: ID of the user which wants to access a resource. + resource: The resource the user wants to access. + action: The action that the user wants to perform on the resource. + + Returns: + A tuple (full_resource_access, resource_ids). + `full_resource_access` will be `True` if the user can perform the + given action on any instance of the given resource type, `False` + otherwise. If `full_resource_access` is `False`, `resource_ids` + will contain the list of instance IDs that the user can perform + the action on. + """ diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index 181d165ca84..9169bf17127 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -19,10 +19,9 @@ from fastapi.param_functions import Form from zenml.constants import API, LOGIN, VERSION_1 -from zenml.models import UserRoleAssignmentFilterModel from zenml.zen_server.auth import authenticate_credentials from zenml.zen_server.exceptions import error_response -from zenml.zen_server.utils import zen_store +from zenml.zen_server.utils import rbac router = APIRouter( prefix=API + VERSION_1, @@ -99,20 +98,8 @@ def token( detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) - role_assignments = zen_store().list_user_role_assignments( - user_role_assignment_filter_model=UserRoleAssignmentFilterModel( - user_id=auth_context.user.id - ) - ) - # TODO: This needs to happen at the sql level now - permissions = set().union( - *[ - zen_store().get_role(ra.role.id).permissions - for ra in role_assignments.items - if ra.role is not None - ] - ) + permissions = rbac().get_all_permissions(user=auth_context.user.id) access_token = auth_context.user.generate_access_token( permissions=[p.value for p in permissions] diff --git a/src/zenml/zen_server/routers/pipelines_endpoints.py b/src/zenml/zen_server/routers/pipelines_endpoints.py index 8e5418e37be..ca9d3e0cb54 100644 --- a/src/zenml/zen_server/routers/pipelines_endpoints.py +++ b/src/zenml/zen_server/routers/pipelines_endpoints.py @@ -27,7 +27,7 @@ PipelineUpdateModel, ) from zenml.models.page_model import Page -from zenml.zen_server.auth import AuthContext, authorize +from zenml.zen_server.auth import AuthContext, authorize, verify_permissions from zenml.zen_server.exceptions import error_response from zenml.zen_server.utils import ( handle_exceptions, @@ -45,14 +45,21 @@ @router.get( "", response_model=Page[PipelineResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, + responses={ + 401: error_response, + 403: error_response, + 404: error_response, + 422: error_response, + }, ) @handle_exceptions def list_pipelines( pipeline_filter_model: PipelineFilterModel = Depends( make_dependable(PipelineFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + auth_context: AuthContext = Security( + authorize, scopes=[PermissionType.READ] + ), ) -> Page[PipelineResponseModel]: """Gets a list of pipelines. @@ -71,12 +78,19 @@ def list_pipelines( @router.get( "/{pipeline_id}", response_model=PipelineResponseModel, - responses={401: error_response, 404: error_response, 422: error_response}, + responses={ + 401: error_response, + 403: error_response, + 404: error_response, + 422: error_response, + }, ) @handle_exceptions def get_pipeline( pipeline_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + auth_context: AuthContext = Security( + authorize, scopes=[PermissionType.READ] + ), ) -> PipelineResponseModel: """Gets a specific pipeline using its unique id. @@ -86,6 +100,14 @@ def get_pipeline( Returns: A specific pipeline object. """ + from zenml.zen_server.rbac_interface import Action, ResourceType + + verify_permissions( + resource_type=ResourceType.PIPELINE, + action=Action.READ, + resource_id=pipeline_id, + ) + return zen_store().get_pipeline(pipeline_id=pipeline_id) diff --git a/src/zenml/zen_server/routers/role_assignments_endpoints.py b/src/zenml/zen_server/routers/role_assignments_endpoints.py index 155d656e1bd..6ae10af7d21 100644 --- a/src/zenml/zen_server/routers/role_assignments_endpoints.py +++ b/src/zenml/zen_server/routers/role_assignments_endpoints.py @@ -26,11 +26,7 @@ from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response -from zenml.zen_server.utils import ( - handle_exceptions, - make_dependable, - zen_store, -) +from zenml.zen_server.utils import handle_exceptions, make_dependable, rbac router = APIRouter( prefix=API + VERSION_1 + USER_ROLE_ASSIGNMENTS, @@ -59,7 +55,7 @@ def list_user_role_assignments( Returns: List of all role assignments. """ - return zen_store().list_user_role_assignments( + return rbac().list_user_role_assignments( user_role_assignment_filter_model=user_role_assignment_filter_model ) @@ -84,7 +80,7 @@ def create_role_assignment( Returns: The created role assignment. """ - return zen_store().create_user_role_assignment( + return rbac().create_user_role_assignment( user_role_assignment=role_assignment ) @@ -107,7 +103,7 @@ def get_role_assignment( Returns: A specific role assignment. """ - return zen_store().get_user_role_assignment( + return rbac().get_user_role_assignment( user_role_assignment_id=role_assignment_id ) @@ -126,6 +122,6 @@ def delete_role_assignment( Args: role_assignment_id: The ID of the role assignment. """ - zen_store().delete_user_role_assignment( + rbac().delete_user_role_assignment( user_role_assignment_id=role_assignment_id ) diff --git a/src/zenml/zen_server/routers/roles_endpoints.py b/src/zenml/zen_server/routers/roles_endpoints.py index c4b44305742..b877645d336 100644 --- a/src/zenml/zen_server/routers/roles_endpoints.py +++ b/src/zenml/zen_server/routers/roles_endpoints.py @@ -28,11 +28,7 @@ from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response -from zenml.zen_server.utils import ( - handle_exceptions, - make_dependable, - zen_store, -) +from zenml.zen_server.utils import handle_exceptions, make_dependable, rbac router = APIRouter( prefix=API + VERSION_1 + ROLES, @@ -62,7 +58,7 @@ def list_roles( Returns: List of all roles. """ - return zen_store().list_roles(role_filter_model=role_filter_model) + return rbac().list_roles(role_filter_model=role_filter_model) @router.post( @@ -85,7 +81,7 @@ def create_role( Returns: The created role. """ - return zen_store().create_role(role=role) + return rbac().create_role(role=role) @router.get( @@ -106,7 +102,7 @@ def get_role( Returns: A specific role. """ - return zen_store().get_role(role_name_or_id=role_name_or_id) + return rbac().get_role(role_name_or_id=role_name_or_id) @router.put( @@ -131,7 +127,7 @@ def update_role( Returns: The created role. """ - return zen_store().update_role(role_id=role_id, role_update=role_update) + return rbac().update_role(role_id=role_id, role_update=role_update) @router.delete( @@ -148,4 +144,4 @@ def delete_role( Args: role_name_or_id: Name or ID of the role. """ - zen_store().delete_role(role_name_or_id=role_name_or_id) + rbac().delete_role(role_name_or_id=role_name_or_id) diff --git a/src/zenml/zen_server/routers/server_endpoints.py b/src/zenml/zen_server/routers/server_endpoints.py index 9aa3719a763..5fe9b3ccab0 100644 --- a/src/zenml/zen_server/routers/server_endpoints.py +++ b/src/zenml/zen_server/routers/server_endpoints.py @@ -13,6 +13,8 @@ # permissions and limitations under the License. """Endpoint definitions for authentication (login).""" +import os + from fastapi import APIRouter import zenml @@ -50,4 +52,8 @@ def server_info() -> ServerModel: Returns: Information about the server. """ - return zen_store().get_store_info() + info = zen_store().get_store_info() + if "ZENML_CLOUD" in os.environ: + info.zenml_cloud = True + + return info diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index b22ae2c8aa3..98b21a6e171 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -21,7 +21,7 @@ from zenml.enums import PermissionType from zenml.models import StackFilterModel, StackResponseModel, StackUpdateModel from zenml.models.page_model import Page -from zenml.zen_server.auth import AuthContext, authorize +from zenml.zen_server.auth import AuthContext, authorize, verify_permissions from zenml.zen_server.exceptions import error_response from zenml.zen_server.utils import ( handle_exceptions, @@ -67,12 +67,19 @@ def list_stacks( @router.get( "/{stack_id}", response_model=StackResponseModel, - responses={401: error_response, 404: error_response, 422: error_response}, + responses={ + 401: error_response, + 403: error_response, + 404: error_response, + 422: error_response, + }, ) @handle_exceptions def get_stack( stack_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + auth_context: AuthContext = Security( + authorize, scopes=[PermissionType.READ] + ), ) -> StackResponseModel: """Returns the requested stack. @@ -82,6 +89,14 @@ def get_stack( Returns: The requested stack. """ + from zenml.zen_server.rbac_interface import Action, ResourceType + + verify_permissions( + resource_type=ResourceType.STACK, + action=Action.READ, + resource_id=stack_id, + ) + return zen_store().get_stack(stack_id) diff --git a/src/zenml/zen_server/routers/team_role_assignments_endpoints.py b/src/zenml/zen_server/routers/team_role_assignments_endpoints.py index 9254601c576..9d3cdf6cd34 100644 --- a/src/zenml/zen_server/routers/team_role_assignments_endpoints.py +++ b/src/zenml/zen_server/routers/team_role_assignments_endpoints.py @@ -11,122 +11,122 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Endpoint definitions for role assignments.""" -from uuid import UUID - -from fastapi import APIRouter, Depends, Security - -from zenml.constants import API, TEAM_ROLE_ASSIGNMENTS, VERSION_1 -from zenml.enums import PermissionType -from zenml.models import ( - TeamRoleAssignmentFilterModel, - TeamRoleAssignmentRequestModel, - TeamRoleAssignmentResponseModel, -) -from zenml.models.page_model import Page -from zenml.zen_server.auth import AuthContext, authorize -from zenml.zen_server.exceptions import error_response -from zenml.zen_server.utils import ( - handle_exceptions, - make_dependable, - zen_store, -) - -router = APIRouter( - prefix=API + VERSION_1 + TEAM_ROLE_ASSIGNMENTS, - tags=["team_role_assignments"], - responses={401: error_response}, -) - - -@router.get( - "", - response_model=Page[TeamRoleAssignmentResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_team_role_assignments( - team_role_assignment_filter_model: TeamRoleAssignmentFilterModel = Depends( - make_dependable(TeamRoleAssignmentFilterModel) - ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[TeamRoleAssignmentResponseModel]: - """Returns a list of all role assignments. - - Args: - team_role_assignment_filter_model: filter models for team role assignments - - - Returns: - List of all role assignments. - """ - return zen_store().list_team_role_assignments( - team_role_assignment_filter_model=team_role_assignment_filter_model - ) - - -@router.post( - "", - response_model=TeamRoleAssignmentResponseModel, - responses={401: error_response, 409: error_response, 422: error_response}, -) -@handle_exceptions -def create_team_role_assignment( - role_assignment: TeamRoleAssignmentRequestModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> TeamRoleAssignmentResponseModel: - """Creates a role assignment. - - # noqa: DAR401 - - Args: - role_assignment: Role assignment to create. - - Returns: - The created role assignment. - """ - return zen_store().create_team_role_assignment( - team_role_assignment=role_assignment - ) - - -@router.get( - "/{role_assignment_id}", - response_model=TeamRoleAssignmentResponseModel, - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def get_team_role_assignment( - role_assignment_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> TeamRoleAssignmentResponseModel: - """Returns a specific role assignment. - - Args: - role_assignment_id: Name or ID of the role assignment. - - Returns: - A specific role assignment. - """ - return zen_store().get_team_role_assignment( - team_role_assignment_id=role_assignment_id - ) - - -@router.delete( - "/{role_assignment_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_team_role_assignment( - role_assignment_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Deletes a specific role. - - Args: - role_assignment_id: The ID of the role assignment. - """ - zen_store().delete_team_role_assignment( - team_role_assignment_id=role_assignment_id - ) +# """Endpoint definitions for role assignments.""" +# from uuid import UUID + +# from fastapi import APIRouter, Depends, Security + +# from zenml.constants import API, TEAM_ROLE_ASSIGNMENTS, VERSION_1 +# from zenml.enums import PermissionType +# from zenml.models import ( +# TeamRoleAssignmentFilterModel, +# TeamRoleAssignmentRequestModel, +# TeamRoleAssignmentResponseModel, +# ) +# from zenml.models.page_model import Page +# from zenml.zen_server.auth import AuthContext, authorize +# from zenml.zen_server.exceptions import error_response +# from zenml.zen_server.utils import ( +# handle_exceptions, +# make_dependable, +# zen_store, +# ) + +# router = APIRouter( +# prefix=API + VERSION_1 + TEAM_ROLE_ASSIGNMENTS, +# tags=["team_role_assignments"], +# responses={401: error_response}, +# ) + + +# @router.get( +# "", +# response_model=Page[TeamRoleAssignmentResponseModel], +# responses={401: error_response, 404: error_response, 422: error_response}, +# ) +# @handle_exceptions +# def list_team_role_assignments( +# team_role_assignment_filter_model: TeamRoleAssignmentFilterModel = Depends( +# make_dependable(TeamRoleAssignmentFilterModel) +# ), +# _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +# ) -> Page[TeamRoleAssignmentResponseModel]: +# """Returns a list of all role assignments. + +# Args: +# team_role_assignment_filter_model: filter models for team role assignments + + +# Returns: +# List of all role assignments. +# """ +# return zen_store().list_team_role_assignments( +# team_role_assignment_filter_model=team_role_assignment_filter_model +# ) + + +# @router.post( +# "", +# response_model=TeamRoleAssignmentResponseModel, +# responses={401: error_response, 409: error_response, 422: error_response}, +# ) +# @handle_exceptions +# def create_team_role_assignment( +# role_assignment: TeamRoleAssignmentRequestModel, +# _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +# ) -> TeamRoleAssignmentResponseModel: +# """Creates a role assignment. + +# # noqa: DAR401 + +# Args: +# role_assignment: Role assignment to create. + +# Returns: +# The created role assignment. +# """ +# return zen_store().create_team_role_assignment( +# team_role_assignment=role_assignment +# ) + + +# @router.get( +# "/{role_assignment_id}", +# response_model=TeamRoleAssignmentResponseModel, +# responses={401: error_response, 404: error_response, 422: error_response}, +# ) +# @handle_exceptions +# def get_team_role_assignment( +# role_assignment_id: UUID, +# _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +# ) -> TeamRoleAssignmentResponseModel: +# """Returns a specific role assignment. + +# Args: +# role_assignment_id: Name or ID of the role assignment. + +# Returns: +# A specific role assignment. +# """ +# return zen_store().get_team_role_assignment( +# team_role_assignment_id=role_assignment_id +# ) + + +# @router.delete( +# "/{role_assignment_id}", +# responses={401: error_response, 404: error_response, 422: error_response}, +# ) +# @handle_exceptions +# def delete_team_role_assignment( +# role_assignment_id: UUID, +# _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +# ) -> None: +# """Deletes a specific role. + +# Args: +# role_assignment_id: The ID of the role assignment. +# """ +# zen_store().delete_team_role_assignment( +# team_role_assignment_id=role_assignment_id +# ) diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index 86fa71a759a..a13ef128edf 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -33,6 +33,7 @@ LocalServerDeploymentConfig, ) from zenml.zen_server.exceptions import http_exception_from_error +from zenml.zen_server.rbac_interface import RBACInterface from zenml.zen_stores.sql_zen_store import SqlZenStore logger = get_logger(__name__) @@ -42,6 +43,7 @@ _zen_store: Optional["SqlZenStore"] = None +_rbac: Optional[RBACInterface] = None def zen_store() -> "SqlZenStore": @@ -59,6 +61,29 @@ def zen_store() -> "SqlZenStore": return _zen_store +def rbac() -> RBACInterface: + """Return the initialized RBAC component. + + Raises: + RuntimeError: If the RBAC component is not initialized. + + Returns: + The RBAC component. + """ + global _rbac + if _rbac is None: + raise RuntimeError("RBAC component not initialized") + return _rbac + + +def initialize_rbac() -> None: + """Initialize the RBAC component.""" + from zenml.zen_server.cloud_rbac import CloudRBAC + + global _rbac + _rbac = CloudRBAC() + + def initialize_zen_store() -> None: """Initialize the ZenML Store. diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 896c5dc6b21..b7f82551d3d 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -38,8 +38,6 @@ pipeline_builds_endpoints, pipeline_deployments_endpoints, pipelines_endpoints, - role_assignments_endpoints, - roles_endpoints, run_metadata_endpoints, runs_endpoints, schedule_endpoints, @@ -49,12 +47,14 @@ stack_components_endpoints, stacks_endpoints, steps_endpoints, - team_role_assignments_endpoints, - teams_endpoints, users_endpoints, workspaces_endpoints, ) -from zenml.zen_server.utils import ROOT_URL_PATH, initialize_zen_store +from zenml.zen_server.utils import ( + ROOT_URL_PATH, + initialize_rbac, + initialize_zen_store, +) DASHBOARD_DIRECTORY = "dashboard" @@ -144,6 +144,7 @@ def initialize() -> None: # IMPORTANT: this needs to be done before the fastapi app starts, to avoid # race conditions initialize_zen_store() + initialize_rbac() app.mount( @@ -198,9 +199,6 @@ def dashboard(request: Request) -> Any: app.include_router(pipelines_endpoints.router) app.include_router(workspaces_endpoints.router) app.include_router(flavors_endpoints.router) -app.include_router(roles_endpoints.router) -app.include_router(role_assignments_endpoints.router) -app.include_router(team_role_assignments_endpoints.router) app.include_router(runs_endpoints.router) app.include_router(run_metadata_endpoints.router) app.include_router(schedule_endpoints.router) @@ -213,7 +211,6 @@ def dashboard(request: Request) -> Any: app.include_router(stack_components_endpoints.types_router) app.include_router(steps_endpoints.router) app.include_router(artifacts_endpoints.router) -app.include_router(teams_endpoints.router) app.include_router(users_endpoints.router) app.include_router(users_endpoints.current_user_router) app.include_router(users_endpoints.activation_router) diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index be1252215f5..09e490c9b87 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -43,7 +43,6 @@ IS_DEBUG_ENV, ) from zenml.enums import ( - PermissionType, SecretsStoreType, StackComponentType, StoreType, @@ -51,19 +50,14 @@ from zenml.logger import get_logger from zenml.models import ( ComponentRequestModel, - RoleFilterModel, - RoleRequestModel, - RoleResponseModel, StackFilterModel, StackRequestModel, StackResponseModel, UserRequestModel, UserResponseModel, - UserRoleAssignmentRequestModel, WorkspaceRequestModel, WorkspaceResponseModel, ) -from zenml.models.page_model import Page from zenml.models.server_models import ( ServerDatabaseType, ServerDeploymentType, @@ -87,8 +81,6 @@ DEFAULT_WORKSPACE_NAME = "default" DEFAULT_STACK_NAME = "default" DEFAULT_STACK_COMPONENT_NAME = "default" -DEFAULT_ADMIN_ROLE = "admin" -DEFAULT_GUEST_ROLE = "guest" @make_proxy_class(SecretsStoreInterface, "_secrets_store") @@ -305,14 +297,6 @@ def _initialize_database(self) -> None: default_workspace = self._default_workspace except KeyError: default_workspace = self._create_default_workspace() - try: - assert self._admin_role - except KeyError: - self._create_admin_role() - try: - assert self._guest_role - except KeyError: - self._create_guest_role() try: default_user = self._default_user except KeyError: @@ -465,6 +449,7 @@ def get_store_info(self) -> ServerModel: ), database_type=ServerDatabaseType.OTHER, debug=IS_DEBUG_ENV, + zenml_cloud=False, secrets_store_type=self.secrets_store.type if self.secrets_store else SecretsStoreType.NONE, @@ -635,62 +620,6 @@ def _get_default_stack( ) return default_stacks.items[0] - # ----- - # Roles - # ----- - @property - def _admin_role(self) -> RoleResponseModel: - """Get the admin role. - - Returns: - The default admin role. - """ - return self.get_role(DEFAULT_ADMIN_ROLE) - - def _create_admin_role(self) -> RoleResponseModel: - """Creates the admin role. - - Returns: - The admin role - """ - logger.info(f"Creating '{DEFAULT_ADMIN_ROLE}' role ...") - return self.create_role( - RoleRequestModel( - name=DEFAULT_ADMIN_ROLE, - permissions={ - PermissionType.READ, - PermissionType.WRITE, - PermissionType.ME, - }, - ) - ) - - @property - def _guest_role(self) -> RoleResponseModel: - """Get the guest role. - - Returns: - The guest role. - """ - return self.get_role(DEFAULT_GUEST_ROLE) - - def _create_guest_role(self) -> RoleResponseModel: - """Creates the guest role. - - Returns: - The guest role - """ - logger.info(f"Creating '{DEFAULT_GUEST_ROLE}' role ...") - return self.create_role( - RoleRequestModel( - name=DEFAULT_GUEST_ROLE, - permissions={ - PermissionType.READ, - PermissionType.ME, - }, - ) - ) - # ----- # Users # ----- @@ -721,7 +650,7 @@ def _default_user(self) -> UserResponseModel: raise KeyError(f"The default user '{user_name}' is not configured") def _create_default_user(self) -> UserResponseModel: - """Creates a default user with the admin role. + """Creates a default user. Returns: The default user. @@ -739,28 +668,8 @@ def _create_default_user(self) -> UserResponseModel: password=user_password, ) ) - self.create_user_role_assignment( - UserRoleAssignmentRequestModel( - role=self._admin_role.id, - user=new_user.id, - workspace=None, - ) - ) return new_user - # ----- - # Roles - # ----- - - @property - def roles(self) -> Page[RoleResponseModel]: - """All existing roles. - - Returns: - A list of all existing roles. - """ - return self.list_roles(RoleFilterModel()) - # -------- # Workspaces # -------- diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 6a4768cc493..875a9625fa5 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -52,7 +52,6 @@ PIPELINE_BUILDS, PIPELINE_DEPLOYMENTS, PIPELINES, - ROLES, RUN_METADATA, RUNS, SCHEDULES, @@ -64,9 +63,6 @@ STACK_COMPONENTS, STACKS, STEPS, - TEAM_ROLE_ASSIGNMENTS, - TEAMS, - USER_ROLE_ASSIGNMENTS, USERS, VERSION_1, WORKSPACES, @@ -108,10 +104,6 @@ PipelineRunResponseModel, PipelineRunUpdateModel, PipelineUpdateModel, - RoleFilterModel, - RoleRequestModel, - RoleResponseModel, - RoleUpdateModel, RunMetadataRequestModel, RunMetadataResponseModel, ScheduleRequestModel, @@ -131,17 +123,9 @@ StepRunRequestModel, StepRunResponseModel, StepRunUpdateModel, - TeamRequestModel, - TeamResponseModel, - TeamRoleAssignmentFilterModel, - TeamRoleAssignmentRequestModel, - TeamRoleAssignmentResponseModel, UserFilterModel, UserRequestModel, UserResponseModel, - UserRoleAssignmentFilterModel, - UserRoleAssignmentRequestModel, - UserRoleAssignmentResponseModel, UserUpdateModel, WorkspaceFilterModel, WorkspaceRequestModel, @@ -157,7 +141,6 @@ from zenml.models.run_metadata_models import RunMetadataFilterModel from zenml.models.schedule_model import ScheduleFilterModel from zenml.models.server_models import ServerModel -from zenml.models.team_models import TeamFilterModel, TeamUpdateModel from zenml.service_connectors.service_connector_registry import ( service_connector_registry, ) @@ -786,308 +769,6 @@ def delete_user(self, user_name_or_id: Union[str, UUID]) -> None: route=USERS, ) - # ----- - # Teams - # ----- - - def create_team(self, team: TeamRequestModel) -> TeamResponseModel: - """Creates a new team. - - Args: - team: The team model to create. - - Returns: - The newly created team. - """ - return self._create_resource( - resource=team, - route=TEAMS, - response_model=TeamResponseModel, - ) - - def get_team(self, team_name_or_id: Union[str, UUID]) -> TeamResponseModel: - """Gets a specific team. - - Args: - team_name_or_id: Name or ID of the team to get. - - Returns: - The requested team. - """ - return self._get_resource( - resource_id=team_name_or_id, - route=TEAMS, - response_model=TeamResponseModel, - ) - - def list_teams( - self, team_filter_model: TeamFilterModel - ) -> Page[TeamResponseModel]: - """List all teams matching the given filter criteria. - - Args: - team_filter_model: All filter parameters including pagination - params. - - Returns: - A list of all teams matching the filter criteria. - """ - return self._list_paginated_resources( - route=TEAMS, - response_model=TeamResponseModel, - filter_model=team_filter_model, - ) - - def update_team( - self, team_id: UUID, team_update: TeamUpdateModel - ) -> TeamResponseModel: - """Update an existing team. - - Args: - team_id: The ID of the team to be updated. - team_update: The update to be applied to the team. - - Returns: - The updated team. - """ - return self._update_resource( - resource_id=team_id, - resource_update=team_update, - route=TEAMS, - response_model=TeamResponseModel, - ) - - def delete_team(self, team_name_or_id: Union[str, UUID]) -> None: - """Deletes a team. - - Args: - team_name_or_id: Name or ID of the team to delete. - """ - self._delete_resource( - resource_id=team_name_or_id, - route=TEAMS, - ) - - # ----- - # Roles - # ----- - - def create_role(self, role: RoleRequestModel) -> RoleResponseModel: - """Creates a new role. - - Args: - role: The role model to create. - - Returns: - The newly created role. - """ - return self._create_resource( - resource=role, - route=ROLES, - response_model=RoleResponseModel, - ) - - def get_role(self, role_name_or_id: Union[str, UUID]) -> RoleResponseModel: - """Gets a specific role. - - Args: - role_name_or_id: Name or ID of the role to get. - - Returns: - The requested role. - """ - return self._get_resource( - resource_id=role_name_or_id, - route=ROLES, - response_model=RoleResponseModel, - ) - - def list_roles( - self, role_filter_model: RoleFilterModel - ) -> Page[RoleResponseModel]: - """List all roles matching the given filter criteria. - - Args: - role_filter_model: All filter parameters including pagination - params. - - Returns: - A list of all roles matching the filter criteria. - """ - return self._list_paginated_resources( - route=ROLES, - response_model=RoleResponseModel, - filter_model=role_filter_model, - ) - - def update_role( - self, role_id: UUID, role_update: RoleUpdateModel - ) -> RoleResponseModel: - """Update an existing role. - - Args: - role_id: The ID of the role to be updated. - role_update: The update to be applied to the role. - - Returns: - The updated role. - """ - return self._update_resource( - resource_id=role_id, - resource_update=role_update, - route=ROLES, - response_model=RoleResponseModel, - ) - - def delete_role(self, role_name_or_id: Union[str, UUID]) -> None: - """Deletes a role. - - Args: - role_name_or_id: Name or ID of the role to delete. - """ - self._delete_resource( - resource_id=role_name_or_id, - route=ROLES, - ) - - # ---------------- - # Role assignments - # ---------------- - - def list_user_role_assignments( - self, user_role_assignment_filter_model: UserRoleAssignmentFilterModel - ) -> Page[UserRoleAssignmentResponseModel]: - """List all roles assignments matching the given filter criteria. - - Args: - user_role_assignment_filter_model: All filter parameters including - pagination params. - - Returns: - A list of all roles assignments matching the filter criteria. - """ - return self._list_paginated_resources( - route=USER_ROLE_ASSIGNMENTS, - response_model=UserRoleAssignmentResponseModel, - filter_model=user_role_assignment_filter_model, - ) - - def get_user_role_assignment( - self, user_role_assignment_id: UUID - ) -> UserRoleAssignmentResponseModel: - """Get an existing role assignment by name or ID. - - Args: - user_role_assignment_id: Name or ID of the role assignment to get. - - Returns: - The requested workspace. - """ - return self._get_resource( - resource_id=user_role_assignment_id, - route=USER_ROLE_ASSIGNMENTS, - response_model=UserRoleAssignmentResponseModel, - ) - - def delete_user_role_assignment( - self, user_role_assignment_id: UUID - ) -> None: - """Delete a specific role assignment. - - Args: - user_role_assignment_id: The ID of the specific role assignment - """ - self._delete_resource( - resource_id=user_role_assignment_id, - route=USER_ROLE_ASSIGNMENTS, - ) - - def create_user_role_assignment( - self, user_role_assignment: UserRoleAssignmentRequestModel - ) -> UserRoleAssignmentResponseModel: - """Creates a new role assignment. - - Args: - user_role_assignment: The role assignment to create. - - Returns: - The newly created workspace. - """ - return self._create_resource( - resource=user_role_assignment, - route=USER_ROLE_ASSIGNMENTS, - response_model=UserRoleAssignmentResponseModel, - ) - - # --------------------- - # Team Role assignments - # --------------------- - - def create_team_role_assignment( - self, team_role_assignment: TeamRoleAssignmentRequestModel - ) -> TeamRoleAssignmentResponseModel: - """Creates a new team role assignment. - - Args: - team_role_assignment: The role assignment model to create. - - Returns: - The newly created role assignment. - """ - return self._create_resource( - resource=team_role_assignment, - route=TEAM_ROLE_ASSIGNMENTS, - response_model=TeamRoleAssignmentResponseModel, - ) - - def get_team_role_assignment( - self, team_role_assignment_id: UUID - ) -> TeamRoleAssignmentResponseModel: - """Gets a specific role assignment. - - Args: - team_role_assignment_id: ID of the role assignment to get. - - Returns: - The requested role assignment. - """ - return self._get_resource( - resource_id=team_role_assignment_id, - route=TEAM_ROLE_ASSIGNMENTS, - response_model=TeamRoleAssignmentResponseModel, - ) - - def delete_team_role_assignment( - self, team_role_assignment_id: UUID - ) -> None: - """Delete a specific role assignment. - - Args: - team_role_assignment_id: The ID of the specific role assignment - """ - self._delete_resource( - resource_id=team_role_assignment_id, - route=TEAM_ROLE_ASSIGNMENTS, - ) - - def list_team_role_assignments( - self, team_role_assignment_filter_model: TeamRoleAssignmentFilterModel - ) -> Page[TeamRoleAssignmentResponseModel]: - """List all roles assignments matching the given filter criteria. - - Args: - team_role_assignment_filter_model: All filter parameters including - pagination params. - - Returns: - A list of all roles assignments matching the filter criteria. - """ - return self._list_paginated_resources( - route=TEAM_ROLE_ASSIGNMENTS, - response_model=TeamRoleAssignmentResponseModel, - filter_model=team_role_assignment_filter_model, - ) - # -------- # Workspaces # -------- diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index a3efc01a231..55b72f6c616 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -29,12 +29,6 @@ from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.pipeline_schemas import PipelineSchema from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema -from zenml.zen_stores.schemas.role_schemas import ( - RolePermissionSchema, - RoleSchema, - TeamRoleAssignmentSchema, - UserRoleAssignmentSchema, -) from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema from zenml.zen_stores.schemas.secret_schemas import SecretSchema @@ -51,10 +45,6 @@ StepRunParentsSchema, StepRunSchema, ) -from zenml.zen_stores.schemas.team_schemas import ( - TeamAssignmentSchema, - TeamSchema, -) from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.logs_schemas import LogsSchema @@ -72,8 +62,6 @@ "PipelineRunSchema", "PipelineSchema", "WorkspaceSchema", - "RoleSchema", - "RolePermissionSchema", "RunMetadataSchema", "ScheduleSchema", "SecretSchema", @@ -85,10 +73,6 @@ "StepRunOutputArtifactSchema", "StepRunParentsSchema", "StepRunSchema", - "TeamRoleAssignmentSchema", - "TeamSchema", - "TeamAssignmentSchema", - "UserRoleAssignmentSchema", "UserSchema", "LogsSchema", ] diff --git a/src/zenml/zen_stores/schemas/role_schemas.py b/src/zenml/zen_stores/schemas/role_schemas.py deleted file mode 100644 index 2f09b7fb102..00000000000 --- a/src/zenml/zen_stores/schemas/role_schemas.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""SQLModel implementation of roles that can be assigned to users or teams.""" - -from datetime import datetime -from typing import List, Optional -from uuid import UUID, uuid4 - -from sqlmodel import Field, Relationship, SQLModel - -from zenml.enums import PermissionType -from zenml.models import ( - RoleRequestModel, - RoleResponseModel, - RoleUpdateModel, - TeamRoleAssignmentRequestModel, - TeamRoleAssignmentResponseModel, - UserRoleAssignmentRequestModel, -) -from zenml.models.user_role_assignment_models import ( - UserRoleAssignmentResponseModel, -) -from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema -from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field -from zenml.zen_stores.schemas.team_schemas import TeamSchema -from zenml.zen_stores.schemas.user_schemas import UserSchema -from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema - - -class RoleSchema(NamedSchema, table=True): - """SQL Model for roles.""" - - __tablename__ = "role" - - permissions: List["RolePermissionSchema"] = Relationship( - back_populates="roles", sa_relationship_kwargs={"cascade": "delete"} - ) - user_role_assignments: List["UserRoleAssignmentSchema"] = Relationship( - back_populates="role", sa_relationship_kwargs={"cascade": "delete"} - ) - team_role_assignments: List["TeamRoleAssignmentSchema"] = Relationship( - back_populates="role", sa_relationship_kwargs={"cascade": "delete"} - ) - - @classmethod - def from_request(cls, model: RoleRequestModel) -> "RoleSchema": - """Create a `RoleSchema` from a `RoleResponseModel`. - - Args: - model: The `RoleResponseModel` from which to create the schema. - - Returns: - The created `RoleSchema`. - """ - return cls(name=model.name) - - def update(self, role_update: RoleUpdateModel) -> "RoleSchema": - """Update a `RoleSchema` from a `RoleUpdateModel`. - - Args: - role_update: The `RoleUpdateModel` from which to update the schema. - - Returns: - The updated `RoleSchema`. - """ - for field, value in role_update.dict( - exclude_unset=True, exclude={"permissions"} - ).items(): - setattr(self, field, value) - - self.updated = datetime.utcnow() - return self - - def to_model(self) -> RoleResponseModel: - """Convert a `RoleSchema` to a `RoleResponseModel`. - - Returns: - The converted `RoleResponseModel`. - """ - return RoleResponseModel( - id=self.id, - name=self.name, - created=self.created, - updated=self.updated, - permissions={PermissionType(p.name) for p in self.permissions}, - ) - - -class UserRoleAssignmentSchema(BaseSchema, table=True): - """SQL Model for assigning roles to users for a given workspace.""" - - __tablename__ = "user_role_assignment" - - id: UUID = Field(primary_key=True, default_factory=uuid4) - role_id: UUID = build_foreign_key_field( - source=__tablename__, - target=RoleSchema.__tablename__, - source_column="role_id", - target_column="id", - ondelete="CASCADE", - nullable=False, - ) - user_id: UUID = build_foreign_key_field( - source=__tablename__, - target=UserSchema.__tablename__, - source_column="user_id", - target_column="id", - ondelete="CASCADE", - nullable=False, - ) - workspace_id: Optional[UUID] = build_foreign_key_field( - source=__tablename__, - target=WorkspaceSchema.__tablename__, - source_column="workspace_id", - target_column="id", - ondelete="CASCADE", - nullable=True, - ) - - role: RoleSchema = Relationship(back_populates="user_role_assignments") - user: Optional["UserSchema"] = Relationship( - back_populates="assigned_roles" - ) - workspace: Optional["WorkspaceSchema"] = Relationship( - back_populates="user_role_assignments" - ) - - @classmethod - def from_request( - cls, role_assignment: UserRoleAssignmentRequestModel - ) -> "UserRoleAssignmentSchema": - """Create a `UserRoleAssignmentSchema` from a `RoleAssignmentRequestModel`. - - Args: - role_assignment: The `RoleAssignmentRequestModel` from which to - create the schema. - - Returns: - The created `UserRoleAssignmentSchema`. - """ - return cls( - role_id=role_assignment.role, - user_id=role_assignment.user, - workspace_id=role_assignment.workspace, - ) - - def to_model(self) -> UserRoleAssignmentResponseModel: - """Convert a `UserRoleAssignmentSchema` to a `RoleAssignmentModel`. - - Returns: - The converted `RoleAssignmentModel`. - """ - return UserRoleAssignmentResponseModel( - id=self.id, - workspace=self.workspace.to_model() if self.workspace else None, - user=self.user.to_model(_block_recursion=True) - if self.user - else None, - role=self.role.to_model(), - created=self.created, - updated=self.updated, - ) - - -class TeamRoleAssignmentSchema(BaseSchema, table=True): - """SQL Model for assigning roles to teams for a given workspace.""" - - __tablename__ = "team_role_assignment" - - id: UUID = Field(primary_key=True, default_factory=uuid4) - role_id: UUID = build_foreign_key_field( - source=__tablename__, - target=RoleSchema.__tablename__, - source_column="role_id", - target_column="id", - ondelete="CASCADE", - nullable=False, - ) - team_id: UUID = build_foreign_key_field( - source=__tablename__, - target=TeamSchema.__tablename__, - source_column="team_id", - target_column="id", - ondelete="CASCADE", - nullable=False, - ) - workspace_id: Optional[UUID] = build_foreign_key_field( - source=__tablename__, - target=WorkspaceSchema.__tablename__, - source_column="workspace_id", - target_column="id", - ondelete="CASCADE", - nullable=True, - ) - role: RoleSchema = Relationship(back_populates="team_role_assignments") - team: "TeamSchema" = Relationship(back_populates="assigned_roles") - workspace: Optional["WorkspaceSchema"] = Relationship( - back_populates="team_role_assignments" - ) - - @classmethod - def from_request( - cls, role_assignment: TeamRoleAssignmentRequestModel - ) -> "TeamRoleAssignmentSchema": - """Create a `TeamRoleAssignmentSchema` from a `RoleAssignmentRequestModel`. - - Args: - role_assignment: The `RoleAssignmentRequestModel` from which to - create the schema. - - Returns: - The created `TeamRoleAssignmentSchema`. - """ - return cls( - role_id=role_assignment.role, - team_id=role_assignment.team, - workspace_id=role_assignment.workspace, - ) - - def to_model(self) -> TeamRoleAssignmentResponseModel: - """Convert a `TeamRoleAssignmentSchema` to a `RoleAssignmentModel`. - - Returns: - The converted `RoleAssignmentModel`. - """ - return TeamRoleAssignmentResponseModel( - id=self.id, - workspace=self.workspace.to_model() if self.workspace else None, - team=self.team.to_model(_block_recursion=True), - role=self.role.to_model(), - created=self.created, - updated=self.updated, - ) - - -class RolePermissionSchema(SQLModel, table=True): - """SQL Model for team assignments.""" - - __tablename__ = "role_permission" - - name: PermissionType = Field(primary_key=True) - role_id: UUID = build_foreign_key_field( - source=__tablename__, - target=RoleSchema.__tablename__, - source_column="role_id", - target_column="id", - ondelete="CASCADE", - nullable=False, - primary_key=True, - ) - roles: List["RoleSchema"] = Relationship(back_populates="permissions") diff --git a/src/zenml/zen_stores/schemas/team_schemas.py b/src/zenml/zen_stores/schemas/team_schemas.py deleted file mode 100644 index 17206137ce2..00000000000 --- a/src/zenml/zen_stores/schemas/team_schemas.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""SQLModel implementation of team tables.""" - -from datetime import datetime -from typing import TYPE_CHECKING, List -from uuid import UUID - -from sqlmodel import Relationship, SQLModel - -from zenml.models import TeamResponseModel -from zenml.models.team_models import TeamUpdateModel -from zenml.zen_stores.schemas.base_schemas import NamedSchema -from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field - -if TYPE_CHECKING: - from zenml.zen_stores.schemas.role_schemas import TeamRoleAssignmentSchema - from zenml.zen_stores.schemas.user_schemas import UserSchema - - -class TeamAssignmentSchema(SQLModel, table=True): - """SQL Model for team assignments.""" - - __tablename__ = "team_assignment" - - user_id: UUID = build_foreign_key_field( - source=__tablename__, - target="user", # TODO: how to reference `UserSchema.__tablename__`? - source_column="user_id", - target_column="id", - ondelete="CASCADE", - nullable=False, - primary_key=True, - ) - team_id: UUID = build_foreign_key_field( - source=__tablename__, - target="team", # TODO: how to reference `TeamSchema.__tablename__`? - source_column="team_id", - target_column="id", - ondelete="CASCADE", - nullable=False, - primary_key=True, - ) - - -class TeamSchema(NamedSchema, table=True): - """SQL Model for teams.""" - - __tablename__ = "team" - - users: List["UserSchema"] = Relationship( - back_populates="teams", link_model=TeamAssignmentSchema - ) - assigned_roles: List["TeamRoleAssignmentSchema"] = Relationship( - back_populates="team", sa_relationship_kwargs={"cascade": "delete"} - ) - - def update(self, team_update: TeamUpdateModel) -> "TeamSchema": - """Update a `TeamSchema` with a `TeamUpdateModel`. - - Args: - team_update: The `TeamUpdateModel` to update the schema with. - - Returns: - The updated `TeamSchema`. - """ - for field, value in team_update.dict(exclude_unset=True).items(): - if field == "users": - pass - else: - setattr(self, field, value) - - self.updated = datetime.utcnow() - return self - - def to_model(self, _block_recursion: bool = False) -> TeamResponseModel: - """Convert a `TeamSchema` to a `TeamResponseModel`. - - Args: - _block_recursion: Don't recursively fill attributes - - Returns: - The converted `TeamResponseModel`. - """ - if _block_recursion: - return TeamResponseModel( - id=self.id, - name=self.name, - created=self.created, - updated=self.updated, - users=[], - ) - else: - return TeamResponseModel( - id=self.id, - name=self.name, - created=self.created, - updated=self.updated, - users=[u.to_model(_block_recursion=False) for u in self.users], - ) diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index 5710e05b735..87436e26e53 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -20,7 +20,6 @@ from zenml.models import UserRequestModel, UserResponseModel, UserUpdateModel from zenml.zen_stores.schemas.base_schemas import NamedSchema -from zenml.zen_stores.schemas.team_schemas import TeamAssignmentSchema if TYPE_CHECKING: from zenml.zen_stores.schemas import ( @@ -38,8 +37,6 @@ StackComponentSchema, StackSchema, StepRunSchema, - TeamSchema, - UserRoleAssignmentSchema, ) @@ -56,12 +53,6 @@ class UserSchema(NamedSchema, table=True): hub_token: Optional[str] = Field(nullable=True) email_opted_in: Optional[bool] = Field(nullable=True) - teams: List["TeamSchema"] = Relationship( - back_populates="users", link_model=TeamAssignmentSchema - ) - assigned_roles: List["UserRoleAssignmentSchema"] = Relationship( - back_populates="user", sa_relationship_kwargs={"cascade": "delete"} - ) stacks: List["StackSchema"] = Relationship(back_populates="user") components: List["StackComponentSchema"] = Relationship( back_populates="user", @@ -166,9 +157,7 @@ def to_model( email_opted_in=self.email_opted_in, email=self.email if include_private else None, hub_token=self.hub_token if include_private else None, - teams=[t.to_model(_block_recursion=True) for t in self.teams], full_name=self.full_name, created=self.created, updated=self.updated, - roles=[ra.role.to_model() for ra in self.assigned_roles], ) diff --git a/src/zenml/zen_stores/schemas/workspace_schemas.py b/src/zenml/zen_stores/schemas/workspace_schemas.py index a1b7d7bbe2c..89ba8f22168 100644 --- a/src/zenml/zen_stores/schemas/workspace_schemas.py +++ b/src/zenml/zen_stores/schemas/workspace_schemas.py @@ -40,8 +40,6 @@ StackComponentSchema, StackSchema, StepRunSchema, - TeamRoleAssignmentSchema, - UserRoleAssignmentSchema, ) @@ -52,14 +50,6 @@ class WorkspaceSchema(NamedSchema, table=True): description: str - user_role_assignments: List["UserRoleAssignmentSchema"] = Relationship( - back_populates="workspace", - sa_relationship_kwargs={"cascade": "delete"}, - ) - team_role_assignments: List["TeamRoleAssignmentSchema"] = Relationship( - back_populates="workspace", - sa_relationship_kwargs={"cascade": "all, delete"}, - ) stacks: List["StackSchema"] = Relationship( back_populates="workspace", sa_relationship_kwargs={"cascade": "delete"}, diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 1fb101d8180..0cc236d9227 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -28,6 +28,7 @@ Dict, List, Optional, + Set, Tuple, Type, TypeVar, @@ -108,10 +109,6 @@ PipelineRunResponseModel, PipelineRunUpdateModel, PipelineUpdateModel, - RoleFilterModel, - RoleRequestModel, - RoleResponseModel, - RoleUpdateModel, RunMetadataRequestModel, RunMetadataResponseModel, ScheduleRequestModel, @@ -133,20 +130,10 @@ StepRunRequestModel, StepRunResponseModel, StepRunUpdateModel, - TeamFilterModel, - TeamRequestModel, - TeamResponseModel, - TeamRoleAssignmentFilterModel, - TeamRoleAssignmentRequestModel, - TeamRoleAssignmentResponseModel, - TeamUpdateModel, UserAuthModel, UserFilterModel, UserRequestModel, UserResponseModel, - UserRoleAssignmentFilterModel, - UserRoleAssignmentRequestModel, - UserRoleAssignmentResponseModel, UserUpdateModel, WorkspaceFilterModel, WorkspaceRequestModel, @@ -171,8 +158,6 @@ ) from zenml.utils.string_utils import random_str from zenml.zen_stores.base_zen_store import ( - DEFAULT_ADMIN_ROLE, - DEFAULT_GUEST_ROLE, DEFAULT_STACK_COMPONENT_NAME, DEFAULT_STACK_NAME, BaseZenStore, @@ -194,8 +179,6 @@ PipelineDeploymentSchema, PipelineRunSchema, PipelineSchema, - RolePermissionSchema, - RoleSchema, RunMetadataSchema, ScheduleSchema, ServiceConnectorSchema, @@ -206,9 +189,6 @@ StepRunOutputArtifactSchema, StepRunParentsSchema, StepRunSchema, - TeamRoleAssignmentSchema, - TeamSchema, - UserRoleAssignmentSchema, UserSchema, WorkspaceSchema, ) @@ -235,6 +215,9 @@ logger = get_logger(__name__) ZENML_SQLITE_DB_FILENAME = "zenml.db" +ZENML_CLOUD_ONLY_FEATURE_ERROR_MESSAGE = ( + "This feature is only available in ZenML cloud." +) def _is_mysql_missing_database_error(error: OperationalError) -> bool: @@ -719,6 +702,7 @@ def filter_and_paginate( List[AnySchema], ] ] = None, + resource_ids: Optional[Set[UUID]] = None, ) -> Page[B]: """Given a query, return a Page instance with a list of filtered Models. @@ -747,6 +731,9 @@ def filter_and_paginate( """ query = filter_model.apply_filter(query=query, table=table) + if resource_ids: + query = query.where(table.id.in_(resource_ids)) + # Get the total amount of items in the database for a given query if custom_fetch: total = len(custom_fetch(session, query, filter_model)) @@ -2010,597 +1997,6 @@ def delete_user(self, user_name_or_id: Union[str, UUID]) -> None: session.delete(user) session.commit() - # ----- - # Teams - # ----- - - def create_team(self, team: TeamRequestModel) -> TeamResponseModel: - """Creates a new team. - - Args: - team: The team model to create. - - Returns: - The newly created team. - - Raises: - EntityExistsError: If a team with the given name already exists. - """ - with Session(self.engine) as session: - # Check if team with the given name already exists - existing_team = session.exec( - select(TeamSchema).where(TeamSchema.name == team.name) - ).first() - if existing_team is not None: - raise EntityExistsError( - f"Unable to create team with name '{team.name}': " - f"Found existing team with this name." - ) - - defined_users = [] - if team.users: - # Get the Schemas of all users mentioned - filters = [ - (UserSchema.id == user_id) for user_id in team.users - ] - - defined_users = session.exec( - select(UserSchema).where(or_(*filters)) - ).all() - - # Create the team - new_team = TeamSchema(name=team.name, users=defined_users) - session.add(new_team) - session.commit() - - return new_team.to_model() - - def get_team(self, team_name_or_id: Union[str, UUID]) -> TeamResponseModel: - """Gets a specific team. - - Args: - team_name_or_id: Name or ID of the team to get. - - Returns: - The requested team. - """ - with Session(self.engine) as session: - team = self._get_team_schema(team_name_or_id, session=session) - return team.to_model() - - def list_teams( - self, team_filter_model: TeamFilterModel - ) -> Page[TeamResponseModel]: - """List all teams matching the given filter criteria. - - Args: - team_filter_model: All filter parameters including pagination - params. - - Returns: - A list of all teams matching the filter criteria. - """ - with Session(self.engine) as session: - query = select(TeamSchema) - return self.filter_and_paginate( - session=session, - query=query, - table=TeamSchema, - filter_model=team_filter_model, - ) - - def update_team( - self, team_id: UUID, team_update: TeamUpdateModel - ) -> TeamResponseModel: - """Update an existing team. - - Args: - team_id: The ID of the team to be updated. - team_update: The update to be applied to the team. - - Returns: - The updated team. - - Raises: - KeyError: if the team does not exist. - """ - with Session(self.engine) as session: - existing_team = session.exec( - select(TeamSchema).where(TeamSchema.id == team_id) - ).first() - - if existing_team is None: - raise KeyError( - f"Unable to update team with id " - f"'{team_id}': Found no" - f"existing teams with this id." - ) - - # Update the team - existing_team.update(team_update=team_update) - existing_team.users = [] - if "users" in team_update.__fields_set__ and team_update.users: - for user in team_update.users: - existing_team.users.append( - self._get_user_schema( - user_name_or_id=user, session=session - ) - ) - - session.add(existing_team) - session.commit() - - # Refresh the Model that was just created - session.refresh(existing_team) - return existing_team.to_model() - - def delete_team(self, team_name_or_id: Union[str, UUID]) -> None: - """Deletes a team. - - Args: - team_name_or_id: Name or ID of the team to delete. - """ - with Session(self.engine) as session: - team = self._get_team_schema(team_name_or_id, session=session) - session.delete(team) - session.commit() - - # ----- - # Roles - # ----- - - def create_role(self, role: RoleRequestModel) -> RoleResponseModel: - """Creates a new role. - - Args: - role: The role model to create. - - Returns: - The newly created role. - - Raises: - EntityExistsError: If a role with the given name already exists. - """ - with Session(self.engine) as session: - # Check if role with the given name already exists - existing_role = session.exec( - select(RoleSchema).where(RoleSchema.name == role.name) - ).first() - if existing_role is not None: - raise EntityExistsError( - f"Unable to create role '{role.name}': Role already exists." - ) - - # Create role - role_schema = RoleSchema.from_request(role) - session.add(role_schema) - session.commit() - # Add all permissions - for p in role.permissions: - session.add( - RolePermissionSchema(name=p, role_id=role_schema.id) - ) - - session.commit() - return role_schema.to_model() - - def get_role(self, role_name_or_id: Union[str, UUID]) -> RoleResponseModel: - """Gets a specific role. - - Args: - role_name_or_id: Name or ID of the role to get. - - Returns: - The requested role. - """ - with Session(self.engine) as session: - role = self._get_role_schema(role_name_or_id, session=session) - return role.to_model() - - def list_roles( - self, role_filter_model: RoleFilterModel - ) -> Page[RoleResponseModel]: - """List all roles matching the given filter criteria. - - Args: - role_filter_model: All filter parameters including pagination - params. - - Returns: - A list of all roles matching the filter criteria. - """ - with Session(self.engine) as session: - query = select(RoleSchema) - return self.filter_and_paginate( - session=session, - query=query, - table=RoleSchema, - filter_model=role_filter_model, - ) - - def update_role( - self, role_id: UUID, role_update: RoleUpdateModel - ) -> RoleResponseModel: - """Update an existing role. - - Args: - role_id: The ID of the role to be updated. - role_update: The update to be applied to the role. - - Returns: - The updated role. - - Raises: - KeyError: if the role does not exist. - IllegalOperationError: if the role is a system role. - """ - with Session(self.engine) as session: - existing_role = session.exec( - select(RoleSchema).where(RoleSchema.id == role_id) - ).first() - - if existing_role is None: - raise KeyError( - f"Unable to update role with id " - f"'{role_id}': Found no" - f"existing roles with this id." - ) - - if existing_role.name in [DEFAULT_ADMIN_ROLE, DEFAULT_GUEST_ROLE]: - raise IllegalOperationError( - f"The built-in role '{existing_role.name}' cannot be " - f"updated." - ) - - # The relationship table for roles behaves different from the other - # ones. As such the required updates on the permissions have to be - # done manually. - if "permissions" in role_update.__fields_set__: - existing_permissions = { - p.name for p in existing_role.permissions - } - - diff = existing_permissions.symmetric_difference( - role_update.permissions - ) - - for permission in diff: - if permission not in role_update.permissions: - permission_to_delete = session.exec( - select(RolePermissionSchema) - .where(RolePermissionSchema.name == permission) - .where( - RolePermissionSchema.role_id - == existing_role.id - ) - ).one_or_none() - session.delete(permission_to_delete) - - elif permission not in existing_permissions: - session.add( - RolePermissionSchema( - name=permission, role_id=existing_role.id - ) - ) - - # Update the role - existing_role.update(role_update=role_update) - session.add(existing_role) - session.commit() - - session.commit() - - # Refresh the Model that was just created - session.refresh(existing_role) - return existing_role.to_model() - - def delete_role(self, role_name_or_id: Union[str, UUID]) -> None: - """Deletes a role. - - Args: - role_name_or_id: Name or ID of the role to delete. - - Raises: - IllegalOperationError: If the role is still assigned to users or - the role is one of the built-in roles. - """ - with Session(self.engine) as session: - role = self._get_role_schema(role_name_or_id, session=session) - if role.name in [DEFAULT_ADMIN_ROLE, DEFAULT_GUEST_ROLE]: - raise IllegalOperationError( - f"The built-in role '{role.name}' cannot be deleted." - ) - user_role = session.exec( - select(UserRoleAssignmentSchema).where( - UserRoleAssignmentSchema.role_id == role.id - ) - ).all() - team_role = session.exec( - select(TeamRoleAssignmentSchema).where( - TeamRoleAssignmentSchema.role_id == role.id - ) - ).all() - - if len(user_role) > 0 or len(team_role) > 0: - raise IllegalOperationError( - f"Role `{role.name}` of type cannot be " - f"deleted as it is in use by multiple users and teams. " - f"Before deleting this role make sure to remove all " - f"instances where this role is used." - ) - else: - # Delete role - session.delete(role) - session.commit() - - # ---------------- - # Role assignments - # ---------------- - - def list_user_role_assignments( - self, user_role_assignment_filter_model: UserRoleAssignmentFilterModel - ) -> Page[UserRoleAssignmentResponseModel]: - """List all roles assignments matching the given filter criteria. - - Args: - user_role_assignment_filter_model: All filter parameters including - pagination params. - - Returns: - A list of all roles assignments matching the filter criteria. - """ - with Session(self.engine) as session: - query = select(UserRoleAssignmentSchema) - return self.filter_and_paginate( - session=session, - query=query, - table=UserRoleAssignmentSchema, - filter_model=user_role_assignment_filter_model, - ) - - def create_user_role_assignment( - self, user_role_assignment: UserRoleAssignmentRequestModel - ) -> UserRoleAssignmentResponseModel: - """Assigns a role to a user or team, scoped to a specific workspace. - - Args: - user_role_assignment: The role assignment to create. - - Returns: - The created role assignment. - - Raises: - EntityExistsError: if the role assignment already exists. - """ - with Session(self.engine) as session: - role = self._get_role_schema( - user_role_assignment.role, session=session - ) - workspace: Optional[WorkspaceSchema] = None - if user_role_assignment.workspace: - workspace = self._get_workspace_schema( - user_role_assignment.workspace, session=session - ) - user = self._get_user_schema( - user_role_assignment.user, session=session - ) - query = select(UserRoleAssignmentSchema).where( - UserRoleAssignmentSchema.user_id == user.id, - UserRoleAssignmentSchema.role_id == role.id, - ) - if workspace is not None: - query = query.where( - UserRoleAssignmentSchema.workspace_id == workspace.id - ) - existing_role_assignment = session.exec(query).first() - if existing_role_assignment is not None: - raise EntityExistsError( - f"Unable to assign role '{role.name}' to user " - f"'{user.name}': Role already assigned in this workspace." - ) - role_assignment = UserRoleAssignmentSchema( - role_id=role.id, - user_id=user.id, - workspace_id=workspace.id if workspace else None, - role=role, - user=user, - workspace=workspace, - ) - session.add(role_assignment) - session.commit() - return role_assignment.to_model() - - def get_user_role_assignment( - self, user_role_assignment_id: UUID - ) -> UserRoleAssignmentResponseModel: - """Gets a role assignment by ID. - - Args: - user_role_assignment_id: ID of the role assignment to get. - - Returns: - The role assignment. - - Raises: - KeyError: If the role assignment does not exist. - """ - with Session(self.engine) as session: - user_role = session.exec( - select(UserRoleAssignmentSchema).where( - UserRoleAssignmentSchema.id == user_role_assignment_id - ) - ).one_or_none() - - if user_role: - return user_role.to_model() - else: - raise KeyError( - f"Unable to get user role assignment with ID " - f"'{user_role_assignment_id}': No user role assignment " - f"with this ID found." - ) - - def delete_user_role_assignment( - self, user_role_assignment_id: UUID - ) -> None: - """Delete a specific role assignment. - - Args: - user_role_assignment_id: The ID of the specific role assignment. - - Raises: - KeyError: If the role assignment does not exist. - """ - with Session(self.engine) as session: - user_role = session.exec( - select(UserRoleAssignmentSchema).where( - UserRoleAssignmentSchema.id == user_role_assignment_id - ) - ).one_or_none() - if not user_role: - raise KeyError( - f"No user role assignment with id " - f"{user_role_assignment_id} exists." - ) - - session.delete(user_role) - - session.commit() - - # --------------------- - # Team Role assignments - # --------------------- - - def create_team_role_assignment( - self, team_role_assignment: TeamRoleAssignmentRequestModel - ) -> TeamRoleAssignmentResponseModel: - """Creates a new team role assignment. - - Args: - team_role_assignment: The role assignment model to create. - - Returns: - The newly created role assignment. - - Raises: - EntityExistsError: If the role assignment already exists. - """ - with Session(self.engine) as session: - role = self._get_role_schema( - team_role_assignment.role, session=session - ) - workspace: Optional[WorkspaceSchema] = None - if team_role_assignment.workspace: - workspace = self._get_workspace_schema( - team_role_assignment.workspace, session=session - ) - team = self._get_team_schema( - team_role_assignment.team, session=session - ) - query = select(UserRoleAssignmentSchema).where( - UserRoleAssignmentSchema.user_id == team.id, - UserRoleAssignmentSchema.role_id == role.id, - ) - if workspace is not None: - query = query.where( - UserRoleAssignmentSchema.workspace_id == workspace.id - ) - existing_role_assignment = session.exec(query).first() - if existing_role_assignment is not None: - raise EntityExistsError( - f"Unable to assign role '{role.name}' to team " - f"'{team.name}': Role already assigned in this workspace." - ) - role_assignment = TeamRoleAssignmentSchema( - role_id=role.id, - team_id=team.id, - workspace_id=workspace.id if workspace else None, - role=role, - team=team, - workspace=workspace, - ) - session.add(role_assignment) - session.commit() - return role_assignment.to_model() - - def get_team_role_assignment( - self, team_role_assignment_id: UUID - ) -> TeamRoleAssignmentResponseModel: - """Gets a specific role assignment. - - Args: - team_role_assignment_id: ID of the role assignment to get. - - Returns: - The requested role assignment. - - Raises: - KeyError: If no role assignment with the given ID exists. - """ - with Session(self.engine) as session: - team_role = session.exec( - select(TeamRoleAssignmentSchema).where( - TeamRoleAssignmentSchema.id == team_role_assignment_id - ) - ).one_or_none() - - if team_role: - return team_role.to_model() - else: - raise KeyError( - f"Unable to get team role assignment with ID " - f"'{team_role_assignment_id}': No team role assignment " - f"with this ID found." - ) - - def delete_team_role_assignment( - self, team_role_assignment_id: UUID - ) -> None: - """Delete a specific role assignment. - - Args: - team_role_assignment_id: The ID of the specific role assignment - - Raises: - KeyError: If the role assignment does not exist. - """ - with Session(self.engine) as session: - team_role = session.exec( - select(TeamRoleAssignmentSchema).where( - TeamRoleAssignmentSchema.id == team_role_assignment_id - ) - ).one_or_none() - if not team_role: - raise KeyError( - f"No team role assignment with id " - f"{team_role_assignment_id} exists." - ) - - session.delete(team_role) - - session.commit() - - def list_team_role_assignments( - self, team_role_assignment_filter_model: TeamRoleAssignmentFilterModel - ) -> Page[TeamRoleAssignmentResponseModel]: - """List all roles assignments matching the given filter criteria. - - Args: - team_role_assignment_filter_model: All filter parameters including - pagination params. - - Returns: - A list of all roles assignments matching the filter criteria. - """ - with Session(self.engine) as session: - query = select(TeamRoleAssignmentSchema) - return self.filter_and_paginate( - session=session, - query=query, - table=TeamRoleAssignmentSchema, - filter_model=team_role_assignment_filter_model, - ) - # -------- # Workspaces # -------- @@ -5241,53 +4637,53 @@ def _get_user_schema( session=session, ) - def _get_team_schema( - self, - team_name_or_id: Union[str, UUID], - session: Session, - ) -> TeamSchema: - """Gets a team schema by name or ID. - - This is a helper method that is used in various places to find a team - by its name or ID. - - Args: - team_name_or_id: The name or ID of the team to get. - session: The database session to use. - - Returns: - The team schema. - """ - return self._get_schema_by_name_or_id( - object_name_or_id=team_name_or_id, - schema_class=TeamSchema, - schema_name="team", - session=session, - ) - - def _get_role_schema( - self, - role_name_or_id: Union[str, UUID], - session: Session, - ) -> RoleSchema: - """Gets a role schema by name or ID. - - This is a helper method that is used in various places to find a role - by its name or ID. - - Args: - role_name_or_id: The name or ID of the role to get. - session: The database session to use. - - Returns: - The role schema. - """ - return self._get_schema_by_name_or_id( - object_name_or_id=role_name_or_id, - schema_class=RoleSchema, - schema_name="role", - session=session, - ) + # def _get_team_schema( + # self, + # team_name_or_id: Union[str, UUID], + # session: Session, + # ) -> TeamSchema: + # """Gets a team schema by name or ID. + + # This is a helper method that is used in various places to find a team + # by its name or ID. + + # Args: + # team_name_or_id: The name or ID of the team to get. + # session: The database session to use. + + # Returns: + # The team schema. + # """ + # return self._get_schema_by_name_or_id( + # object_name_or_id=team_name_or_id, + # schema_class=TeamSchema, + # schema_name="team", + # session=session, + # ) + + # def _get_role_schema( + # self, + # role_name_or_id: Union[str, UUID], + # session: Session, + # ) -> RoleSchema: + # """Gets a role schema by name or ID. + + # This is a helper method that is used in various places to find a role + # by its name or ID. + + # Args: + # role_name_or_id: The name or ID of the role to get. + # session: The database session to use. + + # Returns: + # The role schema. + # """ + # return self._get_schema_by_name_or_id( + # object_name_or_id=role_name_or_id, + # schema_class=RoleSchema, + # schema_name="role", + # session=session, + # ) def _get_run_schema( self, diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 336e6d49874..3f9db718154 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -46,10 +46,6 @@ PipelineRunResponseModel, PipelineRunUpdateModel, PipelineUpdateModel, - RoleFilterModel, - RoleRequestModel, - RoleResponseModel, - RoleUpdateModel, RunMetadataRequestModel, RunMetadataResponseModel, ScheduleRequestModel, @@ -68,19 +64,9 @@ StepRunRequestModel, StepRunResponseModel, StepRunUpdateModel, - TeamFilterModel, - TeamRequestModel, - TeamResponseModel, - TeamRoleAssignmentFilterModel, - TeamRoleAssignmentRequestModel, - TeamRoleAssignmentResponseModel, - TeamUpdateModel, UserFilterModel, UserRequestModel, UserResponseModel, - UserRoleAssignmentFilterModel, - UserRoleAssignmentRequestModel, - UserRoleAssignmentResponseModel, UserUpdateModel, WorkspaceFilterModel, WorkspaceRequestModel, @@ -492,263 +478,6 @@ def delete_user(self, user_name_or_id: Union[str, UUID]) -> None: KeyError: If no user with the given ID exists. """ - # ----- - # Teams - # ----- - - @abstractmethod - def create_team(self, team: TeamRequestModel) -> TeamResponseModel: - """Creates a new team. - - Args: - team: The team model to create. - - Returns: - The newly created team. - """ - - @abstractmethod - def get_team(self, team_name_or_id: Union[str, UUID]) -> TeamResponseModel: - """Gets a specific team. - - Args: - team_name_or_id: Name or ID of the team to get. - - Returns: - The requested team. - - Raises: - KeyError: If no team with the given name or ID exists. - """ - - @abstractmethod - def list_teams( - self, team_filter_model: TeamFilterModel - ) -> Page[TeamResponseModel]: - """List all teams matching the given filter criteria. - - Args: - team_filter_model: All filter parameters including pagination - params. - - Returns: - A list of all teams matching the filter criteria. - """ - - @abstractmethod - def update_team( - self, team_id: UUID, team_update: TeamUpdateModel - ) -> TeamResponseModel: - """Update an existing team. - - Args: - team_id: The ID of the team to be updated. - team_update: The update to be applied to the team. - - Returns: - The updated team. - - Raises: - KeyError: if the team does not exist. - """ - - @abstractmethod - def delete_team(self, team_name_or_id: Union[str, UUID]) -> None: - """Deletes a team. - - Args: - team_name_or_id: Name or ID of the team to delete. - - Raises: - KeyError: If no team with the given ID exists. - """ - - # ----- - # Roles - # ----- - - @abstractmethod - def create_role(self, role: RoleRequestModel) -> RoleResponseModel: - """Creates a new role. - - Args: - role: The role model to create. - - Returns: - The newly created role. - - Raises: - EntityExistsError: If a role with the given name already exists. - """ - - @abstractmethod - def get_role(self, role_name_or_id: Union[str, UUID]) -> RoleResponseModel: - """Gets a specific role. - - Args: - role_name_or_id: Name or ID of the role to get. - - Returns: - The requested role. - - Raises: - KeyError: If no role with the given name exists. - """ - - @abstractmethod - def list_roles( - self, role_filter_model: RoleFilterModel - ) -> Page[RoleResponseModel]: - """List all roles matching the given filter criteria. - - Args: - role_filter_model: All filter parameters including pagination - params. - - Returns: - A list of all roles matching the filter criteria. - """ - - @abstractmethod - def update_role( - self, role_id: UUID, role_update: RoleUpdateModel - ) -> RoleResponseModel: - """Update an existing role. - - Args: - role_id: The ID of the role to be updated. - role_update: The update to be applied to the role. - - Returns: - The updated role. - - Raises: - KeyError: if the role does not exist. - """ - - @abstractmethod - def delete_role(self, role_name_or_id: Union[str, UUID]) -> None: - """Deletes a role. - - Args: - role_name_or_id: Name or ID of the role to delete. - - Raises: - KeyError: If no role with the given ID exists. - """ - - # --------------------- - # User Role assignments - # --------------------- - @abstractmethod - def create_user_role_assignment( - self, user_role_assignment: UserRoleAssignmentRequestModel - ) -> UserRoleAssignmentResponseModel: - """Creates a new role assignment. - - Args: - user_role_assignment: The role assignment model to create. - - Returns: - The newly created role assignment. - """ - - @abstractmethod - def get_user_role_assignment( - self, user_role_assignment_id: UUID - ) -> UserRoleAssignmentResponseModel: - """Gets a specific role assignment. - - Args: - user_role_assignment_id: ID of the role assignment to get. - - Returns: - The requested role assignment. - - Raises: - KeyError: If no role assignment with the given ID exists. - """ - - @abstractmethod - def delete_user_role_assignment( - self, user_role_assignment_id: UUID - ) -> None: - """Delete a specific role assignment. - - Args: - user_role_assignment_id: The ID of the specific role assignment - """ - - @abstractmethod - def list_user_role_assignments( - self, user_role_assignment_filter_model: UserRoleAssignmentFilterModel - ) -> Page[UserRoleAssignmentResponseModel]: - """List all roles assignments matching the given filter criteria. - - Args: - user_role_assignment_filter_model: All filter parameters including - pagination params. - - Returns: - A list of all roles assignments matching the filter criteria. - """ - - # --------------------- - # Team Role assignments - # --------------------- - @abstractmethod - def create_team_role_assignment( - self, team_role_assignment: TeamRoleAssignmentRequestModel - ) -> TeamRoleAssignmentResponseModel: - """Creates a new team role assignment. - - Args: - team_role_assignment: The role assignment model to create. - - Returns: - The newly created role assignment. - """ - - @abstractmethod - def get_team_role_assignment( - self, team_role_assignment_id: UUID - ) -> TeamRoleAssignmentResponseModel: - """Gets a specific role assignment. - - Args: - team_role_assignment_id: ID of the role assignment to get. - - Returns: - The requested role assignment. - - Raises: - KeyError: If no role assignment with the given ID exists. - """ - - @abstractmethod - def delete_team_role_assignment( - self, team_role_assignment_id: UUID - ) -> None: - """Delete a specific role assignment. - - Args: - team_role_assignment_id: The ID of the specific role assignment - """ - - @abstractmethod - def list_team_role_assignments( - self, team_role_assignment_filter_model: TeamRoleAssignmentFilterModel - ) -> Page[TeamRoleAssignmentResponseModel]: - """List all roles assignments matching the given filter criteria. - - Args: - team_role_assignment_filter_model: All filter parameters including - pagination params. - - Returns: - A list of all roles assignments matching the filter criteria. - """ - # -------- # Workspaces # -------- From 667af0691cec63420e956f240953c8ba02e2aabc Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 13 Oct 2023 16:56:24 +0200 Subject: [PATCH 002/103] Add dehydration POC --- src/zenml/models/base_models.py | 3 + src/zenml/zen_server/auth.py | 147 ++++++++++++++- src/zenml/zen_server/rbac_interface.py | 18 +- .../zen_server/routers/pipelines_endpoints.py | 10 +- .../routers/role_assignments_endpoints.py | 127 ------------- .../zen_server/routers/roles_endpoints.py | 147 --------------- .../zen_server/routers/stacks_endpoints.py | 18 +- .../team_role_assignments_endpoints.py | 132 ------------- .../zen_server/routers/teams_endpoints.py | 178 ------------------ .../zen_server/routers/users_endpoints.py | 28 --- .../routers/workspaces_endpoints.py | 66 ------- 11 files changed, 170 insertions(+), 704 deletions(-) delete mode 100644 src/zenml/zen_server/routers/role_assignments_endpoints.py delete mode 100644 src/zenml/zen_server/routers/roles_endpoints.py delete mode 100644 src/zenml/zen_server/routers/team_role_assignments_endpoints.py delete mode 100644 src/zenml/zen_server/routers/teams_endpoints.py diff --git a/src/zenml/models/base_models.py b/src/zenml/models/base_models.py index df0551797c7..e0448f3dd74 100644 --- a/src/zenml/models/base_models.py +++ b/src/zenml/models/base_models.py @@ -76,6 +76,9 @@ class BaseResponseModel(BaseZenModel): title="Time when this resource was last updated." ) + missing_permissions: bool = False + partial: bool = False + def __hash__(self) -> int: """Implementation of hash magic method. diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 355e50f8caa..35cfbb33aa0 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -16,13 +16,7 @@ import os from contextvars import ContextVar from datetime import datetime -from typing import ( - Callable, - List, - Optional, - Set, - Union, -) +from typing import Any, Callable, List, Optional, Set, Union from urllib.parse import urlencode from uuid import UUID @@ -56,9 +50,10 @@ UserResponseModel, UserUpdateModel, ) -from zenml.models.user_models import JWTToken, UserAuthModel +from zenml.models.base_models import BaseResponseModel, UserScopedResponseModel +from zenml.models.user_models import UserAuthModel from zenml.zen_server.jwt import JWTToken -from zenml.zen_server.rbac_interface import Resource +from zenml.zen_server.rbac_interface import RESOURCE_TYPE_MAPPING, Resource from zenml.zen_server.utils import rbac, server_config, zen_store from zenml.zen_stores.base_zen_store import DEFAULT_USERNAME @@ -679,6 +674,135 @@ def authentication_provider() -> Callable[..., AuthContext]: authorize = authentication_provider() +from enum import Enum +from typing import Dict, List, Set, Tuple + + +def dehydrate_response_model( + model: "BaseResponseModel", +) -> "BaseResponseModel": + dehydrated_fields = {} + did_dehydrate = False + + for field_name in model.__fields__.keys(): + value = getattr(model, field_name) + new_value, value_dehydrated = _maybe_dehydrate_value(value) + dehydrated_fields[field_name] = new_value + did_dehydrate = did_dehydrate or value_dehydrated + + if did_dehydrate: + dehydrated_fields["partial"] = True + + return type(model).parse_obj(dehydrated_fields) + + +def _maybe_dehydrate_value(value: Any) -> Tuple[Any, bool]: + if isinstance(value, BaseResponseModel): + if has_read_permissions_for_model(value): + dehydrated_model = dehydrate_response_model(value) + return dehydrated_model, dehydrated_model.partial + else: + return get_403_model(value), True + elif isinstance(value, Dict): + dict_ = {} + did_dehydrate = False + for k, v in value.items(): + dict_[k], d = _maybe_dehydrate_value(v) + did_dehydrate = did_dehydrate or d + return dict_, did_dehydrate + elif isinstance(value, (List, Set, Tuple)): + items = [] + did_dehydrate = False + for v in value: + item, d = _maybe_dehydrate_value(v) + items.append(item) + did_dehydrate = did_dehydrate or d + + type_ = type(value) + return type_(items), did_dehydrate + else: + return value, False + + +def has_read_permissions_for_model(model: "BaseResponseModel") -> bool: + try: + verify_permissions_for_model(model=model, action="READ") + return True + except HTTPException: + return False + + +def get_403_model( + model: "BaseResponseModel", keep_name: bool = True +) -> "BaseResponseModel": + values = {} + + for field_name, field in model.__fields__.items(): + value = getattr(model, field_name) + + if keep_name and field_name == "name" and isinstance(value, str): + pass + elif field.allow_none: + value = None + elif isinstance(value, BaseResponseModel): + value = get_403_model(value, keep_name=False) + elif isinstance(value, UUID): + value = UUID(int=0) + elif isinstance(value, datetime): + value = datetime.utcnow() + elif isinstance(value, Enum): + # TODO: handle enums in a more sensible way + value = list(type(value))[0] + else: + type_ = type(value) + # For the remaining cases (dict, list, set, tuple, int, float, str), + # simply return an empty value + value = type_() + + values[field_name] = value + + # TODO: With the new hydration models, make sure we clear metadata here + values["missing_permissions"] = True + + return type(model).parse_obj(values) + + +def verify_permissions_for_model( + model: "BaseResponseModel", + action: str, +) -> None: + """Verifies if a user has permissions to perform an action on a resource. + + Args: + resource: The resource type the user wants to perform the action on. + action: The action the user wants to perform. + resource_id: ID of the resource the user wants to perform the action on. + + Raises: + HTTPException: If the user is not allowed to perform the action. + """ + if "ZENML_CLOUD" not in os.environ: + return + + if ( + isinstance(model, UserScopedResponseModel) + and model.user + and model.user.id == get_auth_context().user.id + ): + # User is the owner of the model + return + + resource_type = RESOURCE_TYPE_MAPPING.get(type(model)) + if not resource_type: + # This model is not tied to any RBAC resource type and therefore doesn't + # require any special permissions + return + + verify_permissions( + resource_type=resource_type, resource_id=model.id, action=action + ) + + def verify_permissions( resource_type: str, action: str, @@ -697,6 +821,11 @@ def verify_permissions( if "ZENML_CLOUD" not in os.environ: return + if resource_type != "stack": + raise HTTPException(status_code=403) + + return + user_id = get_auth_context().user.external_user_id assert user_id resource = Resource(type=resource_type, id=resource_id) diff --git a/src/zenml/zen_server/rbac_interface.py b/src/zenml/zen_server/rbac_interface.py index bda6b3721b5..332d234a21b 100644 --- a/src/zenml/zen_server/rbac_interface.py +++ b/src/zenml/zen_server/rbac_interface.py @@ -1,10 +1,12 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Type from uuid import UUID from pydantic import BaseModel from zenml.enums import StrEnum +from zenml.models import ComponentResponseModel, StackResponseModel +from zenml.models.base_models import BaseResponseModel class Action(StrEnum): @@ -16,8 +18,20 @@ class Action(StrEnum): class ResourceType(StrEnum): STACK = "stack" - COMPONENT = "component" + FLAVOR = "flavor" + STACK_COMPONENT = "stack_component" PIPELINE = "pipeline" + CODE_REPOSITORY = "code-repository" + MODEL = "model" + SERVICE_CONNECTOR = "service_connector" + ARTIFACT = "artifact" + SECRET = "secret" + + +RESOURCE_TYPE_MAPPING: Dict[Type[BaseResponseModel], ResourceType] = { + StackResponseModel: ResourceType.STACK, + ComponentResponseModel: ResourceType.STACK_COMPONENT, +} class Resource(BaseModel): diff --git a/src/zenml/zen_server/routers/pipelines_endpoints.py b/src/zenml/zen_server/routers/pipelines_endpoints.py index ca9d3e0cb54..f90cdd17a22 100644 --- a/src/zenml/zen_server/routers/pipelines_endpoints.py +++ b/src/zenml/zen_server/routers/pipelines_endpoints.py @@ -27,7 +27,7 @@ PipelineUpdateModel, ) from zenml.models.page_model import Page -from zenml.zen_server.auth import AuthContext, authorize, verify_permissions +from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response from zenml.zen_server.utils import ( handle_exceptions, @@ -100,14 +100,6 @@ def get_pipeline( Returns: A specific pipeline object. """ - from zenml.zen_server.rbac_interface import Action, ResourceType - - verify_permissions( - resource_type=ResourceType.PIPELINE, - action=Action.READ, - resource_id=pipeline_id, - ) - return zen_store().get_pipeline(pipeline_id=pipeline_id) diff --git a/src/zenml/zen_server/routers/role_assignments_endpoints.py b/src/zenml/zen_server/routers/role_assignments_endpoints.py deleted file mode 100644 index 6ae10af7d21..00000000000 --- a/src/zenml/zen_server/routers/role_assignments_endpoints.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Endpoint definitions for role assignments.""" -from uuid import UUID - -from fastapi import APIRouter, Depends, Security - -from zenml.constants import API, USER_ROLE_ASSIGNMENTS, VERSION_1 -from zenml.enums import PermissionType -from zenml.models import ( - UserRoleAssignmentFilterModel, - UserRoleAssignmentRequestModel, - UserRoleAssignmentResponseModel, -) -from zenml.models.page_model import Page -from zenml.zen_server.auth import AuthContext, authorize -from zenml.zen_server.exceptions import error_response -from zenml.zen_server.utils import handle_exceptions, make_dependable, rbac - -router = APIRouter( - prefix=API + VERSION_1 + USER_ROLE_ASSIGNMENTS, - tags=["role_assignments"], - responses={401: error_response}, -) - - -@router.get( - "", - response_model=Page[UserRoleAssignmentResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_user_role_assignments( - user_role_assignment_filter_model: UserRoleAssignmentFilterModel = Depends( - make_dependable(UserRoleAssignmentFilterModel) - ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[UserRoleAssignmentResponseModel]: - """Returns a list of all role assignments. - - Args: - user_role_assignment_filter_model: filter models for user role assignments - - Returns: - List of all role assignments. - """ - return rbac().list_user_role_assignments( - user_role_assignment_filter_model=user_role_assignment_filter_model - ) - - -@router.post( - "", - response_model=UserRoleAssignmentResponseModel, - responses={401: error_response, 409: error_response, 422: error_response}, -) -@handle_exceptions -def create_role_assignment( - role_assignment: UserRoleAssignmentRequestModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> UserRoleAssignmentResponseModel: - """Creates a role assignment. - - # noqa: DAR401 - - Args: - role_assignment: Role assignment to create. - - Returns: - The created role assignment. - """ - return rbac().create_user_role_assignment( - user_role_assignment=role_assignment - ) - - -@router.get( - "/{role_assignment_id}", - response_model=UserRoleAssignmentResponseModel, - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def get_role_assignment( - role_assignment_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> UserRoleAssignmentResponseModel: - """Returns a specific role assignment. - - Args: - role_assignment_id: Name or ID of the role assignment. - - Returns: - A specific role assignment. - """ - return rbac().get_user_role_assignment( - user_role_assignment_id=role_assignment_id - ) - - -@router.delete( - "/{role_assignment_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_role_assignment( - role_assignment_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Deletes a specific role. - - Args: - role_assignment_id: The ID of the role assignment. - """ - rbac().delete_user_role_assignment( - user_role_assignment_id=role_assignment_id - ) diff --git a/src/zenml/zen_server/routers/roles_endpoints.py b/src/zenml/zen_server/routers/roles_endpoints.py deleted file mode 100644 index b877645d336..00000000000 --- a/src/zenml/zen_server/routers/roles_endpoints.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Endpoint definitions for roles and role assignment.""" -from typing import Union -from uuid import UUID - -from fastapi import APIRouter, Depends, Security - -from zenml.constants import API, ROLES, VERSION_1 -from zenml.enums import PermissionType -from zenml.models import ( - RoleFilterModel, - RoleRequestModel, - RoleResponseModel, - RoleUpdateModel, -) -from zenml.models.page_model import Page -from zenml.zen_server.auth import AuthContext, authorize -from zenml.zen_server.exceptions import error_response -from zenml.zen_server.utils import handle_exceptions, make_dependable, rbac - -router = APIRouter( - prefix=API + VERSION_1 + ROLES, - tags=["roles"], - responses={401: error_response}, -) - - -@router.get( - "", - response_model=Page[RoleResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_roles( - role_filter_model: RoleFilterModel = Depends( - make_dependable(RoleFilterModel) - ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[RoleResponseModel]: - """Returns a list of all roles. - - Args: - role_filter_model: Filter model used for pagination, sorting, filtering - - - Returns: - List of all roles. - """ - return rbac().list_roles(role_filter_model=role_filter_model) - - -@router.post( - "", - response_model=RoleResponseModel, - responses={401: error_response, 409: error_response, 422: error_response}, -) -@handle_exceptions -def create_role( - role: RoleRequestModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> RoleResponseModel: - """Creates a role. - - # noqa: DAR401 - - Args: - role: Role to create. - - Returns: - The created role. - """ - return rbac().create_role(role=role) - - -@router.get( - "/{role_name_or_id}", - response_model=RoleResponseModel, - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def get_role( - role_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> RoleResponseModel: - """Returns a specific role. - - Args: - role_name_or_id: Name or ID of the role. - - Returns: - A specific role. - """ - return rbac().get_role(role_name_or_id=role_name_or_id) - - -@router.put( - "/{role_id}", - response_model=RoleResponseModel, - responses={401: error_response, 409: error_response, 422: error_response}, -) -@handle_exceptions -def update_role( - role_id: UUID, - role_update: RoleUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> RoleResponseModel: - """Updates a role. - - # noqa: DAR401 - - Args: - role_id: The ID of the role. - role_update: Role update. - - Returns: - The created role. - """ - return rbac().update_role(role_id=role_id, role_update=role_update) - - -@router.delete( - "/{role_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_role( - role_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Deletes a specific role. - - Args: - role_name_or_id: Name or ID of the role. - """ - rbac().delete_role(role_name_or_id=role_name_or_id) diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index 98b21a6e171..ed6ee348897 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -21,7 +21,12 @@ from zenml.enums import PermissionType from zenml.models import StackFilterModel, StackResponseModel, StackUpdateModel from zenml.models.page_model import Page -from zenml.zen_server.auth import AuthContext, authorize, verify_permissions +from zenml.zen_server.auth import ( + AuthContext, + authorize, + dehydrate_response_model, + verify_permissions_for_model, +) from zenml.zen_server.exceptions import error_response from zenml.zen_server.utils import ( handle_exceptions, @@ -89,15 +94,16 @@ def get_stack( Returns: The requested stack. """ - from zenml.zen_server.rbac_interface import Action, ResourceType + from zenml.zen_server.rbac_interface import Action + + stack = zen_store().get_stack(stack_id) - verify_permissions( - resource_type=ResourceType.STACK, + verify_permissions_for_model( + model=stack, action=Action.READ, - resource_id=stack_id, ) - return zen_store().get_stack(stack_id) + return dehydrate_response_model(stack) @router.put( diff --git a/src/zenml/zen_server/routers/team_role_assignments_endpoints.py b/src/zenml/zen_server/routers/team_role_assignments_endpoints.py deleted file mode 100644 index 9d3cdf6cd34..00000000000 --- a/src/zenml/zen_server/routers/team_role_assignments_endpoints.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -# """Endpoint definitions for role assignments.""" -# from uuid import UUID - -# from fastapi import APIRouter, Depends, Security - -# from zenml.constants import API, TEAM_ROLE_ASSIGNMENTS, VERSION_1 -# from zenml.enums import PermissionType -# from zenml.models import ( -# TeamRoleAssignmentFilterModel, -# TeamRoleAssignmentRequestModel, -# TeamRoleAssignmentResponseModel, -# ) -# from zenml.models.page_model import Page -# from zenml.zen_server.auth import AuthContext, authorize -# from zenml.zen_server.exceptions import error_response -# from zenml.zen_server.utils import ( -# handle_exceptions, -# make_dependable, -# zen_store, -# ) - -# router = APIRouter( -# prefix=API + VERSION_1 + TEAM_ROLE_ASSIGNMENTS, -# tags=["team_role_assignments"], -# responses={401: error_response}, -# ) - - -# @router.get( -# "", -# response_model=Page[TeamRoleAssignmentResponseModel], -# responses={401: error_response, 404: error_response, 422: error_response}, -# ) -# @handle_exceptions -# def list_team_role_assignments( -# team_role_assignment_filter_model: TeamRoleAssignmentFilterModel = Depends( -# make_dependable(TeamRoleAssignmentFilterModel) -# ), -# _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -# ) -> Page[TeamRoleAssignmentResponseModel]: -# """Returns a list of all role assignments. - -# Args: -# team_role_assignment_filter_model: filter models for team role assignments - - -# Returns: -# List of all role assignments. -# """ -# return zen_store().list_team_role_assignments( -# team_role_assignment_filter_model=team_role_assignment_filter_model -# ) - - -# @router.post( -# "", -# response_model=TeamRoleAssignmentResponseModel, -# responses={401: error_response, 409: error_response, 422: error_response}, -# ) -# @handle_exceptions -# def create_team_role_assignment( -# role_assignment: TeamRoleAssignmentRequestModel, -# _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -# ) -> TeamRoleAssignmentResponseModel: -# """Creates a role assignment. - -# # noqa: DAR401 - -# Args: -# role_assignment: Role assignment to create. - -# Returns: -# The created role assignment. -# """ -# return zen_store().create_team_role_assignment( -# team_role_assignment=role_assignment -# ) - - -# @router.get( -# "/{role_assignment_id}", -# response_model=TeamRoleAssignmentResponseModel, -# responses={401: error_response, 404: error_response, 422: error_response}, -# ) -# @handle_exceptions -# def get_team_role_assignment( -# role_assignment_id: UUID, -# _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -# ) -> TeamRoleAssignmentResponseModel: -# """Returns a specific role assignment. - -# Args: -# role_assignment_id: Name or ID of the role assignment. - -# Returns: -# A specific role assignment. -# """ -# return zen_store().get_team_role_assignment( -# team_role_assignment_id=role_assignment_id -# ) - - -# @router.delete( -# "/{role_assignment_id}", -# responses={401: error_response, 404: error_response, 422: error_response}, -# ) -# @handle_exceptions -# def delete_team_role_assignment( -# role_assignment_id: UUID, -# _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -# ) -> None: -# """Deletes a specific role. - -# Args: -# role_assignment_id: The ID of the role assignment. -# """ -# zen_store().delete_team_role_assignment( -# team_role_assignment_id=role_assignment_id -# ) diff --git a/src/zenml/zen_server/routers/teams_endpoints.py b/src/zenml/zen_server/routers/teams_endpoints.py deleted file mode 100644 index 9712749a2d7..00000000000 --- a/src/zenml/zen_server/routers/teams_endpoints.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Endpoint definitions for teams and team membership.""" -from typing import Union -from uuid import UUID - -from fastapi import APIRouter, Depends, Security - -from zenml.constants import API, ROLES, TEAMS, VERSION_1 -from zenml.enums import PermissionType -from zenml.models import ( - TeamFilterModel, - TeamRequestModel, - TeamResponseModel, - TeamRoleAssignmentFilterModel, - TeamRoleAssignmentResponseModel, - TeamUpdateModel, -) -from zenml.models.page_model import Page -from zenml.zen_server.auth import AuthContext, authorize -from zenml.zen_server.exceptions import error_response -from zenml.zen_server.utils import ( - handle_exceptions, - make_dependable, - zen_store, -) - -router = APIRouter( - prefix=API + VERSION_1 + TEAMS, - tags=["teams"], - responses={401: error_response}, -) - - -@router.get( - "", - response_model=Page[TeamResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_teams( - team_filter_model: TeamFilterModel = Depends( - make_dependable(TeamFilterModel) - ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[TeamResponseModel]: - """Returns a list of all teams. - - Args: - team_filter_model: All filter parameters including pagination params. - - Returns: - List of all teams. - """ - return zen_store().list_teams(team_filter_model) - - -@router.post( - "", - response_model=TeamResponseModel, - responses={401: error_response, 409: error_response, 422: error_response}, -) -@handle_exceptions -def create_team( - team: TeamRequestModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> TeamResponseModel: - """Creates a team. - - # noqa: DAR401 - - Args: - team: Team to create. - - Returns: - The created team. - """ - return zen_store().create_team(team=team) - - -@router.get( - "/{team_name_or_id}", - response_model=TeamResponseModel, - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def get_team( - team_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> TeamResponseModel: - """Returns a specific team. - - Args: - team_name_or_id: Name or ID of the team. - - Returns: - A specific team. - """ - return zen_store().get_team(team_name_or_id=team_name_or_id) - - -@router.put( - "/{team_id}", - response_model=TeamResponseModel, - responses={401: error_response, 409: error_response, 422: error_response}, -) -@handle_exceptions -def update_team( - team_id: UUID, - team_update: TeamUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> TeamResponseModel: - """Updates a team. - - # noqa: DAR401 - - Args: - team_id: ID of the team to update. - team_update: Team update. - - Returns: - The updated team. - """ - return zen_store().update_team(team_id=team_id, team_update=team_update) - - -@router.delete( - "/{team_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_team( - team_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Deletes a specific team. - - Args: - team_name_or_id: Name or ID of the team. - """ - zen_store().delete_team(team_name_or_id=team_name_or_id) - - -@router.get( - "/{team_name_or_id}" + ROLES, - response_model=Page[TeamRoleAssignmentResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_role_assignments_for_team( - team_role_assignment_filter_model: TeamRoleAssignmentFilterModel = Depends( - make_dependable(TeamRoleAssignmentFilterModel) - ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[TeamRoleAssignmentResponseModel]: - """Returns a list of all roles that are assigned to a team. - - Args: - team_role_assignment_filter_model: All filter parameters including - pagination params. - - Returns: - A list of all roles that are assigned to a team. - """ - return zen_store().list_team_role_assignments( - team_role_assignment_filter_model - ) diff --git a/src/zenml/zen_server/routers/users_endpoints.py b/src/zenml/zen_server/routers/users_endpoints.py index 89d6264aad9..914a6766c14 100644 --- a/src/zenml/zen_server/routers/users_endpoints.py +++ b/src/zenml/zen_server/routers/users_endpoints.py @@ -24,7 +24,6 @@ API, DEACTIVATE, EMAIL_ANALYTICS, - ROLES, USERS, VERSION_1, ) @@ -35,8 +34,6 @@ UserFilterModel, UserRequestModel, UserResponseModel, - UserRoleAssignmentFilterModel, - UserRoleAssignmentResponseModel, UserUpdateModel, ) from zenml.models.page_model import Page @@ -365,31 +362,6 @@ def email_opt_in_response( ) -@router.get( - "/{user_name_or_id}" + ROLES, - response_model=Page[UserRoleAssignmentResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_role_assignments_for_user( - user_role_assignment_filter_model: UserRoleAssignmentFilterModel = Depends( - make_dependable(UserRoleAssignmentFilterModel) - ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[UserRoleAssignmentResponseModel]: - """Returns a list of all roles that are assigned to a user. - - Args: - user_role_assignment_filter_model: filter models for user role assignments - - Returns: - A list of all roles that are assigned to a user. - """ - return zen_store().list_user_role_assignments( - user_role_assignment_filter_model=user_role_assignment_filter_model - ) - - @current_user_router.get( "/current-user", response_model=UserResponseModel, diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 48be1544610..6a3249f4964 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -36,8 +36,6 @@ STACK_COMPONENTS, STACKS, STATISTICS, - TEAM_ROLE_ASSIGNMENTS, - USER_ROLE_ASSIGNMENTS, VERSION_1, WORKSPACES, ) @@ -87,10 +85,6 @@ StackFilterModel, StackRequestModel, StackResponseModel, - TeamRoleAssignmentFilterModel, - TeamRoleAssignmentResponseModel, - UserRoleAssignmentFilterModel, - UserRoleAssignmentResponseModel, WorkspaceFilterModel, WorkspaceRequestModel, WorkspaceResponseModel, @@ -229,66 +223,6 @@ def delete_workspace( zen_store().delete_workspace(workspace_name_or_id=workspace_name_or_id) -@router.get( - WORKSPACES + "/{workspace_name_or_id}" + USER_ROLE_ASSIGNMENTS, - response_model=Page[UserRoleAssignmentResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_user_role_assignments_for_workspace( - workspace_name_or_id: Union[str, UUID], - user_role_assignment_filter_model: UserRoleAssignmentFilterModel = Depends( - make_dependable(UserRoleAssignmentFilterModel) - ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[UserRoleAssignmentResponseModel]: - """Returns a list of all roles that are assigned to a team. - - Args: - workspace_name_or_id: Name or ID of the workspace. - user_role_assignment_filter_model: Filter model used for pagination, - sorting, filtering - - Returns: - A list of all roles that are assigned to a team. - """ - workspace = zen_store().get_workspace(workspace_name_or_id) - user_role_assignment_filter_model.workspace_id = workspace.id - return zen_store().list_user_role_assignments( - user_role_assignment_filter_model=user_role_assignment_filter_model - ) - - -@router.get( - WORKSPACES + "/{workspace_name_or_id}" + TEAM_ROLE_ASSIGNMENTS, - response_model=Page[TeamRoleAssignmentResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_team_role_assignments_for_workspace( - workspace_name_or_id: Union[str, UUID], - team_role_assignment_filter_model: TeamRoleAssignmentFilterModel = Depends( - make_dependable(TeamRoleAssignmentFilterModel) - ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[TeamRoleAssignmentResponseModel]: - """Returns a list of all roles that are assigned to a team. - - Args: - workspace_name_or_id: Name or ID of the workspace. - team_role_assignment_filter_model: Filter model used for pagination, - sorting, filtering - - Returns: - A list of all roles that are assigned to a team. - """ - workspace = zen_store().get_workspace(workspace_name_or_id) - team_role_assignment_filter_model.workspace_id = workspace.id - return zen_store().list_team_role_assignments( - team_role_assignment_filter_model=team_role_assignment_filter_model - ) - - @router.get( WORKSPACES + "/{workspace_name_or_id}" + STACKS, response_model=Page[StackResponseModel], From 718303430fa6ce0b66dd05aa14667a94329bfec8 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 16 Oct 2023 15:09:33 +0200 Subject: [PATCH 003/103] rename default stack/components --- src/zenml/models/base_models.py | 1 - src/zenml/zen_server/auth.py | 49 +++++------- src/zenml/zen_server/rbac_interface.py | 33 ++++++-- .../zen_server/routers/stacks_endpoints.py | 20 +++-- src/zenml/zen_stores/base_zen_store.py | 79 ++++++++++--------- 5 files changed, 99 insertions(+), 83 deletions(-) diff --git a/src/zenml/models/base_models.py b/src/zenml/models/base_models.py index e0448f3dd74..8ac8ab82029 100644 --- a/src/zenml/models/base_models.py +++ b/src/zenml/models/base_models.py @@ -77,7 +77,6 @@ class BaseResponseModel(BaseZenModel): ) missing_permissions: bool = False - partial: bool = False def __hash__(self) -> int: """Implementation of hash magic method. diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 35cfbb33aa0..58685b9c9c9 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -16,7 +16,8 @@ import os from contextvars import ContextVar from datetime import datetime -from typing import Any, Callable, List, Optional, Set, Union +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from urllib.parse import urlencode from uuid import UUID @@ -53,7 +54,10 @@ from zenml.models.base_models import BaseResponseModel, UserScopedResponseModel from zenml.models.user_models import UserAuthModel from zenml.zen_server.jwt import JWTToken -from zenml.zen_server.rbac_interface import RESOURCE_TYPE_MAPPING, Resource +from zenml.zen_server.rbac_interface import ( + Resource, + get_resource_type_for_model, +) from zenml.zen_server.utils import rbac, server_config, zen_store from zenml.zen_stores.base_zen_store import DEFAULT_USERNAME @@ -674,52 +678,37 @@ def authentication_provider() -> Callable[..., AuthContext]: authorize = authentication_provider() -from enum import Enum -from typing import Dict, List, Set, Tuple +def verify_read_permissions_and_dehydrate( + model: "BaseResponseModel", +) -> "BaseResponseModel": + verify_permissions_for_model(model=model, action="READ") + + return dehydrate_response_model(model=model) def dehydrate_response_model( model: "BaseResponseModel", ) -> "BaseResponseModel": dehydrated_fields = {} - did_dehydrate = False for field_name in model.__fields__.keys(): value = getattr(model, field_name) - new_value, value_dehydrated = _maybe_dehydrate_value(value) - dehydrated_fields[field_name] = new_value - did_dehydrate = did_dehydrate or value_dehydrated - - if did_dehydrate: - dehydrated_fields["partial"] = True + dehydrated_fields[field_name] = _maybe_dehydrate_value(value) return type(model).parse_obj(dehydrated_fields) -def _maybe_dehydrate_value(value: Any) -> Tuple[Any, bool]: +def _maybe_dehydrate_value(value: Any) -> Any: if isinstance(value, BaseResponseModel): if has_read_permissions_for_model(value): - dehydrated_model = dehydrate_response_model(value) - return dehydrated_model, dehydrated_model.partial + return dehydrate_response_model(value) else: - return get_403_model(value), True + return get_403_model(value) elif isinstance(value, Dict): - dict_ = {} - did_dehydrate = False - for k, v in value.items(): - dict_[k], d = _maybe_dehydrate_value(v) - did_dehydrate = did_dehydrate or d - return dict_, did_dehydrate + return {k: _maybe_dehydrate_value(v) for k, v in value.items()} elif isinstance(value, (List, Set, Tuple)): - items = [] - did_dehydrate = False - for v in value: - item, d = _maybe_dehydrate_value(v) - items.append(item) - did_dehydrate = did_dehydrate or d - type_ = type(value) - return type_(items), did_dehydrate + return type_(_maybe_dehydrate_value(v) for v in value) else: return value, False @@ -792,7 +781,7 @@ def verify_permissions_for_model( # User is the owner of the model return - resource_type = RESOURCE_TYPE_MAPPING.get(type(model)) + resource_type = get_resource_type_for_model(model) if not resource_type: # This model is not tied to any RBAC resource type and therefore doesn't # require any special permissions diff --git a/src/zenml/zen_server/rbac_interface.py b/src/zenml/zen_server/rbac_interface.py index 332d234a21b..b1cfbec68fa 100644 --- a/src/zenml/zen_server/rbac_interface.py +++ b/src/zenml/zen_server/rbac_interface.py @@ -5,7 +5,6 @@ from pydantic import BaseModel from zenml.enums import StrEnum -from zenml.models import ComponentResponseModel, StackResponseModel from zenml.models.base_models import BaseResponseModel @@ -28,10 +27,34 @@ class ResourceType(StrEnum): SECRET = "secret" -RESOURCE_TYPE_MAPPING: Dict[Type[BaseResponseModel], ResourceType] = { - StackResponseModel: ResourceType.STACK, - ComponentResponseModel: ResourceType.STACK_COMPONENT, -} +def get_resource_type_for_model( + model: "BaseResponseModel", +) -> Optional[ResourceType]: + from zenml.models import ( + ArtifactResponseModel, + CodeRepositoryResponseModel, + ComponentResponseModel, + FlavorResponseModel, + ModelResponseModel, + PipelineResponseModel, + SecretResponseModel, + ServiceConnectorResponseModel, + StackResponseModel, + ) + + mapping: Dict[Type[BaseResponseModel], ResourceType] = { + FlavorResponseModel: ResourceType.FLAVOR, + ServiceConnectorResponseModel: ResourceType.SERVICE_CONNECTOR, + ComponentResponseModel: ResourceType.STACK_COMPONENT, + StackResponseModel: ResourceType.STACK, + PipelineResponseModel: ResourceType.PIPELINE, + CodeRepositoryResponseModel: ResourceType.CODE_REPOSITORY, + SecretResponseModel: ResourceType.SECRET, + ModelResponseModel: ResourceType.MODEL, + ArtifactResponseModel: ResourceType.ARTIFACT, + } + + return mapping.get(type(model)) class Resource(BaseModel): diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index ed6ee348897..3f077817bc9 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -24,8 +24,8 @@ from zenml.zen_server.auth import ( AuthContext, authorize, - dehydrate_response_model, verify_permissions_for_model, + verify_read_permissions_and_dehydrate, ) from zenml.zen_server.exceptions import error_response from zenml.zen_server.utils import ( @@ -94,16 +94,8 @@ def get_stack( Returns: The requested stack. """ - from zenml.zen_server.rbac_interface import Action - stack = zen_store().get_stack(stack_id) - - verify_permissions_for_model( - model=stack, - action=Action.READ, - ) - - return dehydrate_response_model(stack) + return verify_read_permissions_and_dehydrate(stack) @router.put( @@ -126,6 +118,9 @@ def update_stack( Returns: The updated stack. """ + stack = zen_store().get_stack(stack_id) + verify_permissions_for_model(stack, action="update") + return zen_store().update_stack( stack_id=stack_id, stack_update=stack_update, @@ -146,4 +141,7 @@ def delete_stack( Args: stack_id: Name of the stack. """ - zen_store().delete_stack(stack_id) # aka 'delete_stack' + stack = zen_store().get_stack(stack_id) + verify_permissions_for_model(stack, action="delete") + + zen_store().delete_stack(stack_id) diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index 6cf08e0e018..ec771c2683f 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -80,8 +80,7 @@ DEFAULT_USERNAME = "default" DEFAULT_PASSWORD = "" DEFAULT_WORKSPACE_NAME = "default" -DEFAULT_STACK_NAME = "default" -DEFAULT_STACK_COMPONENT_NAME = "default" +DEFAULT_STACK_AND_COMPONENT_NAME = "default" @make_proxy_class(SecretsStoreInterface, "_secrets_store") @@ -307,16 +306,10 @@ def _initialize_database(self) -> None: default_user = self._default_user except KeyError: default_user = self._create_default_user() - try: - self._get_default_stack( - workspace_name_or_id=default_workspace.id, - user_name_or_id=default_user.id, - ) - except KeyError: - self._create_default_stack( - workspace_name_or_id=default_workspace.id, - user_name_or_id=default_user.id, - ) + self._get_or_create_default_stack( + workspace_id=default_workspace.id, + user_id=default_user.id, + ) @property def url(self) -> str: @@ -476,13 +469,13 @@ def _get_or_create_default_stack( ) -> "StackResponseModel": try: return self._get_default_stack( - workspace_name_or_id=workspace.id, - user_name_or_id=self.get_user().id, + workspace_id=workspace.id, + user_id=self.get_user().id, ) except KeyError: return self._create_default_stack( - workspace_name_or_id=workspace.id, - user_name_or_id=self.get_user().id, + workspace_id=workspace.id, + user_id=self.get_user().id, ) def _get_or_create_default_workspace(self) -> "WorkspaceResponseModel": @@ -533,8 +526,8 @@ def _trigger_event(self, event: StoreEvent, **kwargs: Any) -> None: def _create_default_stack( self, - workspace_name_or_id: Union[str, UUID], - user_name_or_id: Union[str, UUID], + workspace_id: UUID, + user_id: UUID, ) -> StackResponseModel: """Create the default stack components and stack. @@ -542,30 +535,30 @@ def _create_default_stack( store. Args: - workspace_name_or_id: Name or ID of the workspace to which the stack + workspace_id: ID of the workspace to which the stack belongs. - user_name_or_id: The name or ID of the user that owns the stack. + user_id: ID of the user that owns the stack. Returns: The model of the created default stack. """ with analytics_disabler(): - workspace = self.get_workspace( - workspace_name_or_id=workspace_name_or_id - ) - user = self.get_user(user_name_or_id=user_name_or_id) + workspace = self.get_workspace(workspace_name_or_id=workspace_id) + user = self.get_user(user_name_or_id=user_id) logger.info( f"Creating default stack for user '{user.name}' in workspace " f"{workspace.name}..." ) + name = self._get_default_stack_and_component_name(user_id=user_id) + # Register the default orchestrator orchestrator = self.create_stack_component( component=ComponentRequestModel( user=user.id, workspace=workspace.id, - name=DEFAULT_STACK_COMPONENT_NAME, + name=name, type=StackComponentType.ORCHESTRATOR, flavor="local", configuration={}, @@ -577,7 +570,7 @@ def _create_default_stack( component=ComponentRequestModel( user=user.id, workspace=workspace.id, - name=DEFAULT_STACK_COMPONENT_NAME, + name=name, type=StackComponentType.ARTIFACT_STORE, flavor="local", configuration={}, @@ -589,7 +582,7 @@ def _create_default_stack( } # Register the default stack stack = StackRequestModel( - name=DEFAULT_STACK_NAME, + name=name, components=components, is_shared=False, workspace=workspace.id, @@ -597,16 +590,27 @@ def _create_default_stack( ) return self.create_stack(stack=stack) + def _get_default_stack_and_component_name(self, user_id: UUID) -> str: + """Get the name for the default stack and its components. + + Args: + user_id: ID of the user to which the default stack belongs. + + Returns: + The default stack/component name. + """ + return f"{DEFAULT_STACK_AND_COMPONENT_NAME}-{user_id}" + def _get_default_stack( self, - workspace_name_or_id: Union[str, UUID], - user_name_or_id: Union[str, UUID], + workspace_id: UUID, + user_id: UUID, ) -> StackResponseModel: """Get the default stack for a user in a workspace. Args: - workspace_name_or_id: Name or ID of the workspace. - user_name_or_id: Name or ID of the user. + workspace_id: ID of the workspace. + user_id: ID of the user. Returns: The default stack in the workspace owned by the supplied user. @@ -614,17 +618,20 @@ def _get_default_stack( Raises: KeyError: if the workspace or default stack doesn't exist. """ + stack_name = self._get_default_stack_and_component_name( + user_id=user_id + ) default_stacks = self.list_stacks( StackFilterModel( - workspace_id=workspace_name_or_id, - user_id=user_name_or_id, - name=DEFAULT_STACK_NAME, + workspace_id=workspace_id, + user_id=user_id, + name=stack_name, ) ) if default_stacks.total == 0: raise KeyError( - f"No default stack found for user {str(user_name_or_id)} in " - f"workspace {str(workspace_name_or_id)}" + f"No default stack found for user {str(user_id)} in " + f"workspace {str(workspace_id)}" ) return default_stacks.items[0] From 5baec80707a81c9795d420591e8ae9c76d525198 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 16 Oct 2023 16:13:24 +0200 Subject: [PATCH 004/103] Start removing sharing --- .../artifact_stores/local_artifact_store.py | 3 - src/zenml/cli/__init__.py | 23 +- src/zenml/cli/service_connectors.py | 54 +--- src/zenml/cli/stack.py | 76 +---- src/zenml/cli/stack_components.py | 82 ------ src/zenml/cli/utils.py | 39 +-- src/zenml/client.py | 123 +------- .../base_container_registry.py | 3 - .../flavors/feast_feature_store_flavor.py | 3 - ...reat_expectations_data_validator_flavor.py | 3 - .../flavors/kubeflow_orchestrator_flavor.py | 3 - .../orchestrators/kubeflow_orchestrator.py | 2 +- .../flavors/kubernetes_orchestrator_flavor.py | 3 - .../orchestrators/manifest_utils.py | 2 +- .../mlflow_experiment_tracker_flavor.py | 3 - .../flavors/mlflow_model_deployer_flavor.py | 3 - .../skypilot_orchestrator_base_vm_config.py | 3 - .../flavors/tekton_orchestrator_flavor.py | 3 - .../orchestrators/tekton_orchestrator.py | 2 +- src/zenml/models/base_models.py | 51 ---- src/zenml/models/component_models.py | 19 +- src/zenml/models/filter_models.py | 52 ---- src/zenml/models/schedule_model.py | 4 +- src/zenml/models/service_connector_models.py | 20 +- src/zenml/models/stack_models.py | 17 +- .../orchestrators/local/local_orchestrator.py | 3 - .../local_docker/local_docker_orchestrator.py | 3 - .../local/local_secrets_manager.py | 3 - .../service_connectors/service_connector.py | 6 - src/zenml/stack/stack_component.py | 3 - src/zenml/utils/mlstacks_utils.py | 1 - src/zenml/zen_stores/base_zen_store.py | 28 +- .../7500f434b71c_remove_shared_columns.py | 51 ++++ src/zenml/zen_stores/schemas/base_schemas.py | 6 - .../zen_stores/schemas/component_schemas.py | 5 +- .../schemas/service_connector_schemas.py | 6 +- src/zenml/zen_stores/schemas/stack_schemas.py | 5 +- src/zenml/zen_stores/sql_zen_store.py | 276 ++++-------------- 38 files changed, 189 insertions(+), 803 deletions(-) create mode 100644 src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py diff --git a/src/zenml/artifact_stores/local_artifact_store.py b/src/zenml/artifact_stores/local_artifact_store.py index 7608d3d4800..3d600c30273 100644 --- a/src/zenml/artifact_stores/local_artifact_store.py +++ b/src/zenml/artifact_stores/local_artifact_store.py @@ -75,9 +75,6 @@ def ensure_path_local(cls, path: str) -> str: def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/cli/__init__.py b/src/zenml/cli/__init__.py index 79ab62cf770..107ab13f003 100644 --- a/src/zenml/cli/__init__.py +++ b/src/zenml/cli/__init__.py @@ -182,12 +182,7 @@ ``` For fields marked as being of type `BOOL`, you can use the 'True' or 'False' -values to filter the output. For example, to find all orchestrators that are -currently shared, you would type: - -```shell -zenml orchestrator list --is_shared="True" -``` +values to filter the output. Finally, for fields marked as being of type `DATETIME`, you can pass in datetime values in the `%Y-%m-%d %H:%M:%S` format. These can be combined with the `gte`, @@ -228,8 +223,6 @@ zenml artifact-store list ``` -If you wish to update/share - If you wish to delete a particular artifact store, pass the name of the artifact store into the CLI with the following command: @@ -737,20 +730,6 @@ zenml stack register STACK_NAME ... --set ``` -If you want to share the stack and all of its components with everyone using -the same ZenML deployment, simply pass along the `--share` flag. - -```bash -zenml stack register STACK_NAME ... --share -``` - -Even if you haven't done so at creation time of the stack, you can always -decide to do so at a later stage. - -```bash -zenml stack share STACK_NAME -``` - To list the stacks that you have registered within your current ZenML workspace, type: diff --git a/src/zenml/cli/service_connectors.py b/src/zenml/cli/service_connectors.py index e57109a4dd1..3991fac9c52 100644 --- a/src/zenml/cli/service_connectors.py +++ b/src/zenml/cli/service_connectors.py @@ -293,13 +293,13 @@ def prompt_expiration_time( Non-interactive examples: -- register a shared, multi-purpose AWS service connector capable of accessing +- register a multi-purpose AWS service connector capable of accessing any of the resource types that it supports (e.g. S3 buckets, EKS Kubernetes clusters) using auto-configured credentials (i.e. extracted from the environment variables or AWS CLI configuration files): $ zenml service-connector register aws-auto-multi --description \\ -"Multi-purpose AWS connector" --type aws --share --auto-configure \\ +"Multi-purpose AWS connector" --type aws --auto-configure \\ --label auto=true --label purpose=multi - register a Docker service connector providing access to a single DockerHub @@ -382,14 +382,6 @@ def prompt_expiration_time( "-l key1=value1 and can be used multiple times.", multiple=True, ) -@click.option( - "--share", - "share", - is_flag=True, - default=False, - help="Share this service connector with other users.", - type=click.BOOL, -) @click.option( "--no-verify", "no_verify", @@ -444,7 +436,6 @@ def register_service_connector( resource_id: Optional[str] = None, auth_method: Optional[str] = None, expiration_seconds: Optional[int] = None, - share: bool = False, no_verify: bool = False, labels: Optional[List[str]] = None, interactive: bool = False, @@ -464,7 +455,6 @@ def register_service_connector( auth_method: The authentication method to use. expiration_seconds: The duration, in seconds, that the temporary credentials generated by this connector should remain valid. - share: Share the service connector with other users. no_verify: Do not verify the service connector before registering. labels: Labels to be associated with the service connector. @@ -603,7 +593,6 @@ def register_service_connector( description=description or "", connector_type=connector_type, resource_type=resource_type, - is_shared=share, auto_configure=True, verify=True, register=False, @@ -760,7 +749,6 @@ def register_service_connector( resource_type=resource_type, configuration=config_dict, expiration_seconds=expiration_seconds, - is_shared=share, auto_configure=False, verify=True, register=False, @@ -834,7 +822,6 @@ def register_service_connector( description=description or "", expiration_seconds=expiration_seconds, expires_at=expires_at, - is_shared=share, labels=parsed_labels, verify=not no_verify, auto_configure=auto_configure, @@ -1532,43 +1519,6 @@ def update_service_connector( ) -@service_connector.command( - "share", - help="""Share a service connector with other users. -""", -) -@click.argument( - "name_id_or_prefix", - type=str, - required=False, -) -def share_service_connector_command( - name_id_or_prefix: str, -) -> None: - """Shares a service connector. - - Args: - name_id_or_prefix: The name or id of the service connector to share. - """ - client = Client() - - with console.status( - f"Updating service connector '{name_id_or_prefix}'...\n" - ): - try: - client.update_service_connector( - name_id_or_prefix=name_id_or_prefix, - is_shared=True, - verify=False, - ) - except (KeyError, IllegalOperationError) as err: - cli_utils.error(str(err)) - - cli_utils.declare( - "Successfully shared service connector " f"`{name_id_or_prefix}`." - ) - - @service_connector.command( "delete", help="""Delete a service connector. diff --git a/src/zenml/cli/stack.py b/src/zenml/cli/stack.py index 9b7c8f21152..43f6a132339 100644 --- a/src/zenml/cli/stack.py +++ b/src/zenml/cli/stack.py @@ -48,7 +48,6 @@ from zenml.exceptions import ( IllegalOperationError, ProvisioningError, - StackExistsError, ) from zenml.io.fileio import rmtree from zenml.logger import get_logger @@ -196,13 +195,6 @@ def stack() -> None: help="Immediately set this stack as active.", type=click.BOOL, ) -@click.option( - "--share", - "share", - is_flag=True, - help="Use this flag to share this stack with other users.", - type=click.BOOL, -) def register_stack( stack_name: str, artifact_store: str, @@ -219,7 +211,6 @@ def register_stack( data_validator: Optional[str] = None, image_builder: Optional[str] = None, set_stack: bool = False, - share: bool = False, ) -> None: """Register a stack. @@ -239,7 +230,6 @@ def register_stack( data_validator: Name of the data validator for this stack. image_builder: Name of the new image builder for this stack. set_stack: Immediately set this stack as active. - share: Share the stack with other users. """ with console.status(f"Registering stack '{stack_name}'...\n"): client = Client() @@ -276,15 +266,10 @@ def register_stack( StackComponentType.CONTAINER_REGISTRY ] = container_registry - # click<8.0.0 gives flags a default of None - if share is None: - share = False - try: created_stack = client.create_stack( name=stack_name, components=components, - is_shared=share, ) except (KeyError, IllegalOperationError) as err: cli_utils.error(str(err)) @@ -499,52 +484,6 @@ def update_stack( print_model_url(get_stack_url(updated_stack)) -@stack.command( - "share", - context_settings=dict(ignore_unknown_options=True), - help="Share a stack and all its components.", -) -@click.argument("stack_name_or_id", type=str, required=False) -@click.option( - "--recursive", - "-r", - "recursive", - is_flag=True, - help="Recursively also share all stack components if they are private.", -) -def share_stack( - stack_name_or_id: Optional[str], recursive: bool = False -) -> None: - """Share a stack with your team. - - Args: - stack_name_or_id: Name or id of the stack to share. - recursive: Recursively also share all components - """ - client = Client() - - with console.status("Sharing the stack...\n"): - try: - if recursive: - stack_to_update = client.get_stack( - name_id_or_prefix=stack_name_or_id - ) - for c_type, components in stack_to_update.components.items(): - for component in components: - client.update_stack_component( - name_id_or_prefix=component.id, - component_type=c_type, - is_shared=True, - ) - updated_stack = client.update_stack( - name_id_or_prefix=stack_name_or_id, - is_shared=True, - ) - except (KeyError, IllegalOperationError, StackExistsError) as err: - cli_utils.error(str(err)) - cli_utils.declare(f"Stack `{updated_stack.name}` successfully shared!") - - @stack.command( "remove-component", context_settings=dict(ignore_unknown_options=True), @@ -1131,7 +1070,7 @@ def import_stack( component_ids[component_type] = component_id imported_stack = Client().create_stack( - name=stack_name, components=component_ids, is_shared=False + name=stack_name, components=component_ids ) print_model_url(get_stack_url(imported_stack)) @@ -1140,22 +1079,12 @@ def import_stack( @stack.command("copy", help="Copy a stack to a new stack name.") @click.argument("source_stack_name_or_id", type=str, required=True) @click.argument("target_stack", type=str, required=True) -@click.option( - "--share", - "share", - is_flag=True, - help="Use this flag to share this stack with other users.", - type=click.BOOL, -) -def copy_stack( - source_stack_name_or_id: str, target_stack: str, share: bool = False -) -> None: +def copy_stack(source_stack_name_or_id: str, target_stack: str) -> None: """Copy a stack. Args: source_stack_name_or_id: The name or id of the stack to copy. target_stack: Name of the copied stack. - share: Share the stack with other users. """ client = Client() @@ -1176,7 +1105,6 @@ def copy_stack( copied_stack = client.create_stack( name=target_stack, components=component_mapping, - is_shared=share, ) print_model_url(get_stack_url(copied_stack)) diff --git a/src/zenml/cli/stack_components.py b/src/zenml/cli/stack_components.py index 65596227e02..538ec9e8f3a 100644 --- a/src/zenml/cli/stack_components.py +++ b/src/zenml/cli/stack_components.py @@ -230,20 +230,6 @@ def generate_stack_component_register_command( "-l key1=value1 -l key2=value2.", multiple=True, ) - @click.option( - "--share", - "share", - is_flag=True, - help="Use this flag to share this stack component with other users.", - type=click.BOOL, - ) - @click.option( - "--connector", - "-c", - "connector", - help="Use this flag to connect this stack component to a service connector.", - type=str, - ) @click.option( "--connector", "-c", @@ -265,7 +251,6 @@ def generate_stack_component_register_command( def register_stack_component_command( name: str, flavor: str, - share: bool, args: List[str], labels: Optional[List[str]] = None, connector: Optional[str] = None, @@ -276,7 +261,6 @@ def register_stack_component_command( Args: name: Name of the component to register. flavor: Flavor of the component to register. - share: Share the stack with other users. args: Additional arguments to pass to the component. labels: Labels to be associated with the component. connector: Name of the service connector to connect the component to. @@ -296,10 +280,6 @@ def register_stack_component_command( parsed_labels = cli_utils.get_parsed_labels(labels) - # click<8.0.0 gives flags a default of None - if share is None: - share = False - if connector: try: client.get_service_connector(connector) @@ -316,7 +296,6 @@ def register_stack_component_command( component_type=component_type, configuration=parsed_args, labels=parsed_labels, - is_shared=share, ) cli_utils.declare( @@ -414,57 +393,6 @@ def update_stack_component_command( return update_stack_component_command -def generate_stack_component_share_command( - component_type: StackComponentType, -) -> Callable[[str], None]: - """Generates an `share` command for the specific stack component type. - - Args: - component_type: Type of the component to generate the command for. - - Returns: - A function that can be used as a `click` command. - """ - display_name = _component_display_name(component_type) - - @click.argument( - "name_id_or_prefix", - type=str, - required=False, - ) - def share_stack_component_command( - name_id_or_prefix: str, - ) -> None: - """Shares a stack component. - - Args: - name_id_or_prefix: The name or id of the stack component to update. - """ - if component_type == StackComponentType.SECRETS_MANAGER: - warn_deprecated_secrets_manager() - - client = Client() - - with console.status( - f"Updating {display_name} '{name_id_or_prefix}'...\n" - ): - try: - client.update_stack_component( - name_id_or_prefix=name_id_or_prefix, - component_type=component_type, - is_shared=True, - ) - except (KeyError, IllegalOperationError) as err: - cli_utils.error(str(err)) - - cli_utils.declare( - f"Successfully shared {display_name} " - f"`{name_id_or_prefix}`." - ) - - return share_stack_component_command - - def generate_stack_component_remove_attribute_command( component_type: StackComponentType, ) -> Callable[[str, List[str]], None]: @@ -676,7 +604,6 @@ def copy_stack_component_command( component_type=component_to_copy.type, configuration=component_to_copy.configuration, labels=component_to_copy.labels, - is_shared=component_to_copy.is_shared, component_spec_path=component_to_copy.component_spec_path, ) print_model_url(get_component_url(copied_component)) @@ -1812,15 +1739,6 @@ def command_group() -> None: help=f"Update a registered {singular_display_name}.", )(update_command) - # zenml stack-component share - share_command = generate_stack_component_share_command(component_type) - context_settings = {"ignore_unknown_options": True} - command_group.command( - "share", - context_settings=context_settings, - help=f"Share a registered {singular_display_name}.", - )(share_command) - # zenml stack-component remove-attribute remove_attribute_command = ( generate_stack_component_remove_attribute_command(component_type) diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index 05bfda802a7..8240cc9e904 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -497,8 +497,7 @@ def print_stack_configuration( console.print(rich_table) declare( f"Stack '{stack.name}' with id '{stack.id}' is " - f"{f'owned by user {stack.user.name} and is ' if stack.user else ''}" - f"'{'shared' if stack.is_shared else 'private'}'." + f"{f'owned by user {stack.user.name}.' if stack.user else 'unowned.'}" ) if stack.stack_spec_path: @@ -543,8 +542,7 @@ def print_stack_component_configuration( declare( f"{component.type.value.title()} '{component.name}' of flavor " f"'{component.flavor}' with id '{component.id}' is owned by " - f"user '{user_name}' and is " - f"'{'shared' if component.is_shared else 'private'}'." + f"user '{user_name}'." ) if len(component.configuration) == 0: @@ -1333,16 +1331,16 @@ def describe_pydantic_object(schema_json: Dict[str, Any]) -> None: declare(f" {prop_schema['description']}", width=80) -def get_shared_emoji(is_shared: bool) -> str: - """Returns the emoji for whether a stack is shared or not. +def get_boolean_emoji(value: bool) -> str: + """Returns the emoji for displaying a boolean. Args: - is_shared: Whether the stack is shared or not. + value: The boolean value to display Returns: - The emoji for whether the stack is shared or not. + The emoji for the boolean """ - return ":white_heavy_check_mark:" if is_shared else ":heavy_minus_sign:" + return ":white_heavy_check_mark:" if value else ":heavy_minus_sign:" def replace_emojis(text: str) -> str: @@ -1403,7 +1401,6 @@ def print_stacks_table( "ACTIVE": ":point_right:" if is_active else "", "STACK NAME": stack.name, "STACK ID": stack.id, - "SHARED": get_shared_emoji(stack.is_shared), "OWNER": user_name, **{ component_type.upper(): components[0].name @@ -1466,7 +1463,6 @@ def print_components_table( "NAME": component.name, "COMPONENT ID": component.id, "FLAVOR": component.flavor, - "SHARED": get_shared_emoji(component.is_shared), "OWNER": f"{component.user.name if component.user else 'DELETED!'}", } configurations.append(component_config) @@ -1575,7 +1571,6 @@ def print_service_connectors_table( "TYPE": connector.emojified_connector_type, "RESOURCE TYPES": "\n".join(connector.emojified_resource_types), "RESOURCE NAME": resource_name, - "SHARED": get_shared_emoji(connector.is_shared), "OWNER": f"{connector.user.name if connector.user else 'DELETED!'}", "EXPIRES IN": expires_in( connector.expires_at, ":name_badge: Expired!" @@ -1688,19 +1683,15 @@ def print_service_connector_configuration( declare( f"Service connector '{connector.name}' of type " f"'{connector.type}' with id '{connector.id}' is owned by " - f"user '{user_name}' and is " - f"'{'shared' if connector.is_shared else 'private'}'." + f"user '{user_name}'." ) else: declare( f"Service connector '{connector.name}' of type " - f"'{connector.type}' is " - f"'{'shared' if connector.is_shared else 'private'}'." + f"'{connector.type}'." ) - title_ = ( - f"'{connector.name}' {connector.type} Service Connector " "Details" - ) + title_ = f"'{connector.name}' {connector.type} Service Connector Details" if active_status: title_ += " (ACTIVE)" @@ -1734,7 +1725,6 @@ def print_service_connector_configuration( else "N/A", "OWNER": user_name, "WORKSPACE": connector.workspace.name, - "SHARED": get_shared_emoji(connector.is_shared), "CREATED_AT": connector.created, "UPDATED_AT": connector.updated, } @@ -1751,7 +1741,6 @@ def print_service_connector_configuration( ) if connector.expires_at else "N/A", - "SHARED": get_shared_emoji(connector.is_shared), } for item in properties.items(): @@ -1831,8 +1820,8 @@ def print_service_connector_types_table( connector_type.emojified_resource_types ), "AUTH METHODS": "\n".join(supported_auth_methods), - "LOCAL": get_shared_emoji(connector_type.local), - "REMOTE": get_shared_emoji(connector_type.remote), + "LOCAL": get_boolean_emoji(connector_type.local), + "REMOTE": get_boolean_emoji(connector_type.remote), } configurations.append(connector_type_config) print_table(configurations) @@ -2091,7 +2080,6 @@ def print_debug_stack() -> None: declare("\nCURRENT STACK\n", bold=True) console.print(f"Name: {stack.name}") console.print(f"ID: {str(stack.id)}") - console.print(f"Shared: {'Yes' if stack.is_shared else 'No'}") if stack.user and stack.user.name and stack.user.id: # mypy check console.print(f"User: {stack.user.name} / {str(stack.user.id)}") console.print( @@ -2110,9 +2098,6 @@ def print_debug_stack() -> None: console.print(f"Type: {component.type.value}") console.print(f"Flavor: {component.flavor}") console.print(f"Configuration: {_scrub_secret(component.config)}") - console.print( - f"Shared: {'Yes' if component_response.is_shared else 'No'}" - ) if ( component_response.user and component_response.user.name diff --git a/src/zenml/client.py b/src/zenml/client.py index ef5ffdda761..540a0dc189e 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -1026,6 +1026,13 @@ def get_stack( Returns: The stack. """ + if name_id_or_prefix == "default": + name_id_or_prefix = ( + self.zen_store._get_default_stack_and_component_name( + user_id=self.active_user.id + ) + ) + if name_id_or_prefix is not None: return self._get_entity_by_id_or_name_or_prefix( get_method=self.zen_store.get_stack, @@ -1040,7 +1047,6 @@ def create_stack( self, name: str, components: Mapping[StackComponentType, Union[str, UUID]], - is_shared: bool = False, stack_spec_file: Optional[str] = None, ) -> "StackResponseModel": """Registers a stack and its components. @@ -1048,15 +1054,10 @@ def create_stack( Args: name: The name of the stack to register. components: dictionary which maps component types to component names - is_shared: boolean to decide whether the stack is shared stack_spec_file: path to the stack spec file Returns: The model of the registered stack. - - Raises: - ValueError: If the stack contains private components and is - attempted to be registered as shared. """ stack_components = {} @@ -1072,22 +1073,9 @@ def create_stack( ) stack_components[c_type] = [component.id] - # Raise an error if private components are used in a shared stack. - if is_shared and not component.is_shared: - raise ValueError( - f"You attempted to include the private {c_type} " - f"'{component.name}' in a shared stack. This is not " - f"supported. You can either share the {c_type} with the " - f"following command:\n" - f"`zenml {c_type.replace('_', '-')} share`{component.id}`\n" - f"or create the stack privately and then share it and all " - f"of its components using:\n`zenml stack share {name} -r`" - ) - stack = StackRequestModel( name=name, components=stack_components, - is_shared=is_shared, stack_spec_path=stack_spec_file, workspace=self.active_workspace.id, user=self.active_user.id, @@ -1101,7 +1089,6 @@ def update_stack( self, name_id_or_prefix: Optional[Union[UUID, str]] = None, name: Optional[str] = None, - is_shared: Optional[bool] = None, stack_spec_file: Optional[str] = None, description: Optional[str] = None, component_updates: Optional[ @@ -1113,7 +1100,6 @@ def update_stack( Args: name_id_or_prefix: The name, id or prefix of the stack to update. name: the new name of the stack. - is_shared: the new shared status of the stack. stack_spec_file: path to the stack spec file description: the new description of the stack. component_updates: dictionary which maps stack component types to @@ -1123,8 +1109,6 @@ def update_stack( The model of the updated stack. Raises: - ValueError: If the stack contains private components and is - attempted to be shared. EntityExistsError: If the stack name is already taken. """ # First, get the stack @@ -1139,10 +1123,8 @@ def update_stack( stack_spec_path=stack_spec_file, ) - shared_status = is_shared or stack.is_shared - if name: - if self.list_stacks(name=name, is_shared=shared_status): + if self.list_stacks(name=name): raise EntityExistsError( "There are already existing stacks with the name " f"'{name}'." @@ -1150,31 +1132,6 @@ def update_stack( update_model.name = name - if is_shared: - current_name = update_model.name or stack.name - if self.list_stacks(name=current_name, is_shared=True): - raise EntityExistsError( - "There are already existing shared stacks with the name " - f"'{current_name}'." - ) - - for component_type, components in stack.components.items(): - for c in components: - if not c.is_shared: - raise ValueError( - f"A Stack can only be shared when all its " - f"components are also shared. Component " - f"'{component_type}:{c.name}' is not shared. Set " - f"the {component_type} to shared like this and " - f"then try re-sharing your stack:\n" - f"`zenml {component_type.replace('_', '-')} " - f"share {c.id}`\nAlternatively, you can rerun " - f"your command with `-r` to recursively " - f"share all components within the stack." - ) - - update_model.is_shared = is_shared - if description: update_model.description = description @@ -1192,22 +1149,6 @@ def update_stack( for component_id in component_id_list ] - # If the stack is shared, ensure all new components are also shared - if shared_status: - for component_list in components_dict.values(): - for component in component_list: - if not component.is_shared: - raise ValueError( - "Private components cannot be added to a " - "shared stack. Component " - f"'{component.type}:{component.name}' is not " - "shared. Set the component to shared like " - "this and then try adding it to your stack " - "again:\n" - f"`zenml {component.type.replace('_', '-')} " - f"share {component.id}`." - ) - update_model.components = { c_type: [c.id for c in c_list] for c_type, c_list in components_dict.items() @@ -1295,7 +1236,6 @@ def list_stacks( id: Optional[Union[UUID, str]] = None, created: Optional[datetime] = None, updated: Optional[datetime] = None, - is_shared: Optional[bool] = None, name: Optional[str] = None, description: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, @@ -1317,7 +1257,6 @@ def list_stacks( user_id: The id of the user to filter by. component_id: The id of the component to filter by. name: The name of the stack to filter by. - is_shared: The shared status of the stack to filter by. Returns: A page of stacks. @@ -1331,7 +1270,6 @@ def list_stacks( user_id=user_id, component_id=component_id, name=name, - is_shared=is_shared, description=description, id=id, created=created, @@ -1463,6 +1401,11 @@ def get_stack_component( KeyError: If no name_id_or_prefix is provided and no such component is part of the active stack. """ + if name_id_or_prefix == "default": + self.zen_store._get_default_stack_and_component_name( + user_id=self.active_user.id + ) + # If no `name_id_or_prefix` provided, try to get the active component. if not name_id_or_prefix: components = self.active_stack_model.components.get( @@ -1514,7 +1457,6 @@ def list_stack_components( id: Optional[Union[UUID, str]] = None, created: Optional[datetime] = None, updated: Optional[datetime] = None, - is_shared: Optional[bool] = None, name: Optional[str] = None, flavor: Optional[str] = None, type: Optional[str] = None, @@ -1538,7 +1480,6 @@ def list_stack_components( user_id: The id of the user to filter by. connector_id: The id of the connector to filter by. name: The name of the component to filter by. - is_shared: The shared status of the component to filter by. Returns: A page of stack components. @@ -1552,7 +1493,6 @@ def list_stack_components( user_id=user_id, connector_id=connector_id, name=name, - is_shared=is_shared, flavor=flavor, type=type, id=id, @@ -1573,7 +1513,6 @@ def create_stack_component( configuration: Dict[str, str], component_spec_path: Optional[str] = None, labels: Optional[Dict[str, Any]] = None, - is_shared: bool = False, ) -> "ComponentResponseModel": """Registers a stack component. @@ -1584,7 +1523,6 @@ def create_stack_component( component_type: The type of the stack component. configuration: The configuration of the stack component. labels: The labels of the stack component. - is_shared: Whether the stack component is shared or not. Returns: The model of the registered component. @@ -1613,7 +1551,6 @@ def create_stack_component( flavor=flavor, component_spec_path=component_spec_path, configuration=configuration, - is_shared=is_shared, user=self.active_user.id, workspace=self.active_workspace.id, labels=labels, @@ -1632,7 +1569,6 @@ def update_stack_component( component_spec_path: Optional[str] = None, configuration: Optional[Dict[str, Any]] = None, labels: Optional[Dict[str, Any]] = None, - is_shared: Optional[bool] = None, connector_id: Optional[UUID] = None, connector_resource_id: Optional[str] = None, ) -> "ComponentResponseModel": @@ -1646,7 +1582,6 @@ def update_stack_component( component_spec_path: The new path to the stack spec file. configuration: The new configuration of the stack component. labels: The new labels of the stack component. - is_shared: The new shared status of the stack component. connector_id: The new connector id of the stack component. connector_resource_id: The new connector resource id of the stack component. @@ -1671,33 +1606,17 @@ def update_stack_component( ) if name is not None: - shared_status = is_shared or component.is_shared - existing_components = self.list_stack_components( name=name, - is_shared=shared_status, type=component_type, ) if existing_components.total > 0: raise EntityExistsError( - f"There are already existing " - f"{'shared' if shared_status else 'unshared'} components " - f"with the name '{name}'." + f"There are already existing components with the " + f"name '{name}'." ) update_model.name = name - if is_shared is not None: - current_name = update_model.name or component.name - existing_components = self.list_stack_components( - name=current_name, is_shared=is_shared, type=component_type - ) - if any(e.id != component.id for e in existing_components.items): - raise EntityExistsError( - f"There are already existing shared components with " - f"the name '{current_name}'" - ) - update_model.is_shared = is_shared - if configuration is not None: existing_configuration = component.configuration existing_configuration.update(configuration) @@ -3652,7 +3571,6 @@ def list_service_connectors( id: Optional[Union[UUID, str]] = None, created: Optional[datetime] = None, updated: Optional[datetime] = None, - is_shared: Optional[bool] = None, name: Optional[str] = None, connector_type: Optional[str] = None, auth_method: Optional[str] = None, @@ -3682,7 +3600,6 @@ def list_service_connectors( workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. name: The name of the service connector to filter by. - is_shared: The shared status of the service connector to filter by. labels: The labels of the service connector to filter by. secret_id: Filter by the id of the secret that is referenced by the service connector. @@ -3698,7 +3615,6 @@ def list_service_connectors( workspace_id=workspace_id or self.active_workspace.id, user_id=user_id, name=name, - is_shared=is_shared, connector_type=connector_type, auth_method=auth_method, resource_type=resource_type, @@ -3725,7 +3641,6 @@ def create_service_connector( description: str = "", expiration_seconds: Optional[int] = None, expires_at: Optional[datetime] = None, - is_shared: bool = False, labels: Optional[Dict[str, str]] = None, auto_configure: bool = False, verify: bool = True, @@ -3754,7 +3669,6 @@ def create_service_connector( expiration_seconds: The expiration time of the service connector. expires_at: The expiration time of the service connector credentials. - is_shared: Whether the service connector is shared or not. labels: The labels of the service connector. auto_configure: Whether to automatically configure the service connector from the local environment. @@ -3828,7 +3742,6 @@ def create_service_connector( user=self.active_user.id, workspace=self.active_workspace.id, description=description or "", - is_shared=is_shared, labels=labels, ) @@ -3873,7 +3786,6 @@ def create_service_connector( auth_method=auth_method, expiration_seconds=expiration_seconds, expires_at=expires_at, - is_shared=is_shared, user=self.active_user.id, workspace=self.active_workspace.id, labels=labels or {}, @@ -3948,7 +3860,6 @@ def update_service_connector( resource_id: Optional[str] = None, description: Optional[str] = None, expiration_seconds: Optional[int] = None, - is_shared: Optional[bool] = None, labels: Optional[Dict[str, Optional[str]]] = None, verify: bool = True, list_resources: bool = True, @@ -3992,7 +3903,6 @@ def update_service_connector( description: The description of the service connector. expiration_seconds: The expiration time of the service connector. If set to 0, the existing expiration time will be removed. - is_shared: Whether the service connector is shared or not. labels: The service connector to update or remove. If a label value is set to None, the label will be removed. verify: Whether to verify that the service connector configuration @@ -4056,9 +3966,6 @@ def update_service_connector( description=description or connector_model.description, auth_method=auth_method or connector_model.auth_method, expiration_seconds=expiration_seconds, - is_shared=is_shared - if is_shared is not None - else connector_model.is_shared, user=self.active_user.id, workspace=self.active_workspace.id, ) diff --git a/src/zenml/container_registries/base_container_registry.py b/src/zenml/container_registries/base_container_registry.py index f31c5570649..12539d2643b 100644 --- a/src/zenml/container_registries/base_container_registry.py +++ b/src/zenml/container_registries/base_container_registry.py @@ -58,9 +58,6 @@ def strip_trailing_slash(cls, uri: str) -> str: def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/integrations/feast/flavors/feast_feature_store_flavor.py b/src/zenml/integrations/feast/flavors/feast_feature_store_flavor.py index 4b24da8df84..6dd25d8464b 100644 --- a/src/zenml/integrations/feast/flavors/feast_feature_store_flavor.py +++ b/src/zenml/integrations/feast/flavors/feast_feature_store_flavor.py @@ -36,9 +36,6 @@ class FeastFeatureStoreConfig(BaseFeatureStoreConfig): def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py b/src/zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py index ea6daa47177..95f78d61159 100644 --- a/src/zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py +++ b/src/zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py @@ -82,9 +82,6 @@ def _ensure_valid_context_root_dir( def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py b/src/zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py index 5c3aa251456..8dfa6f965ca 100644 --- a/src/zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py +++ b/src/zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py @@ -230,9 +230,6 @@ def is_remote(self) -> bool: def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py b/src/zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py index a00dc9a97b0..a2d4a7e06f9 100644 --- a/src/zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py +++ b/src/zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py @@ -401,7 +401,7 @@ def _configure_container_op( pass else: # Run KFP containers in the context of the local UID/GID - # to ensure that the artifact and metadata stores can be shared + # to ensure that the local stores can be shared # with the local pipeline runs. container_op.container.security_context = ( k8s_client.V1SecurityContext( diff --git a/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py b/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py index 1c853c620ee..6cd92af15af 100644 --- a/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +++ b/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py @@ -97,9 +97,6 @@ def is_remote(self) -> bool: def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py b/src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py index d89b3ef603a..45aa3736230 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py +++ b/src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py @@ -77,7 +77,7 @@ def add_local_stores_mount( pass else: # Run KFP containers in the context of the local UID/GID - # to ensure that the artifact and metadata stores can be shared + # to ensure that the local stores can be shared # with the local pipeline runs. pod_spec.security_context = k8s_client.V1SecurityContext( run_as_user=os.getuid(), diff --git a/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py b/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py index 1e6b9079f58..67bc693d444 100644 --- a/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +++ b/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py @@ -169,9 +169,6 @@ def _ensure_authentication_if_necessary( def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py b/src/zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py index c0dd6e50dda..7de40370465 100644 --- a/src/zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py +++ b/src/zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py @@ -39,9 +39,6 @@ class MLFlowModelDeployerConfig(BaseModelDeployerConfig): def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py b/src/zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py index 10cc1a36415..daca191d7fe 100644 --- a/src/zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py +++ b/src/zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py @@ -117,9 +117,6 @@ class SkypilotBaseOrchestratorConfig( # type: ignore[misc] # https://github.com def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py b/src/zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py index 01896e2d1a3..9cd6bdc3501 100644 --- a/src/zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py +++ b/src/zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py @@ -115,9 +115,6 @@ def is_remote(self) -> bool: def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/integrations/tekton/orchestrators/tekton_orchestrator.py b/src/zenml/integrations/tekton/orchestrators/tekton_orchestrator.py index 101d7031411..da5a34c24b3 100644 --- a/src/zenml/integrations/tekton/orchestrators/tekton_orchestrator.py +++ b/src/zenml/integrations/tekton/orchestrators/tekton_orchestrator.py @@ -296,7 +296,7 @@ def _configure_container_op( pass else: # Run KFP containers in the context of the local UID/GID - # to ensure that the artifact and metadata stores can be shared + # to ensure that the local stores can be shared # with the local pipeline runs. container_op.container.security_context = ( k8s_client.V1SecurityContext( diff --git a/src/zenml/models/base_models.py b/src/zenml/models/base_models.py index 8ac8ab82029..91b952ec632 100644 --- a/src/zenml/models/base_models.py +++ b/src/zenml/models/base_models.py @@ -154,31 +154,6 @@ def get_analytics_metadata(self) -> Dict[str, Any]: return metadata -class ShareableResponseModel(WorkspaceScopedResponseModel): - """Base shareable workspace-scoped domain model. - - Used as a base class for all domain models that are workspace-scoped and are - shareable. - """ - - is_shared: bool = Field( - title=( - "Flag describing if this resource is shared with other users in " - "the same workspace." - ), - ) - - def get_analytics_metadata(self) -> Dict[str, Any]: - """Fetches the analytics metadata for workspace scoped models. - - Returns: - The analytics metadata. - """ - metadata = super().get_analytics_metadata() - metadata["is_shared"] = self.is_shared - return metadata - - # -------------- # # REQUEST MODELS # # -------------- # @@ -231,32 +206,6 @@ def get_analytics_metadata(self) -> Dict[str, Any]: return metadata -class ShareableRequestModel(WorkspaceScopedRequestModel): - """Base shareable workspace-scoped domain model. - - Used as a base class for all domain models that are workspace-scoped and are - shareable. - """ - - is_shared: bool = Field( - default=False, - title=( - "Flag describing if this resource is shared with other users in " - "the same workspace." - ), - ) - - def get_analytics_metadata(self) -> Dict[str, Any]: - """Fetches the analytics metadata for workspace scoped models. - - Returns: - The analytics metadata. - """ - metadata = super().get_analytics_metadata() - metadata["is_shared"] = self.is_shared - return metadata - - # ------------- # # UPDATE MODELS # # ------------- # diff --git a/src/zenml/models/component_models.py b/src/zenml/models/component_models.py index 256ed20c997..dcf24a30a72 100644 --- a/src/zenml/models/component_models.py +++ b/src/zenml/models/component_models.py @@ -30,12 +30,12 @@ from zenml.enums import StackComponentType from zenml.logger import get_logger from zenml.models.base_models import ( - ShareableRequestModel, - ShareableResponseModel, + WorkspaceScopedRequestModel, + WorkspaceScopedResponseModel, update_model, ) from zenml.models.constants import STR_FIELD_MAX_LENGTH -from zenml.models.filter_models import ShareableWorkspaceScopedFilterModel +from zenml.models.filter_models import WorkspaceScopedFilterModel from zenml.models.service_connector_models import ServiceConnectorResponseModel from zenml.utils import secret_utils @@ -92,7 +92,7 @@ class ComponentBaseModel(BaseModel): # -------- # -class ComponentResponseModel(ComponentBaseModel, ShareableResponseModel): +class ComponentResponseModel(ComponentBaseModel, WorkspaceScopedResponseModel): """Response model for stack components.""" ANALYTICS_FIELDS: ClassVar[List[str]] = ["type", "flavor"] @@ -108,7 +108,7 @@ class ComponentResponseModel(ComponentBaseModel, ShareableResponseModel): # ------ # -class ComponentFilterModel(ShareableWorkspaceScopedFilterModel): +class ComponentFilterModel(WorkspaceScopedFilterModel): """Model to enable advanced filtering of all ComponentModels. The Component Model needs additional scoping. As such the `_scope_user` @@ -118,11 +118,11 @@ class ComponentFilterModel(ShareableWorkspaceScopedFilterModel): """ FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *ShareableWorkspaceScopedFilterModel.FILTER_EXCLUDE_FIELDS, + *WorkspaceScopedFilterModel.FILTER_EXCLUDE_FIELDS, "scope_type", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *ShareableWorkspaceScopedFilterModel.CLI_EXCLUDE_FIELDS, + *WorkspaceScopedFilterModel.CLI_EXCLUDE_FIELDS, "scope_type", ] scope_type: Optional[str] = Field( @@ -130,9 +130,6 @@ class ComponentFilterModel(ShareableWorkspaceScopedFilterModel): description="The type to scope this query to.", ) - is_shared: Optional[Union[bool, str]] = Field( - default=None, description="If the stack is shared or private" - ) name: Optional[str] = Field( default=None, description="Name of the stack component", @@ -190,7 +187,7 @@ def generate_filter( # ------- # -class ComponentRequestModel(ComponentBaseModel, ShareableRequestModel): +class ComponentRequestModel(ComponentBaseModel, WorkspaceScopedRequestModel): """Request model for stack components.""" ANALYTICS_FIELDS: ClassVar[List[str]] = ["type", "flavor"] diff --git a/src/zenml/models/filter_models.py b/src/zenml/models/filter_models.py index 14dda9d14b3..6c5dbddb610 100644 --- a/src/zenml/models/filter_models.py +++ b/src/zenml/models/filter_models.py @@ -814,58 +814,6 @@ def apply_filter( return query -class ShareableWorkspaceScopedFilterModel(WorkspaceScopedFilterModel): - """Model to enable advanced scoping with workspace and user scoped shareable things.""" - - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedFilterModel.FILTER_EXCLUDE_FIELDS, - "scope_user", - ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedFilterModel.CLI_EXCLUDE_FIELDS, - "scope_user", - ] - scope_user: Optional[UUID] = Field( - default=None, - description="The user to scope this query to.", - ) - - def set_scope_user(self, user_id: UUID) -> None: - """Set the user that is performing the filtering to scope the response. - - Args: - user_id: The user ID to scope the response to. - """ - self.scope_user = user_id - - def apply_filter( - self, - query: Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"], - table: Type["AnySchema"], - ) -> Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"]: - """Applies the filter to a query. - - Args: - query: The query to which to apply the filter. - table: The query table. - - Returns: - The query with filter applied. - """ - from sqlmodel import or_ - - query = super().apply_filter(query=query, table=table) - - if self.scope_user: - scope_filter = or_( - getattr(table, "user_id") == self.scope_user, - getattr(table, "is_shared").is_(True), - ) - query = query.where(scope_filter) - - return query - - class UserScopedFilterModel(BaseFilterModel): """Model to enable advanced user-based scoping.""" diff --git a/src/zenml/models/schedule_model.py b/src/zenml/models/schedule_model.py index bc29bcb1a69..75b75c5d0bc 100644 --- a/src/zenml/models/schedule_model.py +++ b/src/zenml/models/schedule_model.py @@ -24,7 +24,7 @@ WorkspaceScopedRequestModel, WorkspaceScopedResponseModel, ) -from zenml.models.filter_models import ShareableWorkspaceScopedFilterModel +from zenml.models.filter_models import WorkspaceScopedFilterModel # ---- # # BASE # @@ -56,7 +56,7 @@ class ScheduleResponseModel(ScheduleBaseModel, WorkspaceScopedResponseModel): # ------ # -class ScheduleFilterModel(ShareableWorkspaceScopedFilterModel): +class ScheduleFilterModel(WorkspaceScopedFilterModel): """Model to enable advanced filtering of all Users.""" workspace_id: Optional[Union[UUID, str]] = Field( diff --git a/src/zenml/models/service_connector_models.py b/src/zenml/models/service_connector_models.py index 18d4c5a23b2..efeab6b218a 100644 --- a/src/zenml/models/service_connector_models.py +++ b/src/zenml/models/service_connector_models.py @@ -38,12 +38,12 @@ from zenml.logger import get_logger from zenml.models.base_models import ( - ShareableRequestModel, - ShareableResponseModel, + WorkspaceScopedRequestModel, + WorkspaceScopedResponseModel, update_model, ) from zenml.models.constants import STR_FIELD_MAX_LENGTH -from zenml.models.filter_models import ShareableWorkspaceScopedFilterModel +from zenml.models.filter_models import WorkspaceScopedFilterModel if TYPE_CHECKING: from zenml.models.component_models import ComponentBaseModel @@ -1132,7 +1132,7 @@ def from_connector_model( class ServiceConnectorResponseModel( - ServiceConnectorBaseModel, ShareableResponseModel + ServiceConnectorBaseModel, WorkspaceScopedResponseModel ): """Response model for service connectors.""" @@ -1169,18 +1169,18 @@ def get_analytics_metadata(self) -> Dict[str, Any]: # ------ # -class ServiceConnectorFilterModel(ShareableWorkspaceScopedFilterModel): +class ServiceConnectorFilterModel(WorkspaceScopedFilterModel): """Model to enable advanced filtering of service connectors.""" FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *ShareableWorkspaceScopedFilterModel.FILTER_EXCLUDE_FIELDS, + *WorkspaceScopedFilterModel.FILTER_EXCLUDE_FIELDS, "scope_type", "resource_type", "labels_str", "labels", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *ShareableWorkspaceScopedFilterModel.CLI_EXCLUDE_FIELDS, + *WorkspaceScopedFilterModel.CLI_EXCLUDE_FIELDS, "scope_type", "labels_str", "labels", @@ -1190,10 +1190,6 @@ class ServiceConnectorFilterModel(ShareableWorkspaceScopedFilterModel): description="The type to scope this query to.", ) - is_shared: Optional[Union[bool, str]] = Field( - default=None, - description="If the service connector is shared or private", - ) name: Optional[str] = Field( default=None, description="The name to filter by", @@ -1288,7 +1284,7 @@ class Config: class ServiceConnectorRequestModel( - ServiceConnectorBaseModel, ShareableRequestModel + ServiceConnectorBaseModel, WorkspaceScopedRequestModel ): """Request model for service connectors.""" diff --git a/src/zenml/models/stack_models.py b/src/zenml/models/stack_models.py index 47bb187dd1e..2fc9cc72476 100644 --- a/src/zenml/models/stack_models.py +++ b/src/zenml/models/stack_models.py @@ -21,13 +21,13 @@ from zenml.enums import StackComponentType from zenml.models.base_models import ( - ShareableRequestModel, - ShareableResponseModel, + WorkspaceScopedRequestModel, + WorkspaceScopedResponseModel, update_model, ) from zenml.models.component_models import ComponentResponseModel from zenml.models.constants import STR_FIELD_MAX_LENGTH -from zenml.models.filter_models import ShareableWorkspaceScopedFilterModel +from zenml.models.filter_models import WorkspaceScopedFilterModel # ---- # # BASE # @@ -57,7 +57,7 @@ class StackBaseModel(BaseModel): # -------- # -class StackResponseModel(StackBaseModel, ShareableResponseModel): +class StackResponseModel(StackBaseModel, WorkspaceScopedResponseModel): """Stack model with Components, User and Workspace fully hydrated.""" components: Dict[StackComponentType, List[ComponentResponseModel]] = Field( @@ -117,7 +117,7 @@ def to_yaml(self) -> Dict[str, Any]: # ------ # -class StackFilterModel(ShareableWorkspaceScopedFilterModel): +class StackFilterModel(WorkspaceScopedFilterModel): """Model to enable advanced filtering of all StackModels. The Stack Model needs additional scoping. As such the `_scope_user` field @@ -130,13 +130,10 @@ class StackFilterModel(ShareableWorkspaceScopedFilterModel): # rather than a field in the db, hence it needs to be handled # explicitly FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *ShareableWorkspaceScopedFilterModel.FILTER_EXCLUDE_FIELDS, + *WorkspaceScopedFilterModel.FILTER_EXCLUDE_FIELDS, "component_id", # This is a relationship, not a field ] - is_shared: Optional[Union[bool, str]] = Field( - default=None, description="If the stack is shared or private" - ) name: Optional[str] = Field( default=None, description="Name of the stack", @@ -160,7 +157,7 @@ class StackFilterModel(ShareableWorkspaceScopedFilterModel): # ------- # -class StackRequestModel(StackBaseModel, ShareableRequestModel): +class StackRequestModel(StackBaseModel, WorkspaceScopedRequestModel): """Stack model with components, user and workspace as UUIDs.""" components: Optional[Dict[StackComponentType, List[UUID]]] = Field( diff --git a/src/zenml/orchestrators/local/local_orchestrator.py b/src/zenml/orchestrators/local/local_orchestrator.py index ef0adc53e32..433e233972b 100644 --- a/src/zenml/orchestrators/local/local_orchestrator.py +++ b/src/zenml/orchestrators/local/local_orchestrator.py @@ -117,9 +117,6 @@ class LocalOrchestratorConfig(BaseOrchestratorConfig): def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py b/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py index 76b8684a588..508a39d7bb0 100644 --- a/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py +++ b/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py @@ -261,9 +261,6 @@ class LocalDockerOrchestratorConfig( # type: ignore[misc] # https://github.com/ def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/secrets_managers/local/local_secrets_manager.py b/src/zenml/secrets_managers/local/local_secrets_manager.py index 8b4cf29e6c8..d648ceacdc6 100644 --- a/src/zenml/secrets_managers/local/local_secrets_manager.py +++ b/src/zenml/secrets_managers/local/local_secrets_manager.py @@ -54,9 +54,6 @@ class LocalSecretsManagerConfig(BaseSecretsManagerConfig): def is_local(self) -> bool: """Checks if this stack component is running locally. - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Returns: True if this config is for a local component, False otherwise. """ diff --git a/src/zenml/service_connectors/service_connector.py b/src/zenml/service_connectors/service_connector.py index 06ad9e406aa..b1ece47ecf2 100644 --- a/src/zenml/service_connectors/service_connector.py +++ b/src/zenml/service_connectors/service_connector.py @@ -693,7 +693,6 @@ def to_model( user: UUID, workspace: UUID, name: Optional[str] = None, - is_shared: bool = False, description: str = "", labels: Optional[Dict[str, str]] = None, ) -> "ServiceConnectorRequestModel": @@ -703,7 +702,6 @@ def to_model( name: The name of the connector. user: The ID of the user that created the connector. workspace: The ID of the workspace that the connector belongs to. - is_shared: Whether the connector is shared with other users. description: The description of the connector. labels: The labels of the connector. @@ -728,7 +726,6 @@ def to_model( description=description, user=user, workspace=workspace, - is_shared=is_shared, auth_method=self.auth_method, expires_at=self.expires_at, expiration_seconds=self.expiration_seconds, @@ -752,7 +749,6 @@ def to_response_model( user: Optional[UserResponseModel] = None, name: Optional[str] = None, id: Optional[UUID] = None, - is_shared: bool = False, description: str = "", labels: Optional[Dict[str, str]] = None, ) -> "ServiceConnectorResponseModel": @@ -763,7 +759,6 @@ def to_response_model( user: The user that created the connector. name: The name of the connector. id: The ID of the connector. - is_shared: Whether the connector is shared with other users. description: The description of the connector. labels: The labels of the connector. @@ -792,7 +787,6 @@ def to_response_model( description=description, user=user, workspace=workspace, - is_shared=is_shared, auth_method=self.auth_method, expires_at=self.expires_at, expiration_seconds=self.expiration_seconds, diff --git a/src/zenml/stack/stack_component.py b/src/zenml/stack/stack_component.py index 9a186345977..0e3b3814380 100644 --- a/src/zenml/stack/stack_component.py +++ b/src/zenml/stack/stack_component.py @@ -162,9 +162,6 @@ def is_local(self) -> bool: resources or capabilities (e.g. local filesystem, local database or other services). - This designation is used to determine if the stack component can be - shared with other users or if it is only usable on the local host. - Examples: * Artifact Stores that store artifacts in the local filesystem * Orchestrators that are connected to local orchestration runtime diff --git a/src/zenml/utils/mlstacks_utils.py b/src/zenml/utils/mlstacks_utils.py index 5938fa587ec..3610a81e9b1 100644 --- a/src/zenml/utils/mlstacks_utils.py +++ b/src/zenml/utils/mlstacks_utils.py @@ -510,7 +510,6 @@ def import_new_mlstacks_stack( imported_stack = Client().create_stack( name=stack_name, components=component_ids, - is_shared=False, stack_spec_file=stack_spec_file, ) diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index ec771c2683f..1b477d7e4f9 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -80,7 +80,7 @@ DEFAULT_USERNAME = "default" DEFAULT_PASSWORD = "" DEFAULT_WORKSPACE_NAME = "default" -DEFAULT_STACK_AND_COMPONENT_NAME = "default" +DEFAULT_STACK_AND_COMPONENT_NAME_PREFIX = "default" @make_proxy_class(SecretsStoreInterface, "_secrets_store") @@ -307,7 +307,7 @@ def _initialize_database(self) -> None: except KeyError: default_user = self._create_default_user() self._get_or_create_default_stack( - workspace_id=default_workspace.id, + workspace=default_workspace, user_id=default_user.id, ) @@ -412,19 +412,6 @@ def validate_active_config( active_stack = self._get_or_create_default_stack( active_workspace ) - elif not active_stack.is_shared and ( - not active_stack.user - or (active_stack.user.id != self.get_user().id) - ): - logger.warning( - "The current %s active stack is not shared and not " - "owned by the active user. " - "Resetting the active stack to default.", - config_name, - ) - active_stack = self._get_or_create_default_stack( - active_workspace - ) else: logger.warning( "Setting the %s active stack to default.", @@ -465,17 +452,19 @@ def is_local_store(self) -> bool: return self.get_store_info().is_local() def _get_or_create_default_stack( - self, workspace: "WorkspaceResponseModel" + self, + workspace: "WorkspaceResponseModel", + user_id: Optional[UUID] = None, ) -> "StackResponseModel": try: return self._get_default_stack( workspace_id=workspace.id, - user_id=self.get_user().id, + user_id=user_id or self.get_user().id, ) except KeyError: return self._create_default_stack( workspace_id=workspace.id, - user_id=self.get_user().id, + user_id=user_id or self.get_user().id, ) def _get_or_create_default_workspace(self) -> "WorkspaceResponseModel": @@ -584,7 +573,6 @@ def _create_default_stack( stack = StackRequestModel( name=name, components=components, - is_shared=False, workspace=workspace.id, user=user.id, ) @@ -599,7 +587,7 @@ def _get_default_stack_and_component_name(self, user_id: UUID) -> str: Returns: The default stack/component name. """ - return f"{DEFAULT_STACK_AND_COMPONENT_NAME}-{user_id}" + return f"{DEFAULT_STACK_AND_COMPONENT_NAME_PREFIX}-{user_id}" def _get_default_stack( self, diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py new file mode 100644 index 00000000000..f9498806463 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -0,0 +1,51 @@ +"""Remove shared columns [7500f434b71c]. + +Revision ID: 7500f434b71c +Revises: 0.45.1 +Create Date: 2023-10-16 15:15:34.865337 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "7500f434b71c" +down_revision = "0.45.1" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table("service_connector", schema=None) as batch_op: + batch_op.drop_column("is_shared") + + with op.batch_alter_table("stack", schema=None) as batch_op: + batch_op.drop_column("is_shared") + + with op.batch_alter_table("stack_component", schema=None) as batch_op: + batch_op.drop_column("is_shared") + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("stack_component", schema=None) as batch_op: + batch_op.add_column( + sa.Column("is_shared", sa.BOOLEAN(), nullable=False) + ) + + with op.batch_alter_table("stack", schema=None) as batch_op: + batch_op.add_column( + sa.Column("is_shared", sa.BOOLEAN(), nullable=False) + ) + + with op.batch_alter_table("service_connector", schema=None) as batch_op: + batch_op.add_column( + sa.Column("is_shared", sa.BOOLEAN(), nullable=False) + ) + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/schemas/base_schemas.py b/src/zenml/zen_stores/schemas/base_schemas.py index 43a2a22e1bc..97523123723 100644 --- a/src/zenml/zen_stores/schemas/base_schemas.py +++ b/src/zenml/zen_stores/schemas/base_schemas.py @@ -31,9 +31,3 @@ class NamedSchema(BaseSchema): """Base Named SQL Model.""" name: str - - -class ShareableSchema(NamedSchema): - """Base shareable SQL Model.""" - - is_shared: bool diff --git a/src/zenml/zen_stores/schemas/component_schemas.py b/src/zenml/zen_stores/schemas/component_schemas.py index 4024191fc25..33eb7eb3ece 100644 --- a/src/zenml/zen_stores/schemas/component_schemas.py +++ b/src/zenml/zen_stores/schemas/component_schemas.py @@ -26,7 +26,7 @@ ComponentResponseModel, ComponentUpdateModel, ) -from zenml.zen_stores.schemas.base_schemas import ShareableSchema +from zenml.zen_stores.schemas.base_schemas import NamedSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.service_connector_schemas import ( ServiceConnectorSchema, @@ -44,7 +44,7 @@ from zenml.zen_stores.schemas import ScheduleSchema -class StackComponentSchema(ShareableSchema, table=True): +class StackComponentSchema(NamedSchema, table=True): """SQL Model for stack components.""" __tablename__ = "stack_component" @@ -150,7 +150,6 @@ def to_model( workspace=self.workspace.to_model(), connector=self.connector.to_model() if self.connector else None, connector_resource_id=self.connector_resource_id, - is_shared=self.is_shared, configuration=json.loads( base64.b64decode(self.configuration).decode() ), diff --git a/src/zenml/zen_stores/schemas/service_connector_schemas.py b/src/zenml/zen_stores/schemas/service_connector_schemas.py index 9f3cfe5300c..d24be586f2a 100644 --- a/src/zenml/zen_stores/schemas/service_connector_schemas.py +++ b/src/zenml/zen_stores/schemas/service_connector_schemas.py @@ -27,7 +27,7 @@ ServiceConnectorResponseModel, ServiceConnectorUpdateModel, ) -from zenml.zen_stores.schemas.base_schemas import ShareableSchema +from zenml.zen_stores.schemas.base_schemas import NamedSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema @@ -36,7 +36,7 @@ from zenml.zen_stores.schemas.component_schemas import StackComponentSchema -class ServiceConnectorSchema(ShareableSchema, table=True): +class ServiceConnectorSchema(NamedSchema, table=True): """SQL Model for service connectors.""" __tablename__ = "service_connector" @@ -144,7 +144,6 @@ def from_request( return cls( workspace_id=connector_request.workspace, user_id=connector_request.user, - is_shared=connector_request.is_shared, name=connector_request.name, description=connector_request.description, connector_type=connector_request.type, @@ -241,7 +240,6 @@ def to_model( description=self.description, user=self.user.to_model(True) if self.user else None, workspace=self.workspace.to_model(), - is_shared=self.is_shared, created=self.created, updated=self.updated, connector_type=self.connector_type, diff --git a/src/zenml/zen_stores/schemas/stack_schemas.py b/src/zenml/zen_stores/schemas/stack_schemas.py index b0deacb0055..5a9f30e5ab4 100644 --- a/src/zenml/zen_stores/schemas/stack_schemas.py +++ b/src/zenml/zen_stores/schemas/stack_schemas.py @@ -20,7 +20,7 @@ from sqlmodel import Relationship, SQLModel from zenml.models import StackResponseModel -from zenml.zen_stores.schemas.base_schemas import ShareableSchema +from zenml.zen_stores.schemas.base_schemas import NamedSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema @@ -62,7 +62,7 @@ class StackCompositionSchema(SQLModel, table=True): ) -class StackSchema(ShareableSchema, table=True): +class StackSchema(NamedSchema, table=True): """SQL Model for stacks.""" __tablename__ = "stack" @@ -139,7 +139,6 @@ def to_model(self) -> "StackResponseModel": stack_spec_path=self.stack_spec_path, user=self.user.to_model(True) if self.user else None, workspace=self.workspace.to_model(), - is_shared=self.is_shared, components={c.type: [c.to_model()] for c in self.components}, created=self.created, updated=self.updated, diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index fd7e3a7ac52..789be3ea89c 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -180,8 +180,7 @@ ) from zenml.utils.string_utils import random_str from zenml.zen_stores.base_zen_store import ( - DEFAULT_STACK_COMPONENT_NAME, - DEFAULT_STACK_NAME, + DEFAULT_STACK_AND_COMPONENT_NAME_PREFIX, BaseZenStore, ) from zenml.zen_stores.enums import StoreEvent @@ -1022,14 +1021,8 @@ def create_stack(self, stack: StackRequestModel) -> StackResponseModel: The registered stack. """ with Session(self.engine) as session: - self._fail_if_stack_with_name_exists_for_user( - stack=stack, session=session - ) - - if stack.is_shared: - self._fail_if_stack_with_name_already_shared( - stack=stack, session=session - ) + self._fail_if_stack_name_reserved(stack_name=stack.name) + self._fail_if_stack_with_name_exists(stack=stack, session=session) # Get the Schemas of all components mentioned component_ids = ( @@ -1053,7 +1046,6 @@ def create_stack(self, stack: StackRequestModel) -> StackResponseModel: new_stack_schema = StackSchema( workspace_id=stack.workspace, user_id=stack.user, - is_shared=stack.is_shared, stack_spec_path=stack.stack_spec_path, name=stack.name, description=stack.description, @@ -1156,7 +1148,12 @@ def update_stack( f"Unable to update stack with id '{stack_id}': Found no" f"existing stack with this id." ) - if existing_stack.name == DEFAULT_STACK_NAME: + if ( + existing_stack.name + == self._get_default_stack_and_component_name( + existing_stack.user_id + ) + ): raise IllegalOperationError( "The default stack cannot be modified." ) @@ -1164,16 +1161,10 @@ def update_stack( # with that name if stack_update.name: if existing_stack.name != stack_update.name: - self._fail_if_stack_with_name_exists_for_user( - stack=stack_update, session=session + self._fail_if_stack_name_reserved( + stack_name=stack_update.name ) - - # Check if stack update makes the stack a shared stack. In that - # case, check if a stack with the same name is already shared - # within the workspace - if stack_update.is_shared: - if not existing_stack.is_shared and stack_update.is_shared: - self._fail_if_stack_with_name_already_shared( + self._fail_if_stack_with_name_exists( stack=stack_update, session=session ) @@ -1217,7 +1208,9 @@ def delete_stack(self, stack_id: UUID) -> None: if stack is None: raise KeyError(f"Stack with ID {stack_id} not found.") - if stack.name == DEFAULT_STACK_NAME: + if stack.name == self._get_default_stack_and_component_name( + user_id=stack.user_id + ): raise IllegalOperationError( "The default stack cannot be deleted." ) @@ -1227,85 +1220,48 @@ def delete_stack(self, stack_id: UUID) -> None: session.commit() - def _fail_if_stack_with_name_exists_for_user( + def _fail_if_stack_with_name_exists( self, stack: StackRequestModel, session: Session, ) -> None: - """Raise an exception if a Component with same name exists for user. + """Raise an exception if a stack with same name exists. Args: stack: The Stack session: The Session - Returns: - None - Raises: - StackExistsError: If a Stack with the given name is already - owned by the user + StackExistsError: If a Stack with the given name already exists. """ existing_domain_stack = session.exec( select(StackSchema) .where(StackSchema.name == stack.name) .where(StackSchema.workspace_id == stack.workspace) - .where(StackSchema.user_id == stack.user) ).first() if existing_domain_stack is not None: workspace = self._get_workspace_schema( workspace_name_or_id=stack.workspace, session=session ) - user = self._get_user_schema( - user_name_or_id=stack.user, session=session - ) raise StackExistsError( f"Unable to register stack with name " f"'{stack.name}': Found an existing stack with the same " - f"name in the active workspace, '{workspace.name}', " - f"owned by the same user, '{user.name}'." + f"name in the active workspace '{workspace.name}'." ) - return None - def _fail_if_stack_with_name_already_shared( - self, - stack: StackRequestModel, - session: Session, - ) -> None: - """Raise an exception if a Stack with same name is already shared. + def _fail_if_stack_name_reserved(self, stack_name: str) -> None: + """Raise an exception if the stack name is reserved. Args: - stack: The Stack - session: The Session + stack_name: The stack name. Raises: - StackExistsError: If a stack with the given name is already shared - by a user. + IllegalOperationError: If the stack name is reserved. """ - # Check if component with the same name, type is already shared - # within the workspace - existing_shared_stack = session.exec( - select(StackSchema) - .where(StackSchema.name == stack.name) - .where(StackSchema.workspace_id == stack.workspace) - .where(StackSchema.is_shared == stack.is_shared) - ).first() - if existing_shared_stack is not None: - workspace = self._get_workspace_schema( - workspace_name_or_id=stack.workspace, session=session - ) - error_msg = ( - f"Unable to share stack with name '{stack.name}': Found an " - f"existing shared stack with the same name in workspace " - f"'{workspace.name}'" + if stack_name == DEFAULT_STACK_AND_COMPONENT_NAME_PREFIX: + raise IllegalOperationError( + f"Unable to register stack with reserved name '{stack_name}'." ) - if existing_shared_stack.user_id: - owner_of_shared = self._get_user_schema( - existing_shared_stack.user_id, session=session - ) - error_msg += f" owned by '{owner_of_shared.name}'." - else: - error_msg += ", which is currently not owned by any user." - raise StackExistsError(error_msg) # ---------------- # Stack components @@ -1329,22 +1285,16 @@ def create_stack_component( connector. """ with Session(self.engine) as session: - self._fail_if_component_with_name_type_exists_for_user( + self._fail_if_component_name_reserved( + component_name=component.name + ) + self._fail_if_component_with_name_type_exists( name=component.name, component_type=component.type, - user_id=component.user, workspace_id=component.workspace, session=session, ) - if component.is_shared: - self._fail_if_component_with_name_type_already_shared( - name=component.name, - component_type=component.type, - workspace_id=component.workspace, - session=session, - ) - service_connector: Optional[ServiceConnectorSchema] = None if component.connector: service_connector = session.exec( @@ -1364,7 +1314,6 @@ def create_stack_component( name=component.name, workspace_id=component.workspace, user_id=component.user, - is_shared=component.is_shared, component_spec_path=component.component_spec_path, type=component.type, flavor=component.flavor, @@ -1482,7 +1431,10 @@ def update_stack_component( ) if ( - existing_component.name == DEFAULT_STACK_COMPONENT_NAME + existing_component.name + == self._get_default_stack_and_component_name( + user_id=existing_component.user_id + ) and existing_component.type in [ StackComponentType.ORCHESTRATOR, @@ -1496,28 +1448,12 @@ def update_stack_component( # In case of a renaming update, make sure no component of the same # type already exists with that name if component_update.name: - if ( - existing_component.name != component_update.name - and existing_component.user_id is not None - ): - self._fail_if_component_with_name_type_exists_for_user( - name=component_update.name, - component_type=existing_component.type, - workspace_id=existing_component.workspace_id, - user_id=existing_component.user_id, - session=session, + if existing_component.name != component_update.name: + self._fail_if_component_name_reserved( + component_name=component_update.name ) - - # Check if component update makes the component a shared component, - # In that case check if a component with the same name, type are - # already shared within the workspace - if component_update.is_shared: - if ( - not existing_component.is_shared - and component_update.is_shared - ): - self._fail_if_component_with_name_type_already_shared( - name=component_update.name or existing_component.name, + self._fail_if_component_with_name_type_exists( + name=component_update.name, component_type=existing_component.type, workspace_id=existing_component.workspace_id, session=session, @@ -1569,7 +1505,10 @@ def delete_stack_component(self, component_id: UUID) -> None: if stack_component is None: raise KeyError(f"Stack with ID {component_id} not found.") if ( - stack_component.name == DEFAULT_STACK_COMPONENT_NAME + stack_component.name + == self._get_default_stack_and_component_name( + user_id=stack_component.user_id + ) and stack_component.type in [ StackComponentType.ORCHESTRATOR, @@ -1598,88 +1537,54 @@ def delete_stack_component(self, component_id: UUID) -> None: session.commit() @staticmethod - def _fail_if_component_with_name_type_exists_for_user( + def _fail_if_component_with_name_type_exists( name: str, component_type: StackComponentType, workspace_id: UUID, - user_id: UUID, session: Session, ) -> None: - """Raise an exception if a Component with same name/type exists. + """Raise an exception if a component with same name/type exists. Args: name: The name of the component component_type: The type of the component workspace_id: The ID of the workspace - user_id: The ID of the user session: The Session - Returns: - None - Raises: StackComponentExistsError: If a component with the given name and - type is already owned by the user + type already exists. """ - assert user_id - # Check if component with the same domain key (name, type, workspace, - # owner) already exists + # Check if component with the same domain key (name, type, workspace) + # already exists existing_domain_component = session.exec( select(StackComponentSchema) .where(StackComponentSchema.name == name) .where(StackComponentSchema.workspace_id == workspace_id) - .where(StackComponentSchema.user_id == user_id) .where(StackComponentSchema.type == component_type) ).first() if existing_domain_component is not None: - # Theoretically the user schema is optional, in this case there is - # no way that it will be None - assert existing_domain_component.user raise StackComponentExistsError( f"Unable to register '{component_type.value}' component " f"with name '{name}': Found an existing " f"component with the same name and type in the same " - f" workspace, '{existing_domain_component.workspace.name}', " - f"owned by the same user, " - f"'{existing_domain_component.user.name}'." + f" workspace '{existing_domain_component.workspace.name}'." ) return None - @staticmethod - def _fail_if_component_with_name_type_already_shared( - name: str, - component_type: StackComponentType, - workspace_id: UUID, - session: Session, - ) -> None: - """Raise an exception if a Component with same name/type already shared. + def _fail_if_component_name_reserved(self, component_name: str) -> None: + """Raise an exception if the component name is reserved. Args: - name: The name of the component - component_type: The type of the component - workspace_id: The ID of the workspace - session: The Session + component_name: The component name. Raises: - StackComponentExistsError: If a component with the given name and - type is already shared by a user + IllegalOperationError: If the component name is reserved. """ - # Check if component with the same name, type is already shared - # within the workspace - is_shared = True - existing_shared_component = session.exec( - select(StackComponentSchema) - .where(StackComponentSchema.name == name) - .where(StackComponentSchema.workspace_id == workspace_id) - .where(StackComponentSchema.type == component_type) - .where(StackComponentSchema.is_shared == is_shared) - ).first() - if existing_shared_component is not None: - raise StackComponentExistsError( - f"Unable to shared component of type '{component_type.value}' " - f"with name '{name}': Found an existing shared " - f"component with the same name and type in workspace " - f"'{workspace_id}'." + if component_name == DEFAULT_STACK_AND_COMPONENT_NAME_PREFIX: + raise IllegalOperationError( + f"Unable to register component with reserved name " + f"'{component_name}'." ) # ----------------------- @@ -3510,10 +3415,9 @@ def delete_code_repository(self, code_repository_id: UUID) -> None: # ------------------ @staticmethod - def _fail_if_service_connector_with_name_exists_for_user( + def _fail_if_service_connector_with_name_exists( name: str, workspace_id: UUID, - user_id: UUID, session: Session, ) -> None: """Raise an exception if a service connector with same name exists. @@ -3521,71 +3425,27 @@ def _fail_if_service_connector_with_name_exists_for_user( Args: name: The name of the service connector workspace_id: The ID of the workspace - user_id: The ID of the user session: The Session - Returns: - None - Raises: - EntityExistsError: If a service connector with the given name is - already owned by the user + EntityExistsError: If a service connector with the given name + already exists. """ - assert user_id - # Check if service connector with the same domain key (name, workspace, - # owner) already exists + # Check if service connector with the same domain key (name, workspace) + # already exists existing_domain_connector = session.exec( select(ServiceConnectorSchema) .where(ServiceConnectorSchema.name == name) .where(ServiceConnectorSchema.workspace_id == workspace_id) - .where(ServiceConnectorSchema.user_id == user_id) ).first() if existing_domain_connector is not None: - # Theoretically the user schema is optional, in this case there is - # no way that it will be None - assert existing_domain_connector.user raise EntityExistsError( f"Unable to register service connector with name '{name}': " "Found an existing service connector with the same name in the " - f"same workspace, '{existing_domain_connector.workspace.name}', " - "owned by the same user, " - f"{existing_domain_connector.user.name}'." + f"same workspace '{existing_domain_connector.workspace.name}'." ) return None - @staticmethod - def _fail_if_service_connector_with_name_already_shared( - name: str, - workspace_id: UUID, - session: Session, - ) -> None: - """Raise an exception if a service connector with same name is already shared. - - Args: - name: The name of the service connector - workspace_id: The ID of the workspace - session: The Session - - Raises: - EntityExistsError: If a service connector with the given name is - already shared by another user - """ - # Check if a service connector with the same name is already shared - # within the workspace - is_shared = True - existing_shared_connector = session.exec( - select(ServiceConnectorSchema) - .where(ServiceConnectorSchema.name == name) - .where(ServiceConnectorSchema.workspace_id == workspace_id) - .where(ServiceConnectorSchema.is_shared == is_shared) - ).first() - if existing_shared_connector is not None: - raise EntityExistsError( - f"Unable to share service connector with name '{name}': Found " - "an existing shared service connector with the same name in " - f"workspace '{workspace_id}'." - ) - def _create_connector_secret( self, connector_name: str, @@ -3707,20 +3567,12 @@ def create_service_connector( ) with Session(self.engine) as session: - self._fail_if_service_connector_with_name_exists_for_user( + self._fail_if_service_connector_with_name_exists( name=service_connector.name, - user_id=service_connector.user, workspace_id=service_connector.workspace, session=session, ) - if service_connector.is_shared: - self._fail_if_service_connector_with_name_already_shared( - name=service_connector.name, - workspace_id=service_connector.workspace, - session=session, - ) - # Create the secret secret_id = self._create_connector_secret( connector_name=service_connector.name, From 831ed0c9e55c3e1522492a42378b70a4f10daf63 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 17 Oct 2023 12:16:31 +0200 Subject: [PATCH 005/103] Some fixes --- src/zenml/config/server_config.py | 15 +++++++ src/zenml/models/base_models.py | 41 ++++++++++++++++++- src/zenml/zen_server/auth.py | 36 ++++++---------- src/zenml/zen_server/rbac_interface.py | 13 +++--- .../routers/service_connectors_endpoints.py | 1 - .../routers/stack_components_endpoints.py | 1 - .../zen_server/routers/stacks_endpoints.py | 2 - .../routers/workspaces_endpoints.py | 2 - src/zenml/zen_server/utils.py | 9 ++-- 9 files changed, 82 insertions(+), 38 deletions(-) diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index 0da27209d11..c561d2fabb5 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -106,6 +106,10 @@ class ServerConfiguration(BaseModel): external_server_id: The ID of the ZenML server to use with the `EXTERNAL` authentication scheme. If not specified, the regular ZenML server ID is used. + rbac_implementation_source: Source pointing to a class implementing + the RBAC interface defined by + `zenml.zen_server.rbac_interface.RBACInterface`. If not specified, + RBAC will not be enabled for this server. """ deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER @@ -136,6 +140,8 @@ class ServerConfiguration(BaseModel): external_cookie_name: Optional[str] = None external_server_id: Optional[UUID] = None + rbac_implementation_source: Optional[str] = None + _deployment_id: Optional[UUID] = None @root_validator(pre=True) @@ -197,6 +203,15 @@ def deployment_id(self) -> UUID: return self._deployment_id + @property + def rbac_enabled(self) -> bool: + """Whether RBAC is enabled on the server or not. + + Returns: + Whether RBAC is enabled on the server or not. + """ + return self.rbac_implementation_source is not None + def get_jwt_token_issuer(self) -> str: """Get the JWT token issuer. diff --git a/src/zenml/models/base_models.py b/src/zenml/models/base_models.py index 91b952ec632..caae7a61d1d 100644 --- a/src/zenml/models/base_models.py +++ b/src/zenml/models/base_models.py @@ -13,7 +13,17 @@ # permissions and limitations under the License. """Base domain model definitions.""" from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Set, + Tuple, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field, SecretStr @@ -110,6 +120,35 @@ def get_analytics_metadata(self) -> Dict[str, Any]: metadata["entity_id"] = self.id return metadata + @property + def partial(self) -> bool: + """Returns if this model is incomplete. + + A model is incomplete if the user has no permissions to read the + model itself or any submodel contained in this model. + + Returns: + True if the model is incomplete, False otherwise. + """ + if self.missing_permissions: + return True + + def _helper(value: Any) -> bool: + if isinstance(value, BaseResponseModel): + if value.partial: + return True + elif isinstance(value, Dict): + return any(_helper(v) for v in value.values()) + elif isinstance(value, (List, Set, Tuple)): + return any(_helper(v) for v in value) + + for field_name in self.__fields__.keys(): + value = getattr(self, field_name) + if _helper(value): + return True + + return False + class UserScopedResponseModel(BaseResponseModel): """Base user-owned domain model. diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 58685b9c9c9..14825683a28 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Authentication module for ZenML server.""" -import os from contextvars import ContextVar from datetime import datetime from enum import Enum @@ -681,6 +680,9 @@ def authentication_provider() -> Callable[..., AuthContext]: def verify_read_permissions_and_dehydrate( model: "BaseResponseModel", ) -> "BaseResponseModel": + if not server_config().rbac_enabled: + return model + verify_permissions_for_model(model=model, action="READ") return dehydrate_response_model(model=model) @@ -710,7 +712,7 @@ def _maybe_dehydrate_value(value: Any) -> Any: type_ = type(value) return type_(_maybe_dehydrate_value(v) for v in value) else: - return value, False + return value def has_read_permissions_for_model(model: "BaseResponseModel") -> bool: @@ -760,17 +762,13 @@ def verify_permissions_for_model( model: "BaseResponseModel", action: str, ) -> None: - """Verifies if a user has permissions to perform an action on a resource. + """Verifies if a user has permissions to perform an action on a model. Args: - resource: The resource type the user wants to perform the action on. + model: The model the user wants to perform the action on. action: The action the user wants to perform. - resource_id: ID of the resource the user wants to perform the action on. - - Raises: - HTTPException: If the user is not allowed to perform the action. """ - if "ZENML_CLOUD" not in os.environ: + if not server_config().rbac_enabled: return if ( @@ -807,20 +805,13 @@ def verify_permissions( Raises: HTTPException: If the user is not allowed to perform the action. """ - if "ZENML_CLOUD" not in os.environ: + if not server_config().rbac_enabled: return - if resource_type != "stack": - raise HTTPException(status_code=403) - - return - - user_id = get_auth_context().user.external_user_id - assert user_id resource = Resource(type=resource_type, id=resource_id) if not rbac().has_permission( - user=user_id, resource=resource, action=action + user=get_auth_context().user, resource=resource, action=action ): raise HTTPException(status_code=403) @@ -839,17 +830,16 @@ def get_allowed_resource_ids( A list of resource IDs or `None` if the user has full access to the all instances of the resource. """ - if "ZENML_CLOUD" not in os.environ: - # Full access in any case + if not server_config().rbac_enabled: return None - user_id = get_auth_context().user.external_user_id - assert user_id ( has_full_resource_access, allowed_ids, ) = rbac().list_allowed_resource_ids( - user=user_id, resource=Resource(type=resource_type), action=action + user=get_auth_context().user, + resource=Resource(type=resource_type), + action=action, ) if has_full_resource_access: diff --git a/src/zenml/zen_server/rbac_interface.py b/src/zenml/zen_server/rbac_interface.py index b1cfbec68fa..fb9e0b83247 100644 --- a/src/zenml/zen_server/rbac_interface.py +++ b/src/zenml/zen_server/rbac_interface.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type from uuid import UUID from pydantic import BaseModel @@ -7,6 +7,9 @@ from zenml.enums import StrEnum from zenml.models.base_models import BaseResponseModel +if TYPE_CHECKING: + from zenml.models import UserResponseModel + class Action(StrEnum): CREATE = "create" @@ -65,12 +68,12 @@ class Resource(BaseModel): class RBACInterface(ABC): @abstractmethod def has_permission( - self, user: UUID, resource: Resource, action: str + self, user: "UserResponseModel", resource: Resource, action: str ) -> bool: """Checks if a user has permission to perform an action on a resource. Args: - user: ID of the user which wants to access a resource. + user: User which wants to access a resource. resource: The resource the user wants to access. action: The action that the user wants to perform on the resource. @@ -80,12 +83,12 @@ def has_permission( @abstractmethod def list_allowed_resource_ids( - self, user: UUID, resource: Resource, action: str + self, user: "UserResponseModel", resource: Resource, action: str ) -> Tuple[bool, List[str]]: """Lists all resource IDs of a resource type that a user can access. Args: - user: ID of the user which wants to access a resource. + user: User which wants to access a resource. resource: The resource the user wants to access. action: The action that the user wants to perform on the resource. diff --git a/src/zenml/zen_server/routers/service_connectors_endpoints.py b/src/zenml/zen_server/routers/service_connectors_endpoints.py index 819d0bf14d9..047b5b4fba9 100644 --- a/src/zenml/zen_server/routers/service_connectors_endpoints.py +++ b/src/zenml/zen_server/routers/service_connectors_endpoints.py @@ -82,7 +82,6 @@ def list_service_connectors( Returns: Page with list of service connectors for a specific type. """ - connector_filter_model.set_scope_user(user_id=auth_context.user.id) connectors = zen_store().list_service_connectors( filter_model=connector_filter_model ) diff --git a/src/zenml/zen_server/routers/stack_components_endpoints.py b/src/zenml/zen_server/routers/stack_components_endpoints.py index 325eb887778..fed9b2d2744 100644 --- a/src/zenml/zen_server/routers/stack_components_endpoints.py +++ b/src/zenml/zen_server/routers/stack_components_endpoints.py @@ -70,7 +70,6 @@ def list_stack_components( Returns: List of stack components for a specific type. """ - component_filter_model.set_scope_user(user_id=auth_context.user.id) return zen_store().list_stack_components( component_filter_model=component_filter_model ) diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index 3f077817bc9..f40b1ba0e24 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -64,8 +64,6 @@ def list_stacks( Returns: All stacks. """ - stack_filter_model.set_scope_user(user_id=auth_context.user.id) - return zen_store().list_stacks(stack_filter_model=stack_filter_model) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 6a3249f4964..cd3628b1223 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -252,7 +252,6 @@ def list_workspace_stacks( """ workspace = zen_store().get_workspace(workspace_name_or_id) stack_filter_model.set_scope_workspace(workspace.id) - stack_filter_model.set_scope_user(user_id=auth_context.user.id) return zen_store().list_stacks(stack_filter_model=stack_filter_model) @@ -330,7 +329,6 @@ def list_workspace_stack_components( """ workspace = zen_store().get_workspace(workspace_name_or_id) component_filter_model.set_scope_workspace(workspace.id) - component_filter_model.set_scope_user(user_id=auth_context.user.id) return zen_store().list_stack_components( component_filter_model=component_filter_model ) diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index addd5b6542c..a24003aafef 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -75,10 +75,13 @@ def rbac() -> RBACInterface: def initialize_rbac() -> None: """Initialize the RBAC component.""" - from zenml.zen_server.cloud_rbac import CloudRBAC - global _rbac - _rbac = CloudRBAC() + + if rbac_source := server_config().rbac_implementation_source: + from zenml.utils import source_utils + + implementation_class = source_utils.load(rbac_source) + _rbac = implementation_class() def initialize_zen_store() -> None: From f2f68430606044c2333ff3afad82163c011192c6 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 17 Oct 2023 13:10:20 +0200 Subject: [PATCH 006/103] Remove auth scopes --- src/zenml/models/filter_models.py | 23 ++++ src/zenml/zen_server/auth.py | 12 +- .../zen_server/routers/artifacts_endpoints.py | 11 +- .../zen_server/routers/auth_endpoints.py | 5 +- .../routers/code_repositories_endpoints.py | 9 +- .../zen_server/routers/devices_endpoints.py | 22 +--- .../zen_server/routers/flavors_endpoints.py | 15 +-- .../zen_server/routers/models_endpoints.py | 26 ++-- .../routers/pipeline_builds_endpoints.py | 7 +- .../routers/pipeline_deployments_endpoints.py | 7 +- .../zen_server/routers/pipelines_endpoints.py | 17 +-- .../routers/run_metadata_endpoints.py | 3 +- .../zen_server/routers/runs_endpoints.py | 18 +-- .../zen_server/routers/schedule_endpoints.py | 9 +- .../zen_server/routers/secrets_endpoints.py | 12 +- .../routers/service_connectors_endpoints.py | 30 ++--- .../routers/stack_components_endpoints.py | 14 +-- .../zen_server/routers/stacks_endpoints.py | 26 ++-- .../zen_server/routers/steps_endpoints.py | 16 +-- .../zen_server/routers/users_endpoints.py | 28 ++--- .../routers/workspaces_endpoints.py | 111 ++++++------------ src/zenml/zen_stores/sql_zen_store.py | 5 - 22 files changed, 173 insertions(+), 253 deletions(-) diff --git a/src/zenml/models/filter_models.py b/src/zenml/models/filter_models.py index 6c5dbddb610..5ba2324d4e1 100644 --- a/src/zenml/models/filter_models.py +++ b/src/zenml/models/filter_models.py @@ -299,6 +299,8 @@ class BaseFilterModel(BaseModel): default=None, description="Updated" ) + _allowed_ids: Optional[List[UUID]] = None + @validator("sort_by", pre=True) def validate_sort_by(cls, v: str) -> str: """Validate that the sort_column is a valid column with a valid operand. @@ -390,6 +392,17 @@ def sorting_params(self) -> Tuple[str, SorterOps]: return column, operator + def set_allowed_ids(self, allowed_ids: Optional[List[UUID]]) -> None: + """Set allowed IDs for the query. + + Args: + allowed_ids: List of IDs to limit the query to. If given, the + remaining filters will be applied to entities within this list + only. If `None`, the remaining filters will applied to all + entries in the table. + """ + self._allowed_ids = allowed_ids + @classmethod def _generate_filter_list(cls, values: Dict[str, Any]) -> List[Filter]: """Create a list of filters from a (column, value) dictionary. @@ -754,6 +767,9 @@ def apply_filter( Returns: The query with filter applied. """ + if self._allowed_ids is not None: + query = query.where(table.id.in_(self._allowed_ids)) + filters = self.generate_filter(table=table) if filters is not None: @@ -761,6 +777,13 @@ def apply_filter( return query + class Config: + """Pydantic configuration class.""" + + # all attributes with leading underscore are private and therefore + # are mutable and not included in serialization + underscore_attrs_are_private = True + class WorkspaceScopedFilterModel(BaseFilterModel): """Model to enable advanced scoping with workspace.""" diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 14825683a28..17464dec6e5 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -603,11 +603,6 @@ def oauth2_authentication( token: str = Depends( CookieOAuth2TokenBearer( tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN, - scopes={ - "read": "Read permissions on all entities", - "write": "Write permissions on all entities", - "me": "Editing permissions to own user", - }, ) ), ) -> AuthContext: @@ -623,18 +618,13 @@ def oauth2_authentication( Raises: HTTPException: If the JWT token could not be authorized. """ - if security_scopes.scopes: - pass - else: - authenticate_value = "Bearer" - try: auth_context = authenticate_credentials(access_token=token) except AuthorizationException as e: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e), - headers={"WWW-Authenticate": authenticate_value}, + headers={"WWW-Authenticate": "Bearer"}, ) return auth_context diff --git a/src/zenml/zen_server/routers/artifacts_endpoints.py b/src/zenml/zen_server/routers/artifacts_endpoints.py index 670001648ea..dd4b766d269 100644 --- a/src/zenml/zen_server/routers/artifacts_endpoints.py +++ b/src/zenml/zen_server/routers/artifacts_endpoints.py @@ -18,7 +18,6 @@ from fastapi import APIRouter, Depends, Security from zenml.constants import API, ARTIFACTS, VERSION_1, VISUALIZE -from zenml.enums import PermissionType from zenml.models import ( ArtifactFilterModel, ArtifactRequestModel, @@ -54,7 +53,7 @@ def list_artifacts( artifact_filter_model: ArtifactFilterModel = Depends( make_dependable(ArtifactFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[ArtifactResponseModel]: """Get artifacts according to query filters. @@ -78,7 +77,7 @@ def list_artifacts( @handle_exceptions def create_artifact( artifact: ArtifactRequestModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> ArtifactResponseModel: """Create a new artifact. @@ -99,7 +98,7 @@ def create_artifact( @handle_exceptions def get_artifact( artifact_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> ArtifactResponseModel: """Get an artifact by ID. @@ -119,7 +118,7 @@ def get_artifact( @handle_exceptions def delete_artifact( artifact_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Delete an artifact by ID. @@ -138,7 +137,7 @@ def delete_artifact( def get_artifact_visualization( artifact_id: UUID, index: int = 0, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> LoadedVisualizationModel: """Get the visualization of an artifact. diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index 76420ad8a3f..ab1205b3379 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -43,7 +43,6 @@ AuthScheme, OAuthDeviceStatus, OAuthGrantTypes, - PermissionType, ) from zenml.logger import get_logger from zenml.models import ( @@ -464,9 +463,7 @@ def api_token( pipeline_id: Optional[UUID] = None, schedule_id: Optional[UUID] = None, expires_minutes: Optional[int] = None, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> str: """Get a workload API token for the current user. diff --git a/src/zenml/zen_server/routers/code_repositories_endpoints.py b/src/zenml/zen_server/routers/code_repositories_endpoints.py index 5f816b8b51d..9882bbf1697 100644 --- a/src/zenml/zen_server/routers/code_repositories_endpoints.py +++ b/src/zenml/zen_server/routers/code_repositories_endpoints.py @@ -17,7 +17,6 @@ from fastapi import APIRouter, Depends, Security from zenml.constants import API, CODE_REPOSITORIES, VERSION_1 -from zenml.enums import PermissionType from zenml.models import ( CodeRepositoryFilterModel, CodeRepositoryResponseModel, @@ -49,7 +48,7 @@ def list_code_repositories( filter_model: CodeRepositoryFilterModel = Depends( make_dependable(CodeRepositoryFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[CodeRepositoryResponseModel]: """Gets a page of code repositories. @@ -71,7 +70,7 @@ def list_code_repositories( @handle_exceptions def get_code_repository( code_repository_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> CodeRepositoryResponseModel: """Gets a specific code repository using its unique ID. @@ -95,7 +94,7 @@ def get_code_repository( def update_code_repository( code_repository_id: UUID, update: CodeRepositoryUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> CodeRepositoryResponseModel: """Updates a code repository. @@ -118,7 +117,7 @@ def update_code_repository( @handle_exceptions def delete_code_repository( code_repository_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a specific code repository. diff --git a/src/zenml/zen_server/routers/devices_endpoints.py b/src/zenml/zen_server/routers/devices_endpoints.py index 59e8f4d7575..67a63ba894e 100644 --- a/src/zenml/zen_server/routers/devices_endpoints.py +++ b/src/zenml/zen_server/routers/devices_endpoints.py @@ -24,7 +24,7 @@ DEVICES, VERSION_1, ) -from zenml.enums import OAuthDeviceStatus, PermissionType +from zenml.enums import OAuthDeviceStatus from zenml.models import ( OAuthDeviceFilterModel, OAuthDeviceInternalUpdateModel, @@ -59,9 +59,7 @@ def list_authorized_devices( filter_model: OAuthDeviceFilterModel = Depends( make_dependable(OAuthDeviceFilterModel) ), - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> Page[OAuthDeviceResponseModel]: """Gets a page of OAuth2 authorized devices belonging to the current user. @@ -86,9 +84,7 @@ def list_authorized_devices( def get_authorization_device( device_id: UUID, user_code: Optional[str] = None, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> OAuthDeviceResponseModel: """Gets a specific OAuth2 authorized device using its unique ID. @@ -134,9 +130,7 @@ def get_authorization_device( def update_authorized_device( device_id: UUID, update: OAuthDeviceUpdateModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> OAuthDeviceResponseModel: """Updates a specific OAuth2 authorized device using its unique ID. @@ -173,9 +167,7 @@ def update_authorized_device( def verify_authorized_device( device_id: UUID, request: OAuthDeviceVerificationRequest, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> OAuthDeviceResponseModel: """Verifies a specific OAuth2 authorized device using its unique ID. @@ -274,9 +266,7 @@ def verify_authorized_device( @handle_exceptions def delete_authorized_device( device_id: UUID, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> None: """Deletes a specific OAuth2 authorized device using its unique ID. diff --git a/src/zenml/zen_server/routers/flavors_endpoints.py b/src/zenml/zen_server/routers/flavors_endpoints.py index c4d06958c49..beb66de4882 100644 --- a/src/zenml/zen_server/routers/flavors_endpoints.py +++ b/src/zenml/zen_server/routers/flavors_endpoints.py @@ -18,7 +18,6 @@ from fastapi import APIRouter, Depends, Security from zenml.constants import API, FLAVORS, VERSION_1 -from zenml.enums import PermissionType from zenml.exceptions import IllegalOperationError from zenml.models import ( FlavorFilterModel, @@ -52,7 +51,7 @@ def list_flavors( flavor_filter_model: FlavorFilterModel = Depends( make_dependable(FlavorFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[FlavorResponseModel]: """Returns all flavors. @@ -75,7 +74,7 @@ def list_flavors( @handle_exceptions def get_flavor( flavor_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> FlavorResponseModel: """Returns the requested flavor. @@ -97,9 +96,7 @@ def get_flavor( @handle_exceptions def create_flavor( flavor: FlavorRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> FlavorResponseModel: """Creates a stack component flavor. @@ -136,7 +133,7 @@ def create_flavor( def update_flavor( flavor_id: UUID, flavor_update: FlavorUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> FlavorResponseModel: """Updates a flavor. @@ -161,7 +158,7 @@ def update_flavor( @handle_exceptions def delete_flavor( flavor_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a flavor. @@ -177,7 +174,7 @@ def delete_flavor( ) @handle_exceptions def sync_flavors( - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Purge all in-built and integration flavors from the DB and sync. diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 58273a852b7..a48b16ad5dd 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -27,7 +27,7 @@ RUNS, VERSION_1, ) -from zenml.enums import ModelStages, PermissionType +from zenml.enums import ModelStages from zenml.models import ( ModelFilterModel, ModelResponseModel, @@ -70,7 +70,7 @@ def list_models( model_filter_model: ModelFilterModel = Depends( make_dependable(ModelFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[ModelResponseModel]: """Get models according to query filters. @@ -95,7 +95,7 @@ def list_models( @handle_exceptions def get_model( model_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> ModelResponseModel: """Get a model by name or ID. @@ -117,7 +117,7 @@ def get_model( def update_model( model_id: UUID, model_update: ModelUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> ModelResponseModel: """Updates a model. @@ -141,7 +141,7 @@ def update_model( @handle_exceptions def delete_model( model_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Delete a model by name or ID. @@ -166,7 +166,7 @@ def list_model_versions( model_version_filter_model: ModelVersionFilterModel = Depends( make_dependable(ModelVersionFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[ModelVersionResponseModel]: """Get model versions according to query filters. @@ -196,7 +196,7 @@ def get_model_version( str, int, UUID, ModelStages ] = LATEST_MODEL_VERSION_PLACEHOLDER, is_number: bool = False, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> ModelVersionResponseModel: """Get a model version by name or ID. @@ -226,7 +226,7 @@ def get_model_version( def update_model_version( model_version_id: UUID, model_version_update_model: ModelVersionUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> ModelVersionResponseModel: """Get all model versions by filter. @@ -251,7 +251,7 @@ def update_model_version( def delete_model_version( model_name_or_id: Union[str, UUID], model_version_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Delete a model by name or ID. @@ -282,7 +282,7 @@ def list_model_version_artifact_links( model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( make_dependable(ModelVersionArtifactFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[ModelVersionArtifactResponseModel]: """Get model version to artifact links according to query filters. @@ -311,7 +311,7 @@ def delete_model_version_artifact_link( model_name_or_id: Union[str, UUID], model_version_name_or_id: Union[str, UUID], model_version_artifact_link_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a model version link. @@ -345,7 +345,7 @@ def list_model_version_pipeline_run_links( model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( make_dependable(ModelVersionPipelineRunFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[ModelVersionPipelineRunResponseModel]: """Get model version to pipeline run links according to query filters. @@ -374,7 +374,7 @@ def delete_model_version_pipeline_run_link( model_name_or_id: Union[str, UUID], model_version_name_or_id: Union[str, UUID], model_version_pipeline_run_link_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a model version link. diff --git a/src/zenml/zen_server/routers/pipeline_builds_endpoints.py b/src/zenml/zen_server/routers/pipeline_builds_endpoints.py index 524f6ee20d5..f053fe26c7b 100644 --- a/src/zenml/zen_server/routers/pipeline_builds_endpoints.py +++ b/src/zenml/zen_server/routers/pipeline_builds_endpoints.py @@ -17,7 +17,6 @@ from fastapi import APIRouter, Depends, Security from zenml.constants import API, PIPELINE_BUILDS, VERSION_1 -from zenml.enums import PermissionType from zenml.models import PipelineBuildFilterModel, PipelineBuildResponseModel from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize @@ -45,7 +44,7 @@ def list_builds( build_filter_model: PipelineBuildFilterModel = Depends( make_dependable(PipelineBuildFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[PipelineBuildResponseModel]: """Gets a list of builds. @@ -67,7 +66,7 @@ def list_builds( @handle_exceptions def get_build( build_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> PipelineBuildResponseModel: """Gets a specific build using its unique id. @@ -87,7 +86,7 @@ def get_build( @handle_exceptions def delete_build( build_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a specific build. diff --git a/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py b/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py index 3d4b7b4c244..70fd486a234 100644 --- a/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py +++ b/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py @@ -17,7 +17,6 @@ from fastapi import APIRouter, Depends, Security from zenml.constants import API, PIPELINE_DEPLOYMENTS, VERSION_1 -from zenml.enums import PermissionType from zenml.models import ( PipelineDeploymentFilterModel, PipelineDeploymentResponseModel, @@ -48,7 +47,7 @@ def list_deployments( deployment_filter_model: PipelineDeploymentFilterModel = Depends( make_dependable(PipelineDeploymentFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[PipelineDeploymentResponseModel]: """Gets a list of deployment. @@ -72,7 +71,7 @@ def list_deployments( @handle_exceptions def get_deployment( deployment_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> PipelineDeploymentResponseModel: """Gets a specific deployment using its unique id. @@ -92,7 +91,7 @@ def get_deployment( @handle_exceptions def delete_deployment( deployment_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a specific deployment. diff --git a/src/zenml/zen_server/routers/pipelines_endpoints.py b/src/zenml/zen_server/routers/pipelines_endpoints.py index f90cdd17a22..f033990548e 100644 --- a/src/zenml/zen_server/routers/pipelines_endpoints.py +++ b/src/zenml/zen_server/routers/pipelines_endpoints.py @@ -18,7 +18,6 @@ from zenml.config.pipeline_spec import PipelineSpec from zenml.constants import API, PIPELINE_SPEC, PIPELINES, RUNS, VERSION_1 -from zenml.enums import PermissionType from zenml.models import ( PipelineFilterModel, PipelineResponseModel, @@ -57,9 +56,7 @@ def list_pipelines( pipeline_filter_model: PipelineFilterModel = Depends( make_dependable(PipelineFilterModel) ), - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> Page[PipelineResponseModel]: """Gets a list of pipelines. @@ -88,9 +85,7 @@ def list_pipelines( @handle_exceptions def get_pipeline( pipeline_id: UUID, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> PipelineResponseModel: """Gets a specific pipeline using its unique id. @@ -112,7 +107,7 @@ def get_pipeline( def update_pipeline( pipeline_id: UUID, pipeline_update: PipelineUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> PipelineResponseModel: """Updates the attribute on a specific pipeline using its unique id. @@ -135,7 +130,7 @@ def update_pipeline( @handle_exceptions def delete_pipeline( pipeline_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a specific pipeline. @@ -155,7 +150,7 @@ def list_pipeline_runs( pipeline_run_filter_model: PipelineRunFilterModel = Depends( make_dependable(PipelineRunFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[PipelineRunResponseModel]: """Get pipeline runs according to query filters. @@ -177,7 +172,7 @@ def list_pipeline_runs( @handle_exceptions def get_pipeline_spec( pipeline_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> PipelineSpec: """Gets the spec of a specific pipeline using its unique id. diff --git a/src/zenml/zen_server/routers/run_metadata_endpoints.py b/src/zenml/zen_server/routers/run_metadata_endpoints.py index daafbacffd0..7849ff0ba98 100644 --- a/src/zenml/zen_server/routers/run_metadata_endpoints.py +++ b/src/zenml/zen_server/routers/run_metadata_endpoints.py @@ -17,7 +17,6 @@ from fastapi import APIRouter, Depends, Security from zenml.constants import API, RUN_METADATA, VERSION_1 -from zenml.enums import PermissionType from zenml.models import RunMetadataResponseModel from zenml.models.page_model import Page from zenml.models.run_metadata_models import RunMetadataFilterModel @@ -46,7 +45,7 @@ def list_run_metadata( run_metadata_filter_model: RunMetadataFilterModel = Depends( make_dependable(RunMetadataFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[RunMetadataResponseModel]: """Get run metadata according to query filters. diff --git a/src/zenml/zen_server/routers/runs_endpoints.py b/src/zenml/zen_server/routers/runs_endpoints.py index a1a0dc01a44..eb85a872bca 100644 --- a/src/zenml/zen_server/routers/runs_endpoints.py +++ b/src/zenml/zen_server/routers/runs_endpoints.py @@ -26,7 +26,7 @@ STEPS, VERSION_1, ) -from zenml.enums import ExecutionStatus, PermissionType +from zenml.enums import ExecutionStatus from zenml.lineage_graph.lineage_graph import LineageGraph from zenml.models import ( PipelineRunFilterModel, @@ -61,7 +61,7 @@ def list_runs( runs_filter_model: PipelineRunFilterModel = Depends( make_dependable(PipelineRunFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[PipelineRunResponseModel]: """Get pipeline runs according to query filters. @@ -82,7 +82,7 @@ def list_runs( @handle_exceptions def get_run( run_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> PipelineRunResponseModel: """Get a specific pipeline run using its ID. @@ -104,7 +104,7 @@ def get_run( def update_run( run_id: UUID, run_model: PipelineRunUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> PipelineRunResponseModel: """Updates a run. @@ -125,7 +125,7 @@ def update_run( @handle_exceptions def delete_run( run_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a run. @@ -143,7 +143,7 @@ def delete_run( @handle_exceptions def get_run_dag( run_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> LineageGraph: """Get the DAG for a given pipeline run. @@ -169,7 +169,7 @@ def get_run_steps( step_run_filter_model: StepRunFilterModel = Depends( make_dependable(StepRunFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[StepRunResponseModel]: """Get all steps for a given pipeline run. @@ -191,7 +191,7 @@ def get_run_steps( @handle_exceptions def get_pipeline_configuration( run_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Dict[str, Any]: """Get the pipeline configuration of a specific pipeline run using its ID. @@ -212,7 +212,7 @@ def get_pipeline_configuration( @handle_exceptions def get_run_status( run_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> ExecutionStatus: """Get the status of a specific pipeline run. diff --git a/src/zenml/zen_server/routers/schedule_endpoints.py b/src/zenml/zen_server/routers/schedule_endpoints.py index 13c8d725bce..443df90ded3 100644 --- a/src/zenml/zen_server/routers/schedule_endpoints.py +++ b/src/zenml/zen_server/routers/schedule_endpoints.py @@ -17,7 +17,6 @@ from fastapi import APIRouter, Depends, Security from zenml.constants import API, SCHEDULES, VERSION_1 -from zenml.enums import PermissionType from zenml.models.page_model import Page from zenml.models.schedule_model import ( ScheduleFilterModel, @@ -49,7 +48,7 @@ def list_schedules( schedule_filter_model: ScheduleFilterModel = Depends( make_dependable(ScheduleFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[ScheduleResponseModel]: """Gets a list of schedules. @@ -73,7 +72,7 @@ def list_schedules( @handle_exceptions def get_schedule( schedule_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> ScheduleResponseModel: """Gets a specific schedule using its unique id. @@ -95,7 +94,7 @@ def get_schedule( def update_schedule( schedule_id: UUID, schedule_update: ScheduleUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> ScheduleResponseModel: """Updates the attribute on a specific schedule using its unique id. @@ -118,7 +117,7 @@ def update_schedule( @handle_exceptions def delete_schedule( schedule_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a specific schedule using its unique id. diff --git a/src/zenml/zen_server/routers/secrets_endpoints.py b/src/zenml/zen_server/routers/secrets_endpoints.py index 1b50adbc493..8a155179bd8 100644 --- a/src/zenml/zen_server/routers/secrets_endpoints.py +++ b/src/zenml/zen_server/routers/secrets_endpoints.py @@ -50,9 +50,7 @@ def list_secrets( secret_filter_model: SecretFilterModel = Depends( make_dependable(SecretFilterModel) ), - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> Page[SecretResponseModel]: """Gets a list of secrets. @@ -83,9 +81,7 @@ def list_secrets( @handle_exceptions def get_secret( secret_id: UUID, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> SecretResponseModel: """Gets a specific secret using its unique id. @@ -116,7 +112,7 @@ def update_secret( secret_id: UUID, secret_update: SecretUpdateModel, patch_values: Optional[bool] = False, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> SecretResponseModel: """Updates the attribute on a specific secret using its unique id. @@ -150,7 +146,7 @@ def update_secret( @handle_exceptions def delete_secret( secret_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a specific secret using its unique id. diff --git a/src/zenml/zen_server/routers/service_connectors_endpoints.py b/src/zenml/zen_server/routers/service_connectors_endpoints.py index 047b5b4fba9..086a9a941fc 100644 --- a/src/zenml/zen_server/routers/service_connectors_endpoints.py +++ b/src/zenml/zen_server/routers/service_connectors_endpoints.py @@ -67,9 +67,7 @@ def list_service_connectors( make_dependable(ServiceConnectorFilterModel) ), expand_secrets: bool = True, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> Page[ServiceConnectorResponseModel]: """Get a list of all service connectors for a specific type. @@ -107,9 +105,7 @@ def list_service_connectors( def get_service_connector( connector_id: UUID, expand_secrets: bool = True, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> ServiceConnectorResponseModel: """Returns the requested service connector. @@ -156,9 +152,7 @@ def get_service_connector( def update_service_connector( connector_id: UUID, connector_update: ServiceConnectorUpdateModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> ServiceConnectorResponseModel: """Updates a service connector. @@ -197,9 +191,7 @@ def update_service_connector( @handle_exceptions def delete_service_connector( connector_id: UUID, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> None: """Deletes a service connector. @@ -234,7 +226,7 @@ def delete_service_connector( def validate_and_verify_service_connector_config( connector: ServiceConnectorRequestModel, list_resources: bool = True, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> ServiceConnectorResourcesModel: """Verifies if a service connector configuration has access to resources. @@ -268,9 +260,7 @@ def validate_and_verify_service_connector( resource_type: Optional[str] = None, resource_id: Optional[str] = None, list_resources: bool = True, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> ServiceConnectorResourcesModel: """Verifies if a service connector instance has access to one or more resources. @@ -323,9 +313,7 @@ def get_service_connector_client( connector_id: UUID, resource_type: Optional[str] = None, resource_id: Optional[str] = None, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> ServiceConnectorResponseModel: """Get a service connector client for a service connector and given resource. @@ -374,7 +362,7 @@ def list_service_connector_types( connector_type: Optional[str] = None, resource_type: Optional[str] = None, auth_method: Optional[str] = None, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> List[ServiceConnectorTypeModel]: """Get a list of service connector types. @@ -403,7 +391,7 @@ def list_service_connector_types( @handle_exceptions def get_service_connector_type( connector_type: str, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> ServiceConnectorTypeModel: """Returns the requested service connector type. diff --git a/src/zenml/zen_server/routers/stack_components_endpoints.py b/src/zenml/zen_server/routers/stack_components_endpoints.py index fed9b2d2744..6b2813bf416 100644 --- a/src/zenml/zen_server/routers/stack_components_endpoints.py +++ b/src/zenml/zen_server/routers/stack_components_endpoints.py @@ -18,7 +18,7 @@ from fastapi import APIRouter, Depends, Security from zenml.constants import API, COMPONENT_TYPES, STACK_COMPONENTS, VERSION_1 -from zenml.enums import PermissionType, StackComponentType +from zenml.enums import StackComponentType from zenml.models import ( ComponentFilterModel, ComponentResponseModel, @@ -56,9 +56,7 @@ def list_stack_components( component_filter_model: ComponentFilterModel = Depends( make_dependable(ComponentFilterModel) ), - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> Page[ComponentResponseModel]: """Get a list of all stack components for a specific type. @@ -83,7 +81,7 @@ def list_stack_components( @handle_exceptions def get_stack_component( component_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> ComponentResponseModel: """Returns the requested stack component. @@ -105,7 +103,7 @@ def get_stack_component( def update_stack_component( component_id: UUID, component_update: ComponentUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> ComponentResponseModel: """Updates a stack component. @@ -129,7 +127,7 @@ def update_stack_component( @handle_exceptions def deregister_stack_component( component_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a stack component. @@ -146,7 +144,7 @@ def deregister_stack_component( ) @handle_exceptions def get_stack_component_types( - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]) + _: AuthContext = Security(authorize), ) -> List[str]: """Get a list of all stack component types. diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index f40b1ba0e24..6be82b7eddb 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -18,12 +18,13 @@ from fastapi import APIRouter, Depends, Security from zenml.constants import API, STACKS, VERSION_1 -from zenml.enums import PermissionType from zenml.models import StackFilterModel, StackResponseModel, StackUpdateModel from zenml.models.page_model import Page from zenml.zen_server.auth import ( AuthContext, authorize, + dehydrate_response_model, + get_allowed_resource_ids, verify_permissions_for_model, verify_read_permissions_and_dehydrate, ) @@ -51,9 +52,7 @@ def list_stacks( stack_filter_model: StackFilterModel = Depends( make_dependable(StackFilterModel) ), - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> Page[StackResponseModel]: """Returns all stacks. @@ -64,7 +63,16 @@ def list_stacks( Returns: All stacks. """ - return zen_store().list_stacks(stack_filter_model=stack_filter_model) + allowed_ids = get_allowed_resource_ids( + resource_type="stack", action="read" + ) + print(allowed_ids) + stack_filter_model.set_allowed_ids(allowed_ids) + page = zen_store().list_stacks(stack_filter_model=stack_filter_model) + + # TODO: make this better, this is sending a ton of requests here + page.items = [dehydrate_response_model(model) for model in page.items] + return page @router.get( @@ -80,9 +88,7 @@ def list_stacks( @handle_exceptions def get_stack( stack_id: UUID, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> StackResponseModel: """Returns the requested stack. @@ -105,7 +111,7 @@ def get_stack( def update_stack( stack_id: UUID, stack_update: StackUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> StackResponseModel: """Updates a stack. @@ -132,7 +138,7 @@ def update_stack( @handle_exceptions def delete_stack( stack_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a stack. diff --git a/src/zenml/zen_server/routers/steps_endpoints.py b/src/zenml/zen_server/routers/steps_endpoints.py index 6f0bd787e8a..ab62c1a6d5e 100644 --- a/src/zenml/zen_server/routers/steps_endpoints.py +++ b/src/zenml/zen_server/routers/steps_endpoints.py @@ -26,7 +26,7 @@ STEPS, VERSION_1, ) -from zenml.enums import ExecutionStatus, PermissionType +from zenml.enums import ExecutionStatus from zenml.models import ( StepRunFilterModel, StepRunRequestModel, @@ -63,7 +63,7 @@ def list_run_steps( step_run_filter_model: StepRunFilterModel = Depends( make_dependable(StepRunFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[StepRunResponseModel]: """Get run steps according to query filters. @@ -87,7 +87,7 @@ def list_run_steps( @handle_exceptions def create_run_step( step: StepRunRequestModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> StepRunResponseModel: """Create a run step. @@ -108,7 +108,7 @@ def create_run_step( @handle_exceptions def get_step( step_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> StepRunResponseModel: """Get one specific step. @@ -130,7 +130,7 @@ def get_step( def update_step( step_id: UUID, step_model: StepRunUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> StepRunResponseModel: """Updates a step. @@ -154,7 +154,7 @@ def update_step( @handle_exceptions def get_step_configuration( step_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Dict[str, Any]: """Get the configuration of a specific step. @@ -175,7 +175,7 @@ def get_step_configuration( @handle_exceptions def get_step_status( step_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> ExecutionStatus: """Get the status of a specific step. @@ -196,7 +196,7 @@ def get_step_status( @handle_exceptions def get_step_logs( step_id: UUID, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> str: """Get the logs of a specific step. diff --git a/src/zenml/zen_server/routers/users_endpoints.py b/src/zenml/zen_server/routers/users_endpoints.py index 914a6766c14..bac42fdbec4 100644 --- a/src/zenml/zen_server/routers/users_endpoints.py +++ b/src/zenml/zen_server/routers/users_endpoints.py @@ -27,7 +27,7 @@ USERS, VERSION_1, ) -from zenml.enums import AuthScheme, PermissionType +from zenml.enums import AuthScheme from zenml.exceptions import AuthorizationException, IllegalOperationError from zenml.logger import get_logger from zenml.models import ( @@ -83,7 +83,7 @@ def list_users( user_filter_model: UserFilterModel = Depends( make_dependable(UserFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[UserResponseModel]: """Returns a list of all users. @@ -112,7 +112,7 @@ def list_users( @handle_exceptions def create_user( user: UserRequestModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> UserResponseModel: """Creates a user. @@ -152,7 +152,7 @@ def create_user( @handle_exceptions def get_user( user_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> UserResponseModel: """Returns a specific user. @@ -182,7 +182,7 @@ def get_user( def update_user( user_name_or_id: Union[str, UUID], user_update: UserUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> UserResponseModel: """Updates a specific user. @@ -247,7 +247,7 @@ def activate_user( @handle_exceptions def deactivate_user( user_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> UserResponseModel: """Deactivates a user and generates a new activation token for it. @@ -282,9 +282,7 @@ def deactivate_user( @handle_exceptions def delete_user( user_name_or_id: Union[str, UUID], - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> None: """Deletes a specific user. @@ -319,9 +317,7 @@ def delete_user( def email_opt_in_response( user_name_or_id: Union[str, UUID], user_response: UserUpdateModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.ME] - ), + auth_context: AuthContext = Security(authorize), ) -> UserResponseModel: """Sets the response of the user to the email prompt. @@ -369,9 +365,7 @@ def email_opt_in_response( ) @handle_exceptions def get_current_user( - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> UserResponseModel: """Returns the model of the authenticated user. @@ -400,9 +394,7 @@ def get_current_user( @handle_exceptions def update_myself( user: UserUpdateModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.ME] - ), + auth_context: AuthContext = Security(authorize), ) -> UserResponseModel: """Updates a specific user. diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index cd3628b1223..783135f94b3 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -39,7 +39,6 @@ VERSION_1, WORKSPACES, ) -from zenml.enums import PermissionType from zenml.exceptions import IllegalOperationError from zenml.models import ( CodeRepositoryFilterModel, @@ -116,7 +115,7 @@ def list_workspaces( workspace_filter_model: WorkspaceFilterModel = Depends( make_dependable(WorkspaceFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[WorkspaceResponseModel]: """Lists all workspaces in the organization. @@ -140,7 +139,7 @@ def list_workspaces( @handle_exceptions def create_workspace( workspace: WorkspaceRequestModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> WorkspaceResponseModel: """Creates a workspace based on the requestBody. @@ -163,7 +162,7 @@ def create_workspace( @handle_exceptions def get_workspace( workspace_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> WorkspaceResponseModel: """Get a workspace for given name. @@ -187,7 +186,7 @@ def get_workspace( def update_workspace( workspace_name_or_id: UUID, workspace_update: WorkspaceUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> WorkspaceResponseModel: """Get a workspace for given name. @@ -213,7 +212,7 @@ def update_workspace( @handle_exceptions def delete_workspace( workspace_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes a workspace. @@ -234,9 +233,7 @@ def list_workspace_stacks( stack_filter_model: StackFilterModel = Depends( make_dependable(StackFilterModel) ), - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> Page[StackResponseModel]: """Get stacks that are part of a specific workspace for the user. @@ -264,9 +261,7 @@ def list_workspace_stacks( def create_stack( workspace_name_or_id: Union[str, UUID], stack: StackRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> StackResponseModel: """Creates a stack for a particular workspace. @@ -310,9 +305,7 @@ def list_workspace_stack_components( component_filter_model: ComponentFilterModel = Depends( make_dependable(ComponentFilterModel) ), - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> Page[ComponentResponseModel]: """List stack components that are part of a specific workspace. @@ -343,9 +336,7 @@ def list_workspace_stack_components( def create_stack_component( workspace_name_or_id: Union[str, UUID], component: ComponentRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> ComponentResponseModel: """Creates a stack component. @@ -393,7 +384,7 @@ def list_workspace_pipelines( pipeline_filter_model: PipelineFilterModel = Depends( make_dependable(PipelineFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[PipelineResponseModel]: """Gets pipelines defined for a specific workspace. @@ -423,9 +414,7 @@ def list_workspace_pipelines( def create_pipeline( workspace_name_or_id: Union[str, UUID], pipeline: PipelineRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> PipelineResponseModel: """Creates a pipeline. @@ -469,7 +458,7 @@ def list_workspace_builds( build_filter_model: PipelineBuildFilterModel = Depends( make_dependable(PipelineBuildFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[PipelineBuildResponseModel]: """Gets builds defined for a specific workspace. @@ -497,9 +486,7 @@ def list_workspace_builds( def create_build( workspace_name_or_id: Union[str, UUID], build: PipelineBuildRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> PipelineBuildResponseModel: """Creates a build. @@ -543,7 +530,7 @@ def list_workspace_deployments( deployment_filter_model: PipelineDeploymentFilterModel = Depends( make_dependable(PipelineDeploymentFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[PipelineDeploymentResponseModel]: """Gets deployments defined for a specific workspace. @@ -573,9 +560,7 @@ def list_workspace_deployments( def create_deployment( workspace_name_or_id: Union[str, UUID], deployment: PipelineDeploymentRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> PipelineDeploymentResponseModel: """Creates a deployment. @@ -620,7 +605,7 @@ def list_runs( runs_filter_model: PipelineRunFilterModel = Depends( make_dependable(PipelineRunFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[PipelineRunResponseModel]: """Get pipeline runs according to query filters. @@ -647,9 +632,7 @@ def list_runs( def create_schedule( workspace_name_or_id: Union[str, UUID], schedule: ScheduleRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> ScheduleResponseModel: """Creates a schedule. @@ -690,9 +673,7 @@ def create_schedule( def create_pipeline_run( workspace_name_or_id: Union[str, UUID], pipeline_run: PipelineRunRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), get_if_exists: bool = False, ) -> PipelineRunResponseModel: """Creates a pipeline run. @@ -740,9 +721,7 @@ def create_pipeline_run( def get_or_create_pipeline_run( workspace_name_or_id: Union[str, UUID], pipeline_run: PipelineRunRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> Tuple[PipelineRunResponseModel, bool]: """Get or create a pipeline run. @@ -784,9 +763,7 @@ def get_or_create_pipeline_run( def create_run_metadata( workspace_name_or_id: Union[str, UUID], run_metadata: RunMetadataRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> List[RunMetadataResponseModel]: """Creates run metadata. @@ -829,9 +806,7 @@ def create_run_metadata( def create_secret( workspace_name_or_id: Union[str, UUID], secret: SecretRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> SecretResponseModel: """Creates a secret. @@ -874,7 +849,7 @@ def list_workspace_code_repositories( filter_model: CodeRepositoryFilterModel = Depends( make_dependable(CodeRepositoryFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[CodeRepositoryResponseModel]: """Gets code repositories defined for a specific workspace. @@ -902,9 +877,7 @@ def list_workspace_code_repositories( def create_code_repository( workspace_name_or_id: Union[str, UUID], code_repository: CodeRepositoryRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> CodeRepositoryResponseModel: """Creates a code repository. @@ -946,7 +919,7 @@ def create_code_repository( @handle_exceptions def get_workspace_statistics( workspace_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Dict[str, int]: """Gets statistics of a workspace. @@ -981,9 +954,7 @@ def list_workspace_service_connectors( connector_filter_model: ServiceConnectorFilterModel = Depends( make_dependable(ServiceConnectorFilterModel) ), - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> Page[ServiceConnectorResponseModel]: """List service connectors that are part of a specific workspace. @@ -1015,9 +986,7 @@ def list_workspace_service_connectors( def create_service_connector( workspace_name_or_id: Union[str, UUID], connector: ServiceConnectorRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> ServiceConnectorResponseModel: """Creates a service connector. @@ -1065,9 +1034,7 @@ def list_service_connector_resources( connector_type: Optional[str] = None, resource_type: Optional[str] = None, resource_id: Optional[str] = None, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.READ] - ), + auth_context: AuthContext = Security(authorize), ) -> List[ServiceConnectorResourcesModel]: """List resources that can be accessed by service connectors. @@ -1100,9 +1067,7 @@ def list_service_connector_resources( def create_model( workspace_name_or_id: Union[str, UUID], model: ModelRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> ModelResponseModel: """Create a new model. @@ -1146,7 +1111,7 @@ def list_workspace_models( model_filter_model: ModelFilterModel = Depends( make_dependable(ModelFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[ModelResponseModel]: """Get models according to query filters. @@ -1180,9 +1145,7 @@ def create_model_version( workspace_name_or_id: Union[str, UUID], model_name_or_id: Union[str, UUID], model_version: ModelVersionRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> ModelVersionResponseModel: """Create a new model version. @@ -1228,7 +1191,7 @@ def list_workspace_model_versions( model_version_filter_model: ModelVersionFilterModel = Depends( make_dependable(ModelVersionFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[ModelVersionResponseModel]: """Get model versions according to query filters. @@ -1264,9 +1227,7 @@ def create_model_version_artifact_link( model_name_or_id: Union[str, UUID], model_version_name_or_id: Union[str, UUID], model_version_artifact_link: ModelVersionArtifactRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> ModelVersionArtifactResponseModel: """Create a new model version to artifact link. @@ -1321,7 +1282,7 @@ def list_workspace_model_version_artifact_links( model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( make_dependable(ModelVersionArtifactFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[ModelVersionArtifactResponseModel]: """Get model version to artifact links according to query filters. @@ -1359,9 +1320,7 @@ def create_model_version_pipeline_run_link( model_name_or_id: Union[str, UUID], model_version_name_or_id: Union[str, UUID], model_version_pipeline_run_link: ModelVersionPipelineRunRequestModel, - auth_context: AuthContext = Security( - authorize, scopes=[PermissionType.WRITE] - ), + auth_context: AuthContext = Security(authorize), ) -> ModelVersionPipelineRunResponseModel: """Create a new model version to pipeline run link. @@ -1417,7 +1376,7 @@ def list_workspace_model_version_pipeline_run_links( model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( make_dependable(ModelVersionPipelineRunFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[ModelVersionPipelineRunResponseModel]: """Get model version to pipeline links according to query filters. diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 789be3ea89c..0a70c74ac98 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -29,7 +29,6 @@ Dict, List, Optional, - Set, Tuple, Type, TypeVar, @@ -717,7 +716,6 @@ def filter_and_paginate( List[AnySchema], ] ] = None, - resource_ids: Optional[Set[UUID]] = None, ) -> Page[B]: """Given a query, return a Page instance with a list of filtered Models. @@ -746,9 +744,6 @@ def filter_and_paginate( """ query = filter_model.apply_filter(query=query, table=table) - if resource_ids: - query = query.where(table.id.in_(resource_ids)) - # Get the total amount of items in the database for a given query if custom_fetch: total = len(custom_fetch(session, query, filter_model)) From 41a6f5d7ef7eb06e60df8a723d02b13e122bf3f4 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 17 Oct 2023 13:57:46 +0200 Subject: [PATCH 007/103] Validate source class --- src/zenml/zen_server/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index a24003aafef..9465240125c 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -80,7 +80,9 @@ def initialize_rbac() -> None: if rbac_source := server_config().rbac_implementation_source: from zenml.utils import source_utils - implementation_class = source_utils.load(rbac_source) + implementation_class = source_utils.load_and_validate_class( + rbac_source, expected_class=RBACInterface + ) _rbac = implementation_class() From 738ac04e9fc7b839c2dfa31970c069b7cac92434 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 17 Oct 2023 15:01:19 +0200 Subject: [PATCH 008/103] Some cleanup and docstrings --- src/zenml/zen_server/auth.py | 182 +---------- src/zenml/zen_server/rbac/__init__.py | 14 + src/zenml/zen_server/rbac/models.py | 63 ++++ src/zenml/zen_server/rbac/rbac_interface.py | 61 ++++ src/zenml/zen_server/rbac/utils.py | 287 ++++++++++++++++++ src/zenml/zen_server/rbac_interface.py | 102 ------- .../zen_server/routers/pipelines_endpoints.py | 4 +- .../zen_server/routers/stacks_endpoints.py | 16 +- 8 files changed, 438 insertions(+), 291 deletions(-) create mode 100644 src/zenml/zen_server/rbac/__init__.py create mode 100644 src/zenml/zen_server/rbac/models.py create mode 100644 src/zenml/zen_server/rbac/rbac_interface.py create mode 100644 src/zenml/zen_server/rbac/utils.py delete mode 100644 src/zenml/zen_server/rbac_interface.py diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 17464dec6e5..ab967313239 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -15,8 +15,7 @@ from contextvars import ContextVar from datetime import datetime -from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Callable, List, Optional, Set, Union from urllib.parse import urlencode from uuid import UUID @@ -50,14 +49,8 @@ UserResponseModel, UserUpdateModel, ) -from zenml.models.base_models import BaseResponseModel, UserScopedResponseModel -from zenml.models.user_models import UserAuthModel from zenml.zen_server.jwt import JWTToken -from zenml.zen_server.rbac_interface import ( - Resource, - get_resource_type_for_model, -) -from zenml.zen_server.utils import rbac, server_config, zen_store +from zenml.zen_server.utils import server_config, zen_store from zenml.zen_stores.base_zen_store import DEFAULT_USERNAME logger = get_logger(__name__) @@ -665,174 +658,3 @@ def authentication_provider() -> Callable[..., AuthContext]: authorize = authentication_provider() - - -def verify_read_permissions_and_dehydrate( - model: "BaseResponseModel", -) -> "BaseResponseModel": - if not server_config().rbac_enabled: - return model - - verify_permissions_for_model(model=model, action="READ") - - return dehydrate_response_model(model=model) - - -def dehydrate_response_model( - model: "BaseResponseModel", -) -> "BaseResponseModel": - dehydrated_fields = {} - - for field_name in model.__fields__.keys(): - value = getattr(model, field_name) - dehydrated_fields[field_name] = _maybe_dehydrate_value(value) - - return type(model).parse_obj(dehydrated_fields) - - -def _maybe_dehydrate_value(value: Any) -> Any: - if isinstance(value, BaseResponseModel): - if has_read_permissions_for_model(value): - return dehydrate_response_model(value) - else: - return get_403_model(value) - elif isinstance(value, Dict): - return {k: _maybe_dehydrate_value(v) for k, v in value.items()} - elif isinstance(value, (List, Set, Tuple)): - type_ = type(value) - return type_(_maybe_dehydrate_value(v) for v in value) - else: - return value - - -def has_read_permissions_for_model(model: "BaseResponseModel") -> bool: - try: - verify_permissions_for_model(model=model, action="READ") - return True - except HTTPException: - return False - - -def get_403_model( - model: "BaseResponseModel", keep_name: bool = True -) -> "BaseResponseModel": - values = {} - - for field_name, field in model.__fields__.items(): - value = getattr(model, field_name) - - if keep_name and field_name == "name" and isinstance(value, str): - pass - elif field.allow_none: - value = None - elif isinstance(value, BaseResponseModel): - value = get_403_model(value, keep_name=False) - elif isinstance(value, UUID): - value = UUID(int=0) - elif isinstance(value, datetime): - value = datetime.utcnow() - elif isinstance(value, Enum): - # TODO: handle enums in a more sensible way - value = list(type(value))[0] - else: - type_ = type(value) - # For the remaining cases (dict, list, set, tuple, int, float, str), - # simply return an empty value - value = type_() - - values[field_name] = value - - # TODO: With the new hydration models, make sure we clear metadata here - values["missing_permissions"] = True - - return type(model).parse_obj(values) - - -def verify_permissions_for_model( - model: "BaseResponseModel", - action: str, -) -> None: - """Verifies if a user has permissions to perform an action on a model. - - Args: - model: The model the user wants to perform the action on. - action: The action the user wants to perform. - """ - if not server_config().rbac_enabled: - return - - if ( - isinstance(model, UserScopedResponseModel) - and model.user - and model.user.id == get_auth_context().user.id - ): - # User is the owner of the model - return - - resource_type = get_resource_type_for_model(model) - if not resource_type: - # This model is not tied to any RBAC resource type and therefore doesn't - # require any special permissions - return - - verify_permissions( - resource_type=resource_type, resource_id=model.id, action=action - ) - - -def verify_permissions( - resource_type: str, - action: str, - resource_id: Optional[UUID] = None, -) -> None: - """Verifies if a user has permissions to perform an action on a resource. - - Args: - resource: The resource type the user wants to perform the action on. - action: The action the user wants to perform. - resource_id: ID of the resource the user wants to perform the action on. - - Raises: - HTTPException: If the user is not allowed to perform the action. - """ - if not server_config().rbac_enabled: - return - - resource = Resource(type=resource_type, id=resource_id) - - if not rbac().has_permission( - user=get_auth_context().user, resource=resource, action=action - ): - raise HTTPException(status_code=403) - - -def get_allowed_resource_ids( - resource_type: str, - action: str, -) -> Optional[List[UUID]]: - """Get all resource IDs of a resource type that a user can access. - - Args: - resource_type: The resource type. - action: The action the user wants to perform on the resource. - - Returns: - A list of resource IDs or `None` if the user has full access to the - all instances of the resource. - """ - if not server_config().rbac_enabled: - return None - - ( - has_full_resource_access, - allowed_ids, - ) = rbac().list_allowed_resource_ids( - user=get_auth_context().user, - resource=Resource(type=resource_type), - action=action, - ) - - if has_full_resource_access: - return None - - return [UUID(id) for id in allowed_ids] diff --git a/src/zenml/zen_server/rbac/__init__.py b/src/zenml/zen_server/rbac/__init__.py new file mode 100644 index 00000000000..edbfdff9b40 --- /dev/null +++ b/src/zenml/zen_server/rbac/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""RBAC definitions.""" diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py new file mode 100644 index 00000000000..9bcf3404329 --- /dev/null +++ b/src/zenml/zen_server/rbac/models.py @@ -0,0 +1,63 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""RBAC model classes.""" + +from typing import Optional +from uuid import UUID + +from pydantic import BaseModel + +from zenml.enums import StrEnum + + +class Action(StrEnum): + """RBAC actions.""" + + CREATE = "create" + READ = "read" + UPDATE = "update" + DELETE = "delete" + + +class ResourceType(StrEnum): + """Resource types of the server API.""" + + STACK = "stack" + FLAVOR = "flavor" + STACK_COMPONENT = "stack_component" + PIPELINE = "pipeline" + CODE_REPOSITORY = "code-repository" + MODEL = "model" + SERVICE_CONNECTOR = "service_connector" + ARTIFACT = "artifact" + SECRET = "secret" + + +class Resource(BaseModel): + """RBAC resource model.""" + + type: str + id: Optional[UUID] = None + + def __str__(self) -> str: + """Convert to a string. + + Returns: + Resource string representation. + """ + representation = self.type + if self.id: + representation += f"/{self.id}" + + return representation diff --git a/src/zenml/zen_server/rbac/rbac_interface.py b/src/zenml/zen_server/rbac/rbac_interface.py new file mode 100644 index 00000000000..b71ed7702a5 --- /dev/null +++ b/src/zenml/zen_server/rbac/rbac_interface.py @@ -0,0 +1,61 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""RBAC interface definition.""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Tuple + +from zenml.zen_server.rbac.models import Resource + +if TYPE_CHECKING: + from zenml.models import UserResponseModel + + +class RBACInterface(ABC): + """RBAC interface definition.""" + + @abstractmethod + def has_permission( + self, user: "UserResponseModel", resource: Resource, action: str + ) -> bool: + """Checks if a user has permission to perform an action on a resource. + + Args: + user: User which wants to access a resource. + resource: The resource the user wants to access. + action: The action that the user wants to perform on the resource. + + Returns: + Whether the user has permission to perform an action on a resource. + """ + + @abstractmethod + def list_allowed_resource_ids( + self, user: "UserResponseModel", resource: Resource, action: str + ) -> Tuple[bool, List[str]]: + """Lists all resource IDs of a resource type that a user can access. + + Args: + user: User which wants to access a resource. + resource: The resource the user wants to access. + action: The action that the user wants to perform on the resource. + + Returns: + A tuple (full_resource_access, resource_ids). + `full_resource_access` will be `True` if the user can perform the + given action on any instance of the given resource type, `False` + otherwise. If `full_resource_access` is `False`, `resource_ids` + will contain the list of instance IDs that the user can perform + the action on. + """ diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py new file mode 100644 index 00000000000..f23eb5fb026 --- /dev/null +++ b/src/zenml/zen_server/rbac/utils.py @@ -0,0 +1,287 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""RBAC utility functions.""" + +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Tuple, Type +from uuid import UUID + +from fastapi import HTTPException + +from zenml.models.base_models import BaseResponseModel, UserScopedResponseModel +from zenml.zen_server.auth import get_auth_context +from zenml.zen_server.rbac.models import Action, Resource, ResourceType +from zenml.zen_server.utils import rbac, server_config + + +def verify_read_permissions_and_dehydrate( + model: "BaseResponseModel", +) -> "BaseResponseModel": + """Verify read permissions of the model and dehydrate it if necessary. + + Args: + model: The model for which to verify permissions. + + Returns: + The (potentially) dehydrated model. + """ + if not server_config().rbac_enabled: + return model + + verify_permissions_for_model(model=model, action=Action.READ) + + return dehydrate_response_model(model=model) + + +def dehydrate_response_model( + model: "BaseResponseModel", +) -> "BaseResponseModel": + """Dehydrate a model if necessary. + + Args: + model: The model to dehydrate. + + Returns: + The (potentially) dehydrated model. + """ + dehydrated_fields = {} + + for field_name in model.__fields__.keys(): + value = getattr(model, field_name) + dehydrated_fields[field_name] = _maybe_dehydrate_value(value) + + return type(model).parse_obj(dehydrated_fields) + + +def _maybe_dehydrate_value(value: Any) -> Any: + """Helper function to recursive dehydrate any object. + + Args: + value: The value to dehydrate. + + Returns: + The recursively dehydrated value. + """ + if isinstance(value, BaseResponseModel): + if has_permissions_for_model(model=value, action=Action.READ): + return dehydrate_response_model(value) + else: + return get_permission_denied_model(value) + elif isinstance(value, Dict): + return {k: _maybe_dehydrate_value(v) for k, v in value.items()} + elif isinstance(value, (List, Set, Tuple)): + type_ = type(value) + return type_(_maybe_dehydrate_value(v) for v in value) + else: + return value + + +def has_permissions_for_model(model: "BaseResponseModel", action: str) -> bool: + """If the active user has permissions to perform the action on the model. + + Args: + model: The model the user wants to perform the action on. + action: The action the user wants to perform. + + Returns: + If the active user has permissions to perform the action on the model. + """ + try: + verify_permissions_for_model(model=model, action=action) + return True + except HTTPException: + return False + + +def get_permission_denied_model( + model: "BaseResponseModel", keep_name: bool = True +) -> "BaseResponseModel": + """Get a model to return in case of missing read permissions. + + This function replaces all attributes except name and ID in the given model. + + Args: + model: The original model. + keep_name: If `True`, the model name will not be replaced. + + Returns: + The model with attribute values replaced by default values. + """ + values = {} + + for field_name, field in model.__fields__.items(): + value = getattr(model, field_name) + + if field_name == "id" and isinstance(value, UUID): + pass + elif keep_name and field_name == "name" and isinstance(value, str): + pass + elif field.allow_none: + value = None + elif isinstance(value, BaseResponseModel): + value = get_permission_denied_model(value, keep_name=False) + elif isinstance(value, UUID): + value = UUID(int=0) + elif isinstance(value, datetime): + value = datetime.utcnow() + elif isinstance(value, Enum): + # TODO: handle enums in a more sensible way + value = list(type(value))[0] + else: + type_ = type(value) + # For the remaining cases (dict, list, set, tuple, int, float, str), + # simply return an empty value + value = type_() + + values[field_name] = value + + # TODO: With the new hydration models, make sure we clear metadata here + values["missing_permissions"] = True + + return type(model).parse_obj(values) + + +def verify_permissions_for_model( + model: "BaseResponseModel", + action: str, +) -> None: + """Verifies if a user has permissions to perform an action on a model. + + Args: + model: The model the user wants to perform the action on. + action: The action the user wants to perform. + """ + if not server_config().rbac_enabled: + return + + if ( + isinstance(model, UserScopedResponseModel) + and model.user + and model.user.id == get_auth_context().user.id + ): + # User is the owner of the model + return + + resource_type = get_resource_type_for_model(model) + if not resource_type: + # This model is not tied to any RBAC resource type and therefore doesn't + # require any special permissions + return + + verify_permissions( + resource_type=resource_type, resource_id=model.id, action=action + ) + + +def verify_permissions( + resource_type: str, + action: str, + resource_id: Optional[UUID] = None, +) -> None: + """Verifies if a user has permissions to perform an action on a resource. + + Args: + resource_type: The type of resource that the user wants to perform the + action on. + action: The action the user wants to perform. + resource_id: ID of the resource the user wants to perform the action on. + + Raises: + HTTPException: If the user is not allowed to perform the action. + """ + if not server_config().rbac_enabled: + return + + resource = Resource(type=resource_type, id=resource_id) + + if not rbac().has_permission( + user=get_auth_context().user, resource=resource, action=action + ): + raise HTTPException( + status_code=403, + detail=f"Insufficient permissions to {action.upper()} resource " + f"'{resource}'.", + ) + + +def get_allowed_resource_ids( + resource_type: str, + action: str, +) -> Optional[List[UUID]]: + """Get all resource IDs of a resource type that a user can access. + + Args: + resource_type: The resource type. + action: The action the user wants to perform on the resource. + + Returns: + A list of resource IDs or `None` if the user has full access to the + all instances of the resource. + """ + if not server_config().rbac_enabled: + return None + + ( + has_full_resource_access, + allowed_ids, + ) = rbac().list_allowed_resource_ids( + user=get_auth_context().user, + resource=Resource(type=resource_type), + action=action, + ) + + if has_full_resource_access: + return None + + return [UUID(id) for id in allowed_ids] + + +def get_resource_type_for_model( + model: "BaseResponseModel", +) -> Optional[ResourceType]: + """Get the resource type associated with a model object. + + Args: + model: The model for which to get the resource type. + + Returns: + The resource type associated with the model, or `None` if the model + is not associated with any resource type. + """ + from zenml.models import ( + ArtifactResponseModel, + CodeRepositoryResponseModel, + ComponentResponseModel, + FlavorResponseModel, + ModelResponseModel, + PipelineResponseModel, + SecretResponseModel, + ServiceConnectorResponseModel, + StackResponseModel, + ) + + mapping: Dict[Type[BaseResponseModel], ResourceType] = { + FlavorResponseModel: ResourceType.FLAVOR, + ServiceConnectorResponseModel: ResourceType.SERVICE_CONNECTOR, + ComponentResponseModel: ResourceType.STACK_COMPONENT, + StackResponseModel: ResourceType.STACK, + PipelineResponseModel: ResourceType.PIPELINE, + CodeRepositoryResponseModel: ResourceType.CODE_REPOSITORY, + SecretResponseModel: ResourceType.SECRET, + ModelResponseModel: ResourceType.MODEL, + ArtifactResponseModel: ResourceType.ARTIFACT, + } + + return mapping.get(type(model)) diff --git a/src/zenml/zen_server/rbac_interface.py b/src/zenml/zen_server/rbac_interface.py deleted file mode 100644 index fb9e0b83247..00000000000 --- a/src/zenml/zen_server/rbac_interface.py +++ /dev/null @@ -1,102 +0,0 @@ -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type -from uuid import UUID - -from pydantic import BaseModel - -from zenml.enums import StrEnum -from zenml.models.base_models import BaseResponseModel - -if TYPE_CHECKING: - from zenml.models import UserResponseModel - - -class Action(StrEnum): - CREATE = "create" - READ = "read" - UPDATE = "update" - DELETE = "delete" - - -class ResourceType(StrEnum): - STACK = "stack" - FLAVOR = "flavor" - STACK_COMPONENT = "stack_component" - PIPELINE = "pipeline" - CODE_REPOSITORY = "code-repository" - MODEL = "model" - SERVICE_CONNECTOR = "service_connector" - ARTIFACT = "artifact" - SECRET = "secret" - - -def get_resource_type_for_model( - model: "BaseResponseModel", -) -> Optional[ResourceType]: - from zenml.models import ( - ArtifactResponseModel, - CodeRepositoryResponseModel, - ComponentResponseModel, - FlavorResponseModel, - ModelResponseModel, - PipelineResponseModel, - SecretResponseModel, - ServiceConnectorResponseModel, - StackResponseModel, - ) - - mapping: Dict[Type[BaseResponseModel], ResourceType] = { - FlavorResponseModel: ResourceType.FLAVOR, - ServiceConnectorResponseModel: ResourceType.SERVICE_CONNECTOR, - ComponentResponseModel: ResourceType.STACK_COMPONENT, - StackResponseModel: ResourceType.STACK, - PipelineResponseModel: ResourceType.PIPELINE, - CodeRepositoryResponseModel: ResourceType.CODE_REPOSITORY, - SecretResponseModel: ResourceType.SECRET, - ModelResponseModel: ResourceType.MODEL, - ArtifactResponseModel: ResourceType.ARTIFACT, - } - - return mapping.get(type(model)) - - -class Resource(BaseModel): - type: str - id: Optional[UUID] = None - - -class RBACInterface(ABC): - @abstractmethod - def has_permission( - self, user: "UserResponseModel", resource: Resource, action: str - ) -> bool: - """Checks if a user has permission to perform an action on a resource. - - Args: - user: User which wants to access a resource. - resource: The resource the user wants to access. - action: The action that the user wants to perform on the resource. - - Returns: - Whether the user has permission to perform an action on a resource. - """ - - @abstractmethod - def list_allowed_resource_ids( - self, user: "UserResponseModel", resource: Resource, action: str - ) -> Tuple[bool, List[str]]: - """Lists all resource IDs of a resource type that a user can access. - - Args: - user: User which wants to access a resource. - resource: The resource the user wants to access. - action: The action that the user wants to perform on the resource. - - Returns: - A tuple (full_resource_access, resource_ids). - `full_resource_access` will be `True` if the user can perform the - given action on any instance of the given resource type, `False` - otherwise. If `full_resource_access` is `False`, `resource_ids` - will contain the list of instance IDs that the user can perform - the action on. - """ diff --git a/src/zenml/zen_server/routers/pipelines_endpoints.py b/src/zenml/zen_server/routers/pipelines_endpoints.py index f033990548e..d9136f85b4f 100644 --- a/src/zenml/zen_server/routers/pipelines_endpoints.py +++ b/src/zenml/zen_server/routers/pipelines_endpoints.py @@ -56,7 +56,7 @@ def list_pipelines( pipeline_filter_model: PipelineFilterModel = Depends( make_dependable(PipelineFilterModel) ), - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> Page[PipelineResponseModel]: """Gets a list of pipelines. @@ -85,7 +85,7 @@ def list_pipelines( @handle_exceptions def get_pipeline( pipeline_id: UUID, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> PipelineResponseModel: """Gets a specific pipeline using its unique id. diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index 6be82b7eddb..c50f9a2fb85 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -23,12 +23,15 @@ from zenml.zen_server.auth import ( AuthContext, authorize, +) +from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import ( dehydrate_response_model, get_allowed_resource_ids, verify_permissions_for_model, verify_read_permissions_and_dehydrate, ) -from zenml.zen_server.exceptions import error_response from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -38,7 +41,7 @@ router = APIRouter( prefix=API + VERSION_1 + STACKS, tags=["stacks"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) @@ -64,9 +67,8 @@ def list_stacks( All stacks. """ allowed_ids = get_allowed_resource_ids( - resource_type="stack", action="read" + resource_type=ResourceType.STACK, action=Action.READ ) - print(allowed_ids) stack_filter_model.set_allowed_ids(allowed_ids) page = zen_store().list_stacks(stack_filter_model=stack_filter_model) @@ -88,7 +90,7 @@ def list_stacks( @handle_exceptions def get_stack( stack_id: UUID, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> StackResponseModel: """Returns the requested stack. @@ -123,7 +125,7 @@ def update_stack( The updated stack. """ stack = zen_store().get_stack(stack_id) - verify_permissions_for_model(stack, action="update") + verify_permissions_for_model(stack, action=Action.UPDATE) return zen_store().update_stack( stack_id=stack_id, @@ -146,6 +148,6 @@ def delete_stack( stack_id: Name of the stack. """ stack = zen_store().get_stack(stack_id) - verify_permissions_for_model(stack, action="delete") + verify_permissions_for_model(stack, action=Action.DELETE) zen_store().delete_stack(stack_id) From 4e98925a85621a36f9be3f678954c30e7cbef06f Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 17 Oct 2023 17:33:58 +0200 Subject: [PATCH 009/103] More cleanup --- src/zenml/zen_server/rbac/utils.py | 2 +- src/zenml/zen_server/routers/stacks_endpoints.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index f23eb5fb026..e423a972896 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -218,7 +218,7 @@ def verify_permissions( def get_allowed_resource_ids( resource_type: str, - action: str, + action: str = Action.READ, ) -> Optional[List[UUID]]: """Get all resource IDs of a resource type that a user can access. diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index c50f9a2fb85..768187b74a2 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -55,20 +55,17 @@ def list_stacks( stack_filter_model: StackFilterModel = Depends( make_dependable(StackFilterModel) ), - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> Page[StackResponseModel]: """Returns all stacks. Args: stack_filter_model: Filter model used for pagination, sorting, filtering - auth_context: Authentication Context Returns: All stacks. """ - allowed_ids = get_allowed_resource_ids( - resource_type=ResourceType.STACK, action=Action.READ - ) + allowed_ids = get_allowed_resource_ids(resource_type=ResourceType.STACK) stack_filter_model.set_allowed_ids(allowed_ids) page = zen_store().list_stacks(stack_filter_model=stack_filter_model) From d19ea40fb19c0dfe385483551dd7940ff60f668d Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 18 Oct 2023 16:33:11 +0200 Subject: [PATCH 010/103] Fix import --- src/zenml/zen_server/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index 9465240125c..8e5eaf56bc1 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -34,7 +34,7 @@ LocalServerDeploymentConfig, ) from zenml.zen_server.exceptions import http_exception_from_error -from zenml.zen_server.rbac_interface import RBACInterface +from zenml.zen_server.rbac.rbac_interface import RBACInterface from zenml.zen_stores.sql_zen_store import SqlZenStore logger = get_logger(__name__) From 0ff6a8e5c3a00b23e9b6a163f96e8f113e9ed187 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 19 Oct 2023 08:52:25 +0200 Subject: [PATCH 011/103] Fix some mypy issues --- src/zenml/cli/stack_components.py | 2 +- src/zenml/models/base_models.py | 8 ++--- src/zenml/models/filter_models.py | 2 +- src/zenml/zen_server/rbac/models.py | 2 +- src/zenml/zen_server/rbac/utils.py | 33 ++++++++++++------- .../zen_stores/schemas/pipeline_schemas.py | 9 ----- src/zenml/zen_stores/sql_zen_store.py | 17 +++++++--- 7 files changed, 40 insertions(+), 33 deletions(-) diff --git a/src/zenml/cli/stack_components.py b/src/zenml/cli/stack_components.py index 538ec9e8f3a..c6737651aca 100644 --- a/src/zenml/cli/stack_components.py +++ b/src/zenml/cli/stack_components.py @@ -199,7 +199,7 @@ def list_stack_components_command( def generate_stack_component_register_command( component_type: StackComponentType, -) -> Callable[[str, str, bool, List[str]], None]: +) -> Callable[[str, str, List[str]], None]: """Generates a `register` command for the specific stack component type. Args: diff --git a/src/zenml/models/base_models.py b/src/zenml/models/base_models.py index caae7a61d1d..7ab2f9b7d23 100644 --- a/src/zenml/models/base_models.py +++ b/src/zenml/models/base_models.py @@ -19,7 +19,6 @@ Dict, List, Set, - Tuple, Type, TypeVar, Union, @@ -135,12 +134,13 @@ def partial(self) -> bool: def _helper(value: Any) -> bool: if isinstance(value, BaseResponseModel): - if value.partial: - return True + return value.partial elif isinstance(value, Dict): return any(_helper(v) for v in value.values()) - elif isinstance(value, (List, Set, Tuple)): + elif isinstance(value, (List, Set, tuple)): return any(_helper(v) for v in value) + else: + return False for field_name in self.__fields__.keys(): value = getattr(self, field_name) diff --git a/src/zenml/models/filter_models.py b/src/zenml/models/filter_models.py index 5ba2324d4e1..ba2ae3e049d 100644 --- a/src/zenml/models/filter_models.py +++ b/src/zenml/models/filter_models.py @@ -768,7 +768,7 @@ def apply_filter( The query with filter applied. """ if self._allowed_ids is not None: - query = query.where(table.id.in_(self._allowed_ids)) + query = query.where(table.id.in_(self._allowed_ids)) # type: ignore[attr-defined] filters = self.generate_filter(table=table) diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index 9bcf3404329..39bd5634a0a 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -18,7 +18,7 @@ from pydantic import BaseModel -from zenml.enums import StrEnum +from zenml.utils.enum_utils import StrEnum class Action(StrEnum): diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index e423a972896..d0659b12f7f 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -15,7 +15,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Set, Tuple, Type +from typing import Any, Dict, List, Optional, Set, Type, TypeVar from uuid import UUID from fastapi import HTTPException @@ -25,10 +25,12 @@ from zenml.zen_server.rbac.models import Action, Resource, ResourceType from zenml.zen_server.utils import rbac, server_config +M = TypeVar("M", bound=BaseResponseModel) + def verify_read_permissions_and_dehydrate( - model: "BaseResponseModel", -) -> "BaseResponseModel": + model: M, +) -> M: """Verify read permissions of the model and dehydrate it if necessary. Args: @@ -46,8 +48,8 @@ def verify_read_permissions_and_dehydrate( def dehydrate_response_model( - model: "BaseResponseModel", -) -> "BaseResponseModel": + model: M, +) -> M: """Dehydrate a model if necessary. Args: @@ -81,7 +83,7 @@ def _maybe_dehydrate_value(value: Any) -> Any: return get_permission_denied_model(value) elif isinstance(value, Dict): return {k: _maybe_dehydrate_value(v) for k, v in value.items()} - elif isinstance(value, (List, Set, Tuple)): + elif isinstance(value, (List, Set, tuple)): type_ = type(value) return type_(_maybe_dehydrate_value(v) for v in value) else: @@ -105,9 +107,7 @@ def has_permissions_for_model(model: "BaseResponseModel", action: str) -> bool: return False -def get_permission_denied_model( - model: "BaseResponseModel", keep_name: bool = True -) -> "BaseResponseModel": +def get_permission_denied_model(model: M, keep_name: bool = True) -> M: """Get a model to return in case of missing read permissions. This function replaces all attributes except name and ID in the given model. @@ -166,10 +166,13 @@ def verify_permissions_for_model( if not server_config().rbac_enabled: return + auth_context = get_auth_context() + assert auth_context + if ( isinstance(model, UserScopedResponseModel) and model.user - and model.user.id == get_auth_context().user.id + and model.user.id == auth_context.user.id ): # User is the owner of the model return @@ -204,10 +207,13 @@ def verify_permissions( if not server_config().rbac_enabled: return + auth_context = get_auth_context() + assert auth_context + resource = Resource(type=resource_type, id=resource_id) if not rbac().has_permission( - user=get_auth_context().user, resource=resource, action=action + user=auth_context.user, resource=resource, action=action ): raise HTTPException( status_code=403, @@ -230,6 +236,9 @@ def get_allowed_resource_ids( A list of resource IDs or `None` if the user has full access to the all instances of the resource. """ + auth_context = get_auth_context() + assert auth_context + if not server_config().rbac_enabled: return None @@ -237,7 +246,7 @@ def get_allowed_resource_ids( has_full_resource_access, allowed_ids, ) = rbac().list_allowed_resource_ids( - user=get_auth_context().user, + user=auth_context.user, resource=Resource(type=resource_type), action=action, ) diff --git a/src/zenml/zen_stores/schemas/pipeline_schemas.py b/src/zenml/zen_stores/schemas/pipeline_schemas.py index 37966e17102..e04057eb460 100644 --- a/src/zenml/zen_stores/schemas/pipeline_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_schemas.py @@ -154,14 +154,5 @@ def update( Returns: The updated `PipelineSchema`. """ - if pipeline_update.name: - self.name = pipeline_update.name - - if pipeline_update.docstring: - self.docstring = pipeline_update.docstring - - if pipeline_update.spec: - self.spec = pipeline_update.spec.json(sort_keys=True) - self.updated = datetime.utcnow() return self diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 0a70c74ac98..53dd93f1633 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -1144,7 +1144,8 @@ def update_stack( f"existing stack with this id." ) if ( - existing_stack.name + existing_stack.user_id + and existing_stack.name == self._get_default_stack_and_component_name( existing_stack.user_id ) @@ -1203,8 +1204,12 @@ def delete_stack(self, stack_id: UUID) -> None: if stack is None: raise KeyError(f"Stack with ID {stack_id} not found.") - if stack.name == self._get_default_stack_and_component_name( - user_id=stack.user_id + if ( + stack.user_id + and stack.name + == self._get_default_stack_and_component_name( + user_id=stack.user_id + ) ): raise IllegalOperationError( "The default stack cannot be deleted." @@ -1426,7 +1431,8 @@ def update_stack_component( ) if ( - existing_component.name + existing_component.user_id + and existing_component.name == self._get_default_stack_and_component_name( user_id=existing_component.user_id ) @@ -1500,7 +1506,8 @@ def delete_stack_component(self, component_id: UUID) -> None: if stack_component is None: raise KeyError(f"Stack with ID {component_id} not found.") if ( - stack_component.name + stack_component.user_id + and stack_component.name == self._get_default_stack_and_component_name( user_id=stack_component.user_id ) From 88d4494a5e4e833f4c7209451a41177710acfcec Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 19 Oct 2023 11:10:47 +0200 Subject: [PATCH 012/103] Add helm chart option for rbac implementation source --- .../zen_server/deploy/helm/templates/server-deployment.yaml | 4 ++++ src/zenml/zen_server/deploy/helm/values.yaml | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/src/zenml/zen_server/deploy/helm/templates/server-deployment.yaml b/src/zenml/zen_server/deploy/helm/templates/server-deployment.yaml index ccbcf4e3862..2b4a727182f 100644 --- a/src/zenml/zen_server/deploy/helm/templates/server-deployment.yaml +++ b/src/zenml/zen_server/deploy/helm/templates/server-deployment.yaml @@ -216,6 +216,10 @@ spec: - name: ZENML_SERVER_ROOT_URL_PATH value: {{ .Values.zenml.rootUrlPath | quote }} {{- end }} + {{- if .Values.zenml.rbacImplementationSource }} + - name: ZENML_SERVER_RBAC_IMPLEMENTATION_SOURCE + value: {{ .Values.zenml.rbacImplementationSource | quote }} + {{- end }} - name: ZENML_DEFAULT_PROJECT_NAME value: {{ .Values.zenml.defaultProject | quote }} - name: ZENML_DEFAULT_USER_NAME diff --git a/src/zenml/zen_server/deploy/helm/values.yaml b/src/zenml/zen_server/deploy/helm/values.yaml index 3ad334b8543..5c65416bed5 100644 --- a/src/zenml/zen_server/deploy/helm/values.yaml +++ b/src/zenml/zen_server/deploy/helm/values.yaml @@ -347,6 +347,11 @@ zenml: # mounted as environment variables in the ZenML server container. secretEnvironment: {} + # Source pointing to a class implementing the RBAC interface defined by + # `zenml.zen_server.rbac_interface.RBACInterface`. If not specified, + # RBAC will not be enabled for this server. + rbacImplementationSource: + service: type: ClusterIP port: 80 From a00c6661c4b3733ab9bc6d945add3498e00256bc Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 19 Oct 2023 11:13:54 +0200 Subject: [PATCH 013/103] Cleanup --- src/zenml/models/server_models.py | 3 --- src/zenml/models/stack_models.py | 8 +------- src/zenml/zen_server/routers/server_endpoints.py | 7 +------ 3 files changed, 2 insertions(+), 16 deletions(-) diff --git a/src/zenml/models/server_models.py b/src/zenml/models/server_models.py index f2466ff3c3a..6bf06a0530b 100644 --- a/src/zenml/models/server_models.py +++ b/src/zenml/models/server_models.py @@ -54,9 +54,6 @@ class ServerModel(BaseModel): title="The ZenML version that the server is running.", ) - zenml_cloud: bool = Field( - False, title="Flag to indicate whether this is a ZenML cloud server." - ) debug: bool = Field( False, title="Flag to indicate whether ZenML is running on debug mode." ) diff --git a/src/zenml/models/stack_models.py b/src/zenml/models/stack_models.py index 2fc9cc72476..74e26e1852c 100644 --- a/src/zenml/models/stack_models.py +++ b/src/zenml/models/stack_models.py @@ -118,13 +118,7 @@ def to_yaml(self) -> Dict[str, Any]: class StackFilterModel(WorkspaceScopedFilterModel): - """Model to enable advanced filtering of all StackModels. - - The Stack Model needs additional scoping. As such the `_scope_user` field - can be set to the user that is doing the filtering. The - `generate_filter()` method of the baseclass is overwritten to include the - scoping. - """ + """Model to enable advanced filtering of all StackModels.""" # `component_id` refers to a relationship through a link-table # rather than a field in the db, hence it needs to be handled diff --git a/src/zenml/zen_server/routers/server_endpoints.py b/src/zenml/zen_server/routers/server_endpoints.py index 5fe9b3ccab0..7fb1fed878d 100644 --- a/src/zenml/zen_server/routers/server_endpoints.py +++ b/src/zenml/zen_server/routers/server_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for authentication (login).""" -import os from fastapi import APIRouter @@ -52,8 +51,4 @@ def server_info() -> ServerModel: Returns: Information about the server. """ - info = zen_store().get_store_info() - if "ZENML_CLOUD" in os.environ: - info.zenml_cloud = True - - return info + return zen_store().get_store_info() From 0eb831c606887e3cfd91f0d18af81f60818c872f Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 23 Oct 2023 14:16:45 +0200 Subject: [PATCH 014/103] Fix alembic order --- .../migrations/versions/7500f434b71c_remove_shared_columns.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py index f9498806463..8859b58a11a 100644 --- a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -1,7 +1,7 @@ """Remove shared columns [7500f434b71c]. Revision ID: 7500f434b71c -Revises: 0.45.1 +Revises: 0.45.4 Create Date: 2023-10-16 15:15:34.865337 """ @@ -10,7 +10,7 @@ # revision identifiers, used by Alembic. revision = "7500f434b71c" -down_revision = "0.45.1" +down_revision = "0.45.4" branch_labels = None depends_on = None From f552c1c05ff6245ce2ae2678b3e80f78ea88a4b3 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 23 Oct 2023 15:15:58 +0200 Subject: [PATCH 015/103] Add 403 response in OpenAPI --- src/zenml/zen_server/routers/artifacts_endpoints.py | 2 +- src/zenml/zen_server/routers/code_repositories_endpoints.py | 2 +- src/zenml/zen_server/routers/flavors_endpoints.py | 2 +- src/zenml/zen_server/routers/models_endpoints.py | 2 +- src/zenml/zen_server/routers/pipeline_builds_endpoints.py | 2 +- .../zen_server/routers/pipeline_deployments_endpoints.py | 2 +- src/zenml/zen_server/routers/pipelines_endpoints.py | 4 +--- src/zenml/zen_server/routers/run_metadata_endpoints.py | 2 +- src/zenml/zen_server/routers/runs_endpoints.py | 2 +- src/zenml/zen_server/routers/schedule_endpoints.py | 2 +- src/zenml/zen_server/routers/secrets_endpoints.py | 2 +- src/zenml/zen_server/routers/service_connectors_endpoints.py | 4 ++-- src/zenml/zen_server/routers/stack_components_endpoints.py | 4 ++-- src/zenml/zen_server/routers/stacks_endpoints.py | 1 - src/zenml/zen_server/routers/steps_endpoints.py | 2 +- 15 files changed, 16 insertions(+), 19 deletions(-) diff --git a/src/zenml/zen_server/routers/artifacts_endpoints.py b/src/zenml/zen_server/routers/artifacts_endpoints.py index dd4b766d269..6dd289f442a 100644 --- a/src/zenml/zen_server/routers/artifacts_endpoints.py +++ b/src/zenml/zen_server/routers/artifacts_endpoints.py @@ -39,7 +39,7 @@ router = APIRouter( prefix=API + VERSION_1 + ARTIFACTS, tags=["artifacts"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) diff --git a/src/zenml/zen_server/routers/code_repositories_endpoints.py b/src/zenml/zen_server/routers/code_repositories_endpoints.py index 9882bbf1697..c0dd24daad2 100644 --- a/src/zenml/zen_server/routers/code_repositories_endpoints.py +++ b/src/zenml/zen_server/routers/code_repositories_endpoints.py @@ -34,7 +34,7 @@ router = APIRouter( prefix=API + VERSION_1 + CODE_REPOSITORIES, tags=["code_repositories"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) diff --git a/src/zenml/zen_server/routers/flavors_endpoints.py b/src/zenml/zen_server/routers/flavors_endpoints.py index beb66de4882..182c69cbe62 100644 --- a/src/zenml/zen_server/routers/flavors_endpoints.py +++ b/src/zenml/zen_server/routers/flavors_endpoints.py @@ -37,7 +37,7 @@ router = APIRouter( prefix=API + VERSION_1 + FLAVORS, tags=["flavors"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index a48b16ad5dd..9699e8f29e0 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -56,7 +56,7 @@ router = APIRouter( prefix=API + VERSION_1 + MODELS, tags=["models"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) diff --git a/src/zenml/zen_server/routers/pipeline_builds_endpoints.py b/src/zenml/zen_server/routers/pipeline_builds_endpoints.py index f053fe26c7b..90d4e69818e 100644 --- a/src/zenml/zen_server/routers/pipeline_builds_endpoints.py +++ b/src/zenml/zen_server/routers/pipeline_builds_endpoints.py @@ -30,7 +30,7 @@ router = APIRouter( prefix=API + VERSION_1 + PIPELINE_BUILDS, tags=["builds"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) diff --git a/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py b/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py index 70fd486a234..1fb1efff3db 100644 --- a/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py +++ b/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py @@ -33,7 +33,7 @@ router = APIRouter( prefix=API + VERSION_1 + PIPELINE_DEPLOYMENTS, tags=["deployments"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) diff --git a/src/zenml/zen_server/routers/pipelines_endpoints.py b/src/zenml/zen_server/routers/pipelines_endpoints.py index d9136f85b4f..b1870e3ae83 100644 --- a/src/zenml/zen_server/routers/pipelines_endpoints.py +++ b/src/zenml/zen_server/routers/pipelines_endpoints.py @@ -37,7 +37,7 @@ router = APIRouter( prefix=API + VERSION_1 + PIPELINES, tags=["pipelines"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) @@ -46,7 +46,6 @@ response_model=Page[PipelineResponseModel], responses={ 401: error_response, - 403: error_response, 404: error_response, 422: error_response, }, @@ -77,7 +76,6 @@ def list_pipelines( response_model=PipelineResponseModel, responses={ 401: error_response, - 403: error_response, 404: error_response, 422: error_response, }, diff --git a/src/zenml/zen_server/routers/run_metadata_endpoints.py b/src/zenml/zen_server/routers/run_metadata_endpoints.py index 7849ff0ba98..381cec799c6 100644 --- a/src/zenml/zen_server/routers/run_metadata_endpoints.py +++ b/src/zenml/zen_server/routers/run_metadata_endpoints.py @@ -31,7 +31,7 @@ router = APIRouter( prefix=API + VERSION_1 + RUN_METADATA, tags=["run_metadata"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) diff --git a/src/zenml/zen_server/routers/runs_endpoints.py b/src/zenml/zen_server/routers/runs_endpoints.py index eb85a872bca..b44e988fcce 100644 --- a/src/zenml/zen_server/routers/runs_endpoints.py +++ b/src/zenml/zen_server/routers/runs_endpoints.py @@ -47,7 +47,7 @@ router = APIRouter( prefix=API + VERSION_1 + RUNS, tags=["runs"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) diff --git a/src/zenml/zen_server/routers/schedule_endpoints.py b/src/zenml/zen_server/routers/schedule_endpoints.py index 443df90ded3..c4d3c8b2969 100644 --- a/src/zenml/zen_server/routers/schedule_endpoints.py +++ b/src/zenml/zen_server/routers/schedule_endpoints.py @@ -34,7 +34,7 @@ router = APIRouter( prefix=API + VERSION_1 + SCHEDULES, tags=["schedules"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) diff --git a/src/zenml/zen_server/routers/secrets_endpoints.py b/src/zenml/zen_server/routers/secrets_endpoints.py index 8a155179bd8..0390836758e 100644 --- a/src/zenml/zen_server/routers/secrets_endpoints.py +++ b/src/zenml/zen_server/routers/secrets_endpoints.py @@ -36,7 +36,7 @@ router = APIRouter( prefix=API + VERSION_1 + SECRETS, tags=["secrets"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) diff --git a/src/zenml/zen_server/routers/service_connectors_endpoints.py b/src/zenml/zen_server/routers/service_connectors_endpoints.py index 086a9a941fc..c9fc047a68b 100644 --- a/src/zenml/zen_server/routers/service_connectors_endpoints.py +++ b/src/zenml/zen_server/routers/service_connectors_endpoints.py @@ -46,13 +46,13 @@ router = APIRouter( prefix=API + VERSION_1 + SERVICE_CONNECTORS, tags=["service_connectors"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) types_router = APIRouter( prefix=API + VERSION_1 + SERVICE_CONNECTOR_TYPES, tags=["service_connectors"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) diff --git a/src/zenml/zen_server/routers/stack_components_endpoints.py b/src/zenml/zen_server/routers/stack_components_endpoints.py index 6b2813bf416..f8088f9693d 100644 --- a/src/zenml/zen_server/routers/stack_components_endpoints.py +++ b/src/zenml/zen_server/routers/stack_components_endpoints.py @@ -36,13 +36,13 @@ router = APIRouter( prefix=API + VERSION_1 + STACK_COMPONENTS, tags=["stack_components"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) types_router = APIRouter( prefix=API + VERSION_1 + COMPONENT_TYPES, tags=["stack_components"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index 768187b74a2..6ebe6bb1792 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -79,7 +79,6 @@ def list_stacks( response_model=StackResponseModel, responses={ 401: error_response, - 403: error_response, 404: error_response, 422: error_response, }, diff --git a/src/zenml/zen_server/routers/steps_endpoints.py b/src/zenml/zen_server/routers/steps_endpoints.py index ab62c1a6d5e..61972b5ab7d 100644 --- a/src/zenml/zen_server/routers/steps_endpoints.py +++ b/src/zenml/zen_server/routers/steps_endpoints.py @@ -49,7 +49,7 @@ router = APIRouter( prefix=API + VERSION_1 + STEPS, tags=["steps"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) From df77f37600235c9591a60939191fca49b36bc0d5 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 23 Oct 2023 17:05:07 +0200 Subject: [PATCH 016/103] Better page dehydration --- src/zenml/zen_server/rbac/models.py | 5 + src/zenml/zen_server/rbac/rbac_interface.py | 15 +- src/zenml/zen_server/rbac/utils.py | 185 +++++++++++++++--- .../zen_server/routers/stacks_endpoints.py | 9 +- 4 files changed, 180 insertions(+), 34 deletions(-) diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index 39bd5634a0a..4538e963a70 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -61,3 +61,8 @@ def __str__(self) -> str: representation += f"/{self.id}" return representation + + class Config: + """Pydantic configuration class.""" + + frozen = True diff --git a/src/zenml/zen_server/rbac/rbac_interface.py b/src/zenml/zen_server/rbac/rbac_interface.py index b71ed7702a5..99c401a7a9a 100644 --- a/src/zenml/zen_server/rbac/rbac_interface.py +++ b/src/zenml/zen_server/rbac/rbac_interface.py @@ -14,7 +14,7 @@ """RBAC interface definition.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Set, Tuple from zenml.zen_server.rbac.models import Resource @@ -26,18 +26,19 @@ class RBACInterface(ABC): """RBAC interface definition.""" @abstractmethod - def has_permission( - self, user: "UserResponseModel", resource: Resource, action: str - ) -> bool: + def check_permissions( + self, user: "UserResponseModel", resources: Set[Resource], action: str + ) -> Dict[Resource, bool]: """Checks if a user has permission to perform an action on a resource. Args: user: User which wants to access a resource. - resource: The resource the user wants to access. - action: The action that the user wants to perform on the resource. + resources: The resources the user wants to access. + action: The action that the user wants to perform on the resources. Returns: - Whether the user has permission to perform an action on a resource. + A dictionary mapping resources to a boolean which indicates whether + the user has permissions to perform the action on that resource. """ @abstractmethod diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index d0659b12f7f..84fdfd710d6 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -21,6 +21,7 @@ from fastapi import HTTPException from zenml.models.base_models import BaseResponseModel, UserScopedResponseModel +from zenml.models.page_model import Page from zenml.zen_server.auth import get_auth_context from zenml.zen_server.rbac.models import Action, Resource, ResourceType from zenml.zen_server.utils import rbac, server_config @@ -47,13 +48,43 @@ def verify_read_permissions_and_dehydrate( return dehydrate_response_model(model=model) +def dehydrate_page(page: Page[M]) -> Page[M]: + """Dehydrate all items of a page. + + Args: + page: The page to dehydrate. + + Returns: + The page with (potentially) dehydrated items. + """ + auth_context = get_auth_context() + assert auth_context + + resource_list = [get_subresources_for_model(item) for item in page.items] + resources = set.union(*resource_list) if resource_list else set() + permissions = rbac().check_permissions( + user=auth_context.user, resources=resources, action=Action.READ + ) + + new_items = [ + dehydrate_response_model(item, permissions=permissions) + for item in page.items + ] + + return page.copy(update={"items": new_items}) + + def dehydrate_response_model( - model: M, + model: M, permissions: Optional[Dict[Resource, bool]] = None ) -> M: """Dehydrate a model if necessary. Args: model: The model to dehydrate. + permissions: Prefetched permissions that will be used to check whether + sub-models will be included in the model or not. If a sub-model + refers to a resource which is not included in this dictionary, the + permissions will be checked with the RBAC component. Returns: The (potentially) dehydrated model. @@ -62,30 +93,48 @@ def dehydrate_response_model( for field_name in model.__fields__.keys(): value = getattr(model, field_name) - dehydrated_fields[field_name] = _maybe_dehydrate_value(value) + dehydrated_fields[field_name] = _dehydrate_value( + value, permissions=permissions + ) return type(model).parse_obj(dehydrated_fields) -def _maybe_dehydrate_value(value: Any) -> Any: +def _dehydrate_value( + value: Any, permissions: Optional[Dict[Resource, bool]] = None +) -> Any: """Helper function to recursive dehydrate any object. Args: value: The value to dehydrate. + permissions: Prefetched permissions that will be used to check whether + sub-models will be included in the model or not. If a sub-model + refers to a resource which is not included in this dictionary, the + permissions will be checked with the RBAC component. Returns: The recursively dehydrated value. """ if isinstance(value, BaseResponseModel): - if has_permissions_for_model(model=value, action=Action.READ): - return dehydrate_response_model(value) + resource = get_resource_for_model(value) + has_permissions = resource and (permissions or {}).get(resource, False) + + if has_permissions or has_permissions_for_model( + model=value, action=Action.READ + ): + return dehydrate_response_model(value, permissions=permissions) else: return get_permission_denied_model(value) elif isinstance(value, Dict): - return {k: _maybe_dehydrate_value(v) for k, v in value.items()} + return { + k: _dehydrate_value(v, permissions=permissions) + for k, v in value.items() + } elif isinstance(value, (List, Set, tuple)): type_ = type(value) - return type_(_maybe_dehydrate_value(v) for v in value) + return type_( + _dehydrate_value(v, permissions=permissions) for v in value + ) else: return value @@ -166,15 +215,8 @@ def verify_permissions_for_model( if not server_config().rbac_enabled: return - auth_context = get_auth_context() - assert auth_context - - if ( - isinstance(model, UserScopedResponseModel) - and model.user - and model.user.id == auth_context.user.id - ): - # User is the owner of the model + if is_owned_by_authenticated_user(model): + # The model owner always has permissions return resource_type = get_resource_type_for_model(model) @@ -203,6 +245,7 @@ def verify_permissions( Raises: HTTPException: If the user is not allowed to perform the action. + RuntimeError: If the permission verification failed unexpectedly. """ if not server_config().rbac_enabled: return @@ -211,10 +254,19 @@ def verify_permissions( assert auth_context resource = Resource(type=resource_type, id=resource_id) + permissions = rbac().check_permissions( + user=auth_context.user, resources={resource}, action=action + ) - if not rbac().has_permission( - user=auth_context.user, resource=resource, action=action - ): + if resource not in permissions: + # This should never happen if the RBAC implementation is working + # correctly + raise RuntimeError( + f"Failed to verify permissions to {action.upper()} resource " + f"'{resource}'." + ) + + if not permissions[resource]: raise HTTPException( status_code=403, detail=f"Insufficient permissions to {action.upper()} resource " @@ -236,12 +288,12 @@ def get_allowed_resource_ids( A list of resource IDs or `None` if the user has full access to the all instances of the resource. """ - auth_context = get_auth_context() - assert auth_context - if not server_config().rbac_enabled: return None + auth_context = get_auth_context() + assert auth_context + ( has_full_resource_access, allowed_ids, @@ -257,6 +309,24 @@ def get_allowed_resource_ids( return [UUID(id) for id in allowed_ids] +def get_resource_for_model(model: "BaseResponseModel") -> Optional[Resource]: + """Get the resource associated with a model object. + + Args: + model: The model for which to get the resource. + + Returns: + The resource associated with the model, or `None` if the model + is not associated with any resource type. + """ + resource_type = get_resource_type_for_model(model) + if not resource_type: + # This model is not tied to any RBAC resource type + return None + + return Resource(type=resource_type, id=model.id) + + def get_resource_type_for_model( model: "BaseResponseModel", ) -> Optional[ResourceType]: @@ -294,3 +364,74 @@ def get_resource_type_for_model( } return mapping.get(type(model)) + + +def is_owned_by_authenticated_user(model: "BaseResponseModel") -> bool: + """Returns whether the currently authenticated user owns the model. + + Args: + model: The model for which to check the ownership. + + Returns: + Whether the currently authenticated user owns the model. + """ + auth_context = get_auth_context() + assert auth_context + + if ( + isinstance(model, UserScopedResponseModel) + and model.user + and model.user.id == auth_context.user.id + ): + # User is the owner of the model + return True + + return False + + +def get_subresources_for_model( + model: "BaseResponseModel", +) -> Set[Resource]: + """Get all subresources of a model which need permission verification. + + Args: + model: The model for which to get all the resources. + + Returns: + All resources of a model which need permission verification. + """ + resources = set() + + for field_name in model.__fields__.keys(): + value = getattr(model, field_name) + resources.update(_get_subresources_for_value(value)) + + return resources + + +def _get_subresources_for_value(value: Any) -> Set[Resource]: + """Helper function to recursive retrieve resources of any object. + + Args: + value: The value for which to get all the resources. + + Returns: + All resources of the value which need permission verification. + """ + if isinstance(value, BaseResponseModel): + resources = set() + if not is_owned_by_authenticated_user(value): + if resource := get_resource_for_model(value): + resources.add(resource) + + return resources.union(get_subresources_for_model(value)) + elif isinstance(value, Dict): + resources_list = [ + _get_subresources_for_value(v) for v in value.values() + ] + return set.union(*resources_list) if resources_list else set() + elif isinstance(value, (List, Set, tuple)): + resources_list = [_get_subresources_for_value(v) for v in value] + return set.union(*resources_list) if resources_list else set() + else: + return set() diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index 6ebe6bb1792..82bae390dfa 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -27,6 +27,7 @@ from zenml.zen_server.exceptions import error_response from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( + dehydrate_page, dehydrate_response_model, get_allowed_resource_ids, verify_permissions_for_model, @@ -68,10 +69,7 @@ def list_stacks( allowed_ids = get_allowed_resource_ids(resource_type=ResourceType.STACK) stack_filter_model.set_allowed_ids(allowed_ids) page = zen_store().list_stacks(stack_filter_model=stack_filter_model) - - # TODO: make this better, this is sending a ton of requests here - page.items = [dehydrate_response_model(model) for model in page.items] - return page + return dehydrate_page(page) @router.get( @@ -123,10 +121,11 @@ def update_stack( stack = zen_store().get_stack(stack_id) verify_permissions_for_model(stack, action=Action.UPDATE) - return zen_store().update_stack( + updated_stack = zen_store().update_stack( stack_id=stack_id, stack_update=stack_update, ) + return dehydrate_response_model(updated_stack) @router.delete( From 8f251027cfab1477e58389509c0cbb9b05614a61 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 23 Oct 2023 17:08:30 +0200 Subject: [PATCH 017/103] ZenML cloud rbac implementation --- src/zenml/zen_server/rbac/zenml_cloud_rbac.py | 253 ++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 src/zenml/zen_server/rbac/zenml_cloud_rbac.py diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py new file mode 100644 index 00000000000..1a9307d915c --- /dev/null +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -0,0 +1,253 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Cloud RBAC implementation.""" +import os +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple + +import requests +from pydantic import BaseModel, validator + +from zenml.zen_server.rbac.models import Resource +from zenml.zen_server.rbac.rbac_interface import RBACInterface +from zenml.zen_server.utils import server_config + +if TYPE_CHECKING: + from zenml.models import UserResponseModel + + +ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_" +PERMISSIONS_ENDPOINT = "/rbac/check_permissions" +ALLOWED_RESOURCE_IDS_ENDPOINT = "/rbac/allowed_resource_ids" + +SERVER_ID = server_config().external_server_id + + +def _convert_to_cloud_resource(resource: Resource) -> str: + """Convert a resource to a ZenML Cloud API resource. + + Args: + resource: The resource to convert. + + Returns: + The converted resource. + """ + resource_string = f"{SERVER_ID}@server:{resource.type}" + + if resource.id: + resource_string += f"/{resource.id}" + + return resource_string + + +class ZenMLCloudRBACConfiguration(BaseModel): + """ZenML Cloud RBAC configuration.""" + + api_url: str + + oauth2_client_id: str + oauth2_client_secret: str + oauth2_audience: str + auth0_domain: str + + @validator("api_url") + def _strip_trailing_slashes_url(cls, url: str) -> str: + """Strip any trailing slashes on the API URL. + + Args: + url: The API URL. + + Returns: + The API URL with potential trailing slashes removed. + """ + return url.rstrip("/") + + @classmethod + def from_environment(cls) -> "ZenMLCloudRBACConfiguration": + """Get the RBAC configuration from environment variables. + + Returns: + The RBAC configuration. + """ + env_config: Dict[str, Any] = {} + for k, v in os.environ.items(): + if v == "": + continue + if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX): + env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v + + return ZenMLCloudRBACConfiguration(**env_config) + + class Config: + """Pydantic configuration class.""" + + # Allow extra attributes from configs of previous ZenML versions to + # permit downgrading + extra = "allow" + + +class ZenMLCloudRBAC(RBACInterface): + """RBAC implementation that uses the ZenML Cloud API as a backend.""" + + def __init__(self) -> None: + """Initialize the RBAC component.""" + self._config = ZenMLCloudRBACConfiguration.from_environment() + self._session: Optional[requests.Session] = None + + def check_permissions( + self, user: "UserResponseModel", resources: Set[Resource], action: str + ) -> Dict[Resource, bool]: + """Checks if a user has permission to perform an action on a resource. + + Args: + user: User which wants to access a resource. + resources: The resources the user wants to access. + action: The action that the user wants to perform on the resources. + + Returns: + A dictionary mapping resources to a boolean which indicates whether + the user has permissions to perform the action on that resource. + """ + assert user.external_user_id + + if not resources: + # No need to send a request if there are no resources + return {} + + params = { + "user_id": str(user.external_user_id), + "resources": [ + _convert_to_cloud_resource(resource) for resource in resources + ], + "action": action, + } + response = self._get(endpoint=PERMISSIONS_ENDPOINT, params=params) + value = response.json() + + assert isinstance(value, dict) + return value + + def list_allowed_resource_ids( + self, user: "UserResponseModel", resource: Resource, action: str + ) -> Tuple[bool, List[str]]: + """Lists all resource IDs of a resource type that a user can access. + + Args: + user: User which wants to access a resource. + resource: The resource the user wants to access. + action: The action that the user wants to perform on the resource. + + Returns: + A tuple (full_resource_access, resource_ids). + `full_resource_access` will be `True` if the user can perform the + given action on any instance of the given resource type, `False` + otherwise. If `full_resource_access` is `False`, `resource_ids` + will contain the list of instance IDs that the user can perform + the action on. + """ + assert not resource.id + assert user.external_user_id + params = { + "user_id": str(user.external_user_id), + "resource": _convert_to_cloud_resource(resource), + "action": action, + } + response = self._get( + endpoint=ALLOWED_RESOURCE_IDS_ENDPOINT, params=params + ) + response_json = response.json() + + full_resource_access: bool = response_json["full_access"] + allowed_ids: List[str] = response_json["ids"] + + return full_resource_access, allowed_ids + + def _get(self, endpoint: str, params: Dict[str, Any]) -> requests.Response: + """Send a GET request using the active session. + + Args: + endpoint: The endpoint to send the request to. This will be appended + to the base URL. + params: Parameters to include in the request. + + Raises: + RuntimeError: If the request failed. + + Returns: + The response. + """ + url = self._config.api_url + endpoint + + response = self.session.get(url=url, params=params) + if response.status_code == 401: + # Refresh the auth token and try again + self._clear_session() + response = self.session.get(url=url, params=params) + + try: + response.raise_for_status() + except requests.HTTPError as e: + raise RuntimeError( + "Failed while trying to contact RBAC service." + ) from e + + return response + + @property + def session(self) -> requests.Session: + """Authenticate to the ZenML Cloud API. + + Returns: + A requests session with the authentication token. + """ + if self._session is None: + self._session = requests.Session() + token = self._fetch_auth_token() + self._session.headers.update({"Authorization": "Bearer " + token}) + + return self._session + + def _clear_session(self) -> None: + """Clear the authentication session.""" + self._session = None + + def _fetch_auth_token(self) -> str: + """Fetch an auth token for the Cloud API from auth0. + + Raises: + RuntimeError: If the auth token can't be fetched. + + Returns: + Auth token. + """ + # Get an auth token from auth0 + auth0_url = f"https://{self._config.auth0_domain}/oauth/token" + headers = {"content-type": "application/x-www-form-urlencoded"} + payload = { + "client_id": self._config.oauth2_client_id, + "client_secret": self._config.oauth2_client_secret, + "audience": self._config.oauth2_audience, + "grant_type": "client_credentials", + } + try: + response = requests.post(auth0_url, headers=headers, data=payload) + response.raise_for_status() + except Exception as e: + raise RuntimeError(f"Error fetching auth token from auth0: {e}") + + access_token = response.json().get("access_token", "") + + if not access_token or not isinstance(access_token, str): + raise RuntimeError("Could not fetch auth token from auth0.") + + return access_token From 0b64217ea4d37a435d9a485b5e2aef91d146a2cc Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 23 Oct 2023 17:19:15 +0200 Subject: [PATCH 018/103] Don't include recursive ids --- src/zenml/zen_server/rbac/utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 84fdfd710d6..fc29aeb5db3 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -156,13 +156,16 @@ def has_permissions_for_model(model: "BaseResponseModel", action: str) -> bool: return False -def get_permission_denied_model(model: M, keep_name: bool = True) -> M: +def get_permission_denied_model( + model: M, keep_id: bool = True, keep_name: bool = True +) -> M: """Get a model to return in case of missing read permissions. This function replaces all attributes except name and ID in the given model. Args: model: The original model. + keep_id: If `True`, the model ID will not be replaced. keep_name: If `True`, the model name will not be replaced. Returns: @@ -173,14 +176,16 @@ def get_permission_denied_model(model: M, keep_name: bool = True) -> M: for field_name, field in model.__fields__.items(): value = getattr(model, field_name) - if field_name == "id" and isinstance(value, UUID): + if keep_id and field_name == "id" and isinstance(value, UUID): pass elif keep_name and field_name == "name" and isinstance(value, str): pass elif field.allow_none: value = None elif isinstance(value, BaseResponseModel): - value = get_permission_denied_model(value, keep_name=False) + value = get_permission_denied_model( + value, keep_id=False, keep_name=False + ) elif isinstance(value, UUID): value = UUID(int=0) elif isinstance(value, datetime): @@ -196,7 +201,6 @@ def get_permission_denied_model(model: M, keep_name: bool = True) -> M: values[field_name] = value - # TODO: With the new hydration models, make sure we clear metadata here values["missing_permissions"] = True return type(model).parse_obj(values) From 90d16c8345526bd887ec3f9433cd7ca9e3de263b Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 23 Oct 2023 17:36:31 +0200 Subject: [PATCH 019/103] Update docstring --- src/zenml/zen_server/rbac/rbac_interface.py | 2 +- src/zenml/zen_server/rbac/zenml_cloud_rbac.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_server/rbac/rbac_interface.py b/src/zenml/zen_server/rbac/rbac_interface.py index 99c401a7a9a..8a89629d5e4 100644 --- a/src/zenml/zen_server/rbac/rbac_interface.py +++ b/src/zenml/zen_server/rbac/rbac_interface.py @@ -29,7 +29,7 @@ class RBACInterface(ABC): def check_permissions( self, user: "UserResponseModel", resources: Set[Resource], action: str ) -> Dict[Resource, bool]: - """Checks if a user has permission to perform an action on a resource. + """Checks if a user has permissions to perform an action on resources. Args: user: User which wants to access a resource. diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py index 1a9307d915c..349b8fc734c 100644 --- a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -107,7 +107,7 @@ def __init__(self) -> None: def check_permissions( self, user: "UserResponseModel", resources: Set[Resource], action: str ) -> Dict[Resource, bool]: - """Checks if a user has permission to perform an action on a resource. + """Checks if a user has permissions to perform an action on resources. Args: user: User which wants to access a resource. From 2925d2189d2edcef28cadbd93b15ac611978303a Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 24 Oct 2023 09:48:21 +0200 Subject: [PATCH 020/103] Migrate private components --- .../7500f434b71c_remove_shared_columns.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py index 8859b58a11a..4d82413a260 100644 --- a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -15,9 +15,61 @@ depends_on = None +def _migrate_default_entities(table: sa.Table) -> None: + connection = op.get_bind() + + query = sa.select( + table.c.id, + table.c.user_id, + ).where(table.c.name == "default") + + res = connection.execute(query).fetchall() + for id, owner_id in res: + name = f"default-{owner_id}" + + connection.execute( + sa.update(table).where(table.c.id == id).values(name=name) + ) + + +def resolve_duplicate_names() -> None: + meta = sa.MetaData(bind=op.get_bind()) + meta.reflect(only=("stack", "stack_component", "service_connector")) + + stack_table = sa.Table("stack", meta) + stack_component_table = sa.Table("stack_component", meta) + + _migrate_default_entities(stack_table) + _migrate_default_entities(stack_component_table) + + service_connector_table = sa.Table("service_connector", meta) + query = sa.select( + service_connector_table.c.id, + service_connector_table.c.name, + service_connector_table.c.user_id, + ) + + connection = op.get_bind() + names = set() + for id, name, user_id in connection.execute(query).fetchall(): + if name in names: + name = f"{name}-{user_id}" + # This will never happen, as we had a constraint on unique names + # per user + assert name not in names + connection.execute( + sa.update(service_connector_table) + .where(service_connector_table.c.id == id) + .values(name=name) + ) + + names.add(name) + + def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" # ### commands auto generated by Alembic - please adjust! ### + resolve_duplicate_names() with op.batch_alter_table("service_connector", schema=None) as batch_op: batch_op.drop_column("is_shared") From c82d486ef6f917e1d5594bdd65829a23d51979fb Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 24 Oct 2023 09:59:34 +0200 Subject: [PATCH 021/103] Add docstrings --- .../versions/7500f434b71c_remove_shared_columns.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py index 4d82413a260..70ad3b75cd2 100644 --- a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -15,7 +15,12 @@ depends_on = None -def _migrate_default_entities(table: sa.Table) -> None: +def _rename_default_entities(table: sa.Table) -> None: + """Include owner id in the name of default entities. + + Args: + table: The table in which to rename the default entities. + """ connection = op.get_bind() query = sa.select( @@ -33,14 +38,15 @@ def _migrate_default_entities(table: sa.Table) -> None: def resolve_duplicate_names() -> None: + """Resolve duplicate names for shareable entities.""" meta = sa.MetaData(bind=op.get_bind()) meta.reflect(only=("stack", "stack_component", "service_connector")) stack_table = sa.Table("stack", meta) stack_component_table = sa.Table("stack_component", meta) - _migrate_default_entities(stack_table) - _migrate_default_entities(stack_component_table) + _rename_default_entities(stack_table) + _rename_default_entities(stack_component_table) service_connector_table = sa.Table("service_connector", meta) query = sa.select( From 39a9e43a3541fb9e32c7f7d09a287e44bd5aa816 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 24 Oct 2023 14:26:19 +0200 Subject: [PATCH 022/103] More utility functions --- src/zenml/zen_server/rbac/utils.py | 107 +++++++++++------- src/zenml/zen_server/rbac/zenml_cloud_rbac.py | 4 +- .../zen_server/routers/stacks_endpoints.py | 18 ++- 3 files changed, 85 insertions(+), 44 deletions(-) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index fc29aeb5db3..db3ab81a64d 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -15,7 +15,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Set, Type, TypeVar +from typing import Any, Dict, List, Optional, Sequence, Set, Type, TypeVar from uuid import UUID from fastapi import HTTPException @@ -43,7 +43,7 @@ def verify_read_permissions_and_dehydrate( if not server_config().rbac_enabled: return model - verify_permissions_for_model(model=model, action=Action.READ) + verify_permission_for_model(model=model, action=Action.READ) return dehydrate_response_model(model=model) @@ -57,6 +57,9 @@ def dehydrate_page(page: Page[M]) -> Page[M]: Returns: The page with (potentially) dehydrated items. """ + if not server_config().rbac_enabled: + return page + auth_context = get_auth_context() assert auth_context @@ -150,7 +153,7 @@ def has_permissions_for_model(model: "BaseResponseModel", action: str) -> bool: If the active user has permissions to perform the action on the model. """ try: - verify_permissions_for_model(model=model, action=action) + verify_permission_for_model(model=model, action=action) return True except HTTPException: return False @@ -206,46 +209,53 @@ def get_permission_denied_model( return type(model).parse_obj(values) -def verify_permissions_for_model( - model: "BaseResponseModel", +def batch_verify_permissions_for_models( + models: Sequence["BaseResponseModel"], action: str, ) -> None: - """Verifies if a user has permissions to perform an action on a model. + """Batch permission verification for models. Args: - model: The model the user wants to perform the action on. + models: The models the user wants to perform the action on. action: The action the user wants to perform. """ if not server_config().rbac_enabled: return - if is_owned_by_authenticated_user(model): - # The model owner always has permissions - return + resources = set() + for model in models: + if is_owned_by_authenticated_user(model): + # The model owner always has permissions + continue - resource_type = get_resource_type_for_model(model) - if not resource_type: - # This model is not tied to any RBAC resource type and therefore doesn't - # require any special permissions - return + if resource := get_resource_for_model(model): + resources.add(resource) - verify_permissions( - resource_type=resource_type, resource_id=model.id, action=action - ) + batch_verify_permissions(resources=resources, action=action) -def verify_permissions( - resource_type: str, +def verify_permission_for_model( + model: "BaseResponseModel", action: str, - resource_id: Optional[UUID] = None, ) -> None: - """Verifies if a user has permissions to perform an action on a resource. + """Verifies if a user has permission to perform an action on a model. Args: - resource_type: The type of resource that the user wants to perform the - action on. + model: The model the user wants to perform the action on. + action: The action the user wants to perform. + """ + batch_verify_permissions_for_models(models=[model], action=action) + + +def batch_verify_permissions( + resources: Set[Resource], + action: str, +) -> None: + """Batch permission verification. + + Args: + resources: The resources the user wants to perform the action on. action: The action the user wants to perform. - resource_id: ID of the resource the user wants to perform the action on. Raises: HTTPException: If the user is not allowed to perform the action. @@ -257,25 +267,42 @@ def verify_permissions( auth_context = get_auth_context() assert auth_context - resource = Resource(type=resource_type, id=resource_id) permissions = rbac().check_permissions( - user=auth_context.user, resources={resource}, action=action + user=auth_context.user, resources=resources, action=action ) - if resource not in permissions: - # This should never happen if the RBAC implementation is working - # correctly - raise RuntimeError( - f"Failed to verify permissions to {action.upper()} resource " - f"'{resource}'." - ) + for resource in resources: + if resource not in permissions: + # This should never happen if the RBAC implementation is working + # correctly + raise RuntimeError( + f"Failed to verify permissions to {action.upper()} resource " + f"'{resource}'." + ) - if not permissions[resource]: - raise HTTPException( - status_code=403, - detail=f"Insufficient permissions to {action.upper()} resource " - f"'{resource}'.", - ) + if not permissions[resource]: + raise HTTPException( + status_code=403, + detail=f"Insufficient permissions to {action.upper()} resource " + f"'{resource}'.", + ) + + +def verify_permission( + resource_type: str, + action: str, + resource_id: Optional[UUID] = None, +) -> None: + """Verifies if a user has permission to perform an action on a resource. + + Args: + resource_type: The type of resource that the user wants to perform the + action on. + action: The action the user wants to perform. + resource_id: ID of the resource the user wants to perform the action on. + """ + resource = Resource(type=resource_type, id=resource_id) + batch_verify_permissions(resources={resource}, action=action) def get_allowed_resource_ids( diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py index 349b8fc734c..93b750d7448 100644 --- a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -30,6 +30,8 @@ PERMISSIONS_ENDPOINT = "/rbac/check_permissions" ALLOWED_RESOURCE_IDS_ENDPOINT = "/rbac/allowed_resource_ids" +SERVER_SCOPE_IDENTIFIER = "server" + SERVER_ID = server_config().external_server_id @@ -42,7 +44,7 @@ def _convert_to_cloud_resource(resource: Resource) -> str: Returns: The converted resource. """ - resource_string = f"{SERVER_ID}@server:{resource.type}" + resource_string = f"{SERVER_ID}@{SERVER_SCOPE_IDENTIFIER}:{resource.type}" if resource.id: resource_string += f"/{resource.id}" diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index 82bae390dfa..0e87eac69a9 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -27,10 +27,11 @@ from zenml.zen_server.exceptions import error_response from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( + batch_verify_permissions_for_models, dehydrate_page, dehydrate_response_model, get_allowed_resource_ids, - verify_permissions_for_model, + verify_permission_for_model, verify_read_permissions_and_dehydrate, ) from zenml.zen_server.utils import ( @@ -119,7 +120,18 @@ def update_stack( The updated stack. """ stack = zen_store().get_stack(stack_id) - verify_permissions_for_model(stack, action=Action.UPDATE) + verify_permission_for_model(stack, action=Action.UPDATE) + + if stack_update.components: + updated_components = [ + zen_store().get_stack_component(id) + for ids in stack_update.components.values() + for id in ids + ] + + batch_verify_permissions_for_models( + updated_components, action=Action.READ + ) updated_stack = zen_store().update_stack( stack_id=stack_id, @@ -143,6 +155,6 @@ def delete_stack( stack_id: Name of the stack. """ stack = zen_store().get_stack(stack_id) - verify_permissions_for_model(stack, action=Action.DELETE) + verify_permission_for_model(stack, action=Action.DELETE) zen_store().delete_stack(stack_id) From 437824510ef481cdc8246dabf9450301bbfccac7 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 24 Oct 2023 15:44:37 +0200 Subject: [PATCH 023/103] Pass secretEnvironment as kubernetes secret --- src/zenml/zen_server/deploy/helm/templates/server-secret.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml index 52aa085abb5..f163df2e3e4 100644 --- a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml +++ b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml @@ -65,8 +65,8 @@ data: {{- end }} {{- end }} {{- end }} - {{- if .Values.zenml.environment }} - {{- range $key, $value := .Values.zenml.environment }} + {{- if .Values.zenml.secretEnvironment }} + {{- range $key, $value := .Values.zenml.secretEnvironment }} {{ $key }}: {{ $value | b64enc | quote }} {{- end }} {{- end }} From ed6bb84289323d323f90748b5aaa786e8aeaba8f Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 24 Oct 2023 17:00:18 +0200 Subject: [PATCH 024/103] Remove service connector sharing --- src/zenml/zen_server/auth.py | 22 +-- .../zen_server/routers/secrets_endpoints.py | 30 ++-- .../routers/service_connectors_endpoints.py | 136 ++++-------------- .../routers/workspaces_endpoints.py | 1 - src/zenml/zen_stores/sql_zen_store.py | 57 +------- 5 files changed, 46 insertions(+), 200 deletions(-) diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index ab967313239..558785cf235 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -15,7 +15,7 @@ from contextvars import ContextVar from datetime import datetime -from typing import Callable, List, Optional, Set, Union +from typing import Callable, Optional, Union from urllib.parse import urlencode from uuid import UUID @@ -37,7 +37,7 @@ LOGIN, VERSION_1, ) -from zenml.enums import AuthScheme, OAuthDeviceStatus, PermissionType +from zenml.enums import AuthScheme, OAuthDeviceStatus from zenml.exceptions import AuthorizationException, OAuthError from zenml.logger import get_logger from zenml.models import ( @@ -92,24 +92,6 @@ class AuthContext(BaseModel): encoded_access_token: Optional[str] = None device: Optional[OAuthDeviceInternalResponseModel] = None - @property - def permissions(self) -> Set[PermissionType]: - """Returns the permissions of the user. - - Returns: - The permissions of the user. - """ - if self.user.roles: - # Merge permissions from all roles - permissions: List[PermissionType] = [] - for role in self.user.roles: - permissions.extend(role.permissions) - - # Remove duplicates - return set(permissions) - - return set() - def authenticate_credentials( user_name_or_id: Optional[Union[str, UUID]] = None, diff --git a/src/zenml/zen_server/routers/secrets_endpoints.py b/src/zenml/zen_server/routers/secrets_endpoints.py index 0390836758e..694d4122052 100644 --- a/src/zenml/zen_server/routers/secrets_endpoints.py +++ b/src/zenml/zen_server/routers/secrets_endpoints.py @@ -18,7 +18,6 @@ from fastapi import APIRouter, Depends, Security from zenml.constants import API, SECRETS, VERSION_1 -from zenml.enums import PermissionType from zenml.models.page_model import Page from zenml.models.secret_models import ( SecretFilterModel, @@ -50,27 +49,20 @@ def list_secrets( secret_filter_model: SecretFilterModel = Depends( make_dependable(SecretFilterModel) ), - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> Page[SecretResponseModel]: """Gets a list of secrets. Args: secret_filter_model: Filter model used for pagination, sorting, filtering - auth_context: Authentication context. Returns: List of secret objects. """ - secrets = zen_store().list_secrets(secret_filter_model=secret_filter_model) - - # Remove secrets from the response if the user does not have write - # permissions. - if PermissionType.WRITE not in auth_context.permissions: - for secret in secrets.items: - secret.remove_secrets() - - return secrets + # TODO: we should probably have separate permissions here for reading the + # secret and its content + return zen_store().list_secrets(secret_filter_model=secret_filter_model) @router.get( @@ -81,25 +73,19 @@ def list_secrets( @handle_exceptions def get_secret( secret_id: UUID, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> SecretResponseModel: """Gets a specific secret using its unique id. Args: secret_id: ID of the secret to get. - auth_context: Authentication context. Returns: A specific secret object. """ - secret = zen_store().get_secret(secret_id=secret_id) - - # Remove secrets from the response if the user does not have write - # permissions. - if PermissionType.WRITE not in auth_context.permissions: - secret.remove_secrets() - - return secret + # TODO: we should probably have separate permissions here for reading the + # secret and its content + return zen_store().get_secret(secret_id=secret_id) @router.put( diff --git a/src/zenml/zen_server/routers/service_connectors_endpoints.py b/src/zenml/zen_server/routers/service_connectors_endpoints.py index c9fc047a68b..491db2fd41c 100644 --- a/src/zenml/zen_server/routers/service_connectors_endpoints.py +++ b/src/zenml/zen_server/routers/service_connectors_endpoints.py @@ -25,7 +25,6 @@ SERVICE_CONNECTORS, VERSION_1, ) -from zenml.enums import PermissionType from zenml.models import ( ServiceConnectorFilterModel, ServiceConnectorRequestModel, @@ -67,7 +66,7 @@ def list_service_connectors( make_dependable(ServiceConnectorFilterModel) ), expand_secrets: bool = True, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> Page[ServiceConnectorResponseModel]: """Get a list of all service connectors for a specific type. @@ -75,7 +74,6 @@ def list_service_connectors( connector_filter_model: Filter model used for pagination, sorting, filtering expand_secrets: Whether to expand secrets or not. - auth_context: Authentication Context Returns: Page with list of service connectors for a specific type. @@ -84,7 +82,7 @@ def list_service_connectors( filter_model=connector_filter_model ) - if expand_secrets and PermissionType.WRITE in auth_context.permissions: + if expand_secrets: for connector in connectors.items: if not connector.secret_id: continue @@ -105,42 +103,26 @@ def list_service_connectors( def get_service_connector( connector_id: UUID, expand_secrets: bool = True, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> ServiceConnectorResponseModel: """Returns the requested service connector. Args: connector_id: ID of the service connector. expand_secrets: Whether to expand secrets or not. - auth_context: Authentication context. Returns: The requested service connector. - - Raises: - KeyError: If the service connector does not exist or is not accessible. """ connector = zen_store().get_service_connector(connector_id) - # Don't allow users to access service connectors that don't belong to them - # unless they are shared. - if ( - connector.user - and connector.user.id == auth_context.user.id - or connector.is_shared - ): - if PermissionType.WRITE not in auth_context.permissions: - return connector - - if expand_secrets and connector.secret_id: - secret = zen_store().get_secret(secret_id=connector.secret_id) + if expand_secrets and connector.secret_id: + secret = zen_store().get_secret(secret_id=connector.secret_id) - # Update the connector configuration with the secret. - connector.configuration.update(secret.secret_values) - - return connector + # Update the connector configuration with the secret. + connector.configuration.update(secret.secret_values) - raise KeyError(f"Service connector with ID {connector_id} not found.") + return connector @router.put( @@ -152,36 +134,21 @@ def get_service_connector( def update_service_connector( connector_id: UUID, connector_update: ServiceConnectorUpdateModel, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> ServiceConnectorResponseModel: """Updates a service connector. Args: connector_id: ID of the service connector. connector_update: Service connector to use to update. - auth_context: Authentication context. Returns: Updated service connector. - - Raises: - KeyError: If the service connector does not exist or is not accessible. """ - connector = zen_store().get_service_connector(connector_id) - - # Don't allow users to access service connectors that don't belong to them - # unless they are shared. - if ( - connector.user - and connector.user.id == auth_context.user.id - or connector.is_shared - ): - return zen_store().update_service_connector( - service_connector_id=connector_id, - update=connector_update, - ) - - raise KeyError(f"Service connector with ID {connector_id} not found.") + return zen_store().update_service_connector( + service_connector_id=connector_id, + update=connector_update, + ) @router.delete( @@ -191,30 +158,17 @@ def update_service_connector( @handle_exceptions def delete_service_connector( connector_id: UUID, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> None: """Deletes a service connector. Args: connector_id: ID of the service connector. - auth_context: Authentication context. Raises: KeyError: If the service connector does not exist or is not accessible. """ - connector = zen_store().get_service_connector(connector_id) - - # Don't allow users to access service connectors that don't belong to them - # unless they are shared. - if ( - connector.user - and connector.user.id == auth_context.user.id - or connector.is_shared - ): - zen_store().delete_service_connector(connector_id) - return - - raise KeyError(f"Service connector with ID {connector_id} not found.") + zen_store().delete_service_connector(connector_id) @router.post( @@ -260,7 +214,7 @@ def validate_and_verify_service_connector( resource_type: Optional[str] = None, resource_id: Optional[str] = None, list_resources: bool = True, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> ServiceConnectorResourcesModel: """Verifies if a service connector instance has access to one or more resources. @@ -275,32 +229,17 @@ def validate_and_verify_service_connector( list_resources: If True, the list of all resources accessible through the service connector and matching the supplied resource type and ID are returned. - auth_context: Authentication context. Returns: The list of resources that the service connector has access to, scoped to the supplied resource type and ID, if provided. - - Raises: - KeyError: If the service connector does not exist or is not accessible. """ - connector = zen_store().get_service_connector(connector_id) - - # Don't allow users to access service connectors that don't belong to them - # unless they are shared. - if ( - connector.user - and connector.user.id == auth_context.user.id - or connector.is_shared - ): - return zen_store().verify_service_connector( - service_connector_id=connector_id, - resource_type=resource_type, - resource_id=resource_id, - list_resources=list_resources, - ) - - raise KeyError(f"Service connector with ID {connector_id} not found.") + return zen_store().verify_service_connector( + service_connector_id=connector_id, + resource_type=resource_type, + resource_id=resource_id, + list_resources=list_resources, + ) @router.get( @@ -313,7 +252,7 @@ def get_service_connector_client( connector_id: UUID, resource_type: Optional[str] = None, resource_id: Optional[str] = None, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> ServiceConnectorResponseModel: """Get a service connector client for a service connector and given resource. @@ -325,31 +264,16 @@ def get_service_connector_client( connector_id: ID of the service connector. resource_type: Type of the resource to list. resource_id: ID of the resource to list. - auth_context: Authentication context. Returns: A service connector client that can be used to access the given resource. - - Raises: - KeyError: If the service connector does not exist or is not accessible. """ - connector = zen_store().get_service_connector(connector_id) - - # Don't allow users to access service connectors that don't belong to them - # unless they are shared. - if ( - connector.user - and connector.user.id == auth_context.user.id - or connector.is_shared - ): - return zen_store().get_service_connector_client( - service_connector_id=connector_id, - resource_type=resource_type, - resource_id=resource_id, - ) - - raise KeyError(f"Service connector with ID {connector_id} not found.") + return zen_store().get_service_connector_client( + service_connector_id=connector_id, + resource_type=resource_type, + resource_id=resource_id, + ) @types_router.get( @@ -401,6 +325,4 @@ def get_service_connector_type( Returns: The requested service connector type. """ - c = zen_store().get_service_connector_type(connector_type) - - return c + return zen_store().get_service_connector_type(connector_type) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 783135f94b3..1c705c9e66c 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -971,7 +971,6 @@ def list_workspace_service_connectors( """ workspace = zen_store().get_workspace(workspace_name_or_id) connector_filter_model.set_scope_workspace(workspace.id) - connector_filter_model.set_scope_user(user_id=auth_context.user.id) return zen_store().list_service_connectors( filter_model=connector_filter_model ) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 53dd93f1633..0f028ae8272 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -140,7 +140,6 @@ ScheduleUpdateModel, SecretFilterModel, SecretRequestModel, - SecretUpdateModel, ServerDatabaseType, ServerModel, ServiceConnectorFilterModel, @@ -3453,7 +3452,6 @@ def _create_connector_secret( connector_name: str, user: UUID, workspace: UUID, - is_shared: bool, secrets: Optional[Dict[str, Optional[SecretStr]]], ) -> Optional[UUID]: """Creates a new secret to store the service connector secret credentials. @@ -3464,7 +3462,6 @@ def _create_connector_secret( user: The ID of the user who owns the service connector. workspace: The ID of the workspace in which the service connector is registered. - is_shared: Whether the service connector is shared. secrets: The secret credentials to store. Returns: @@ -3504,9 +3501,7 @@ def _create_connector_secret( name=secret_name, user=user, workspace=workspace, - scope=SecretScope.WORKSPACE - if is_shared - else SecretScope.USER, + scope=SecretScope.WORKSPACE, values=secrets, ) ).id @@ -3580,7 +3575,6 @@ def create_service_connector( connector_name=service_connector.name, user=service_connector.user, workspace=service_connector.workspace, - is_shared=service_connector.is_shared, secrets=service_connector.secrets, ) try: @@ -3766,25 +3760,7 @@ def _update_connector_secret( "A secrets store is not configured or supported." ) - is_shared = ( - existing_connector.is_shared - if updated_connector.is_shared is None - else updated_connector.is_shared - ) - scope_changed = is_shared != existing_connector.is_shared - if updated_connector.secrets is None: - if scope_changed and existing_connector.secret_id: - # Update the scope of the existing secret - self.secrets_store.update_secret( - secret_id=existing_connector.secret_id, - secret_update=SecretUpdateModel( # type: ignore[call-arg] - scope=SecretScope.WORKSPACE - if is_shared - else SecretScope.USER, - ), - ) - # If the connector update does not contain a secrets update, keep # the existing secret (if any) return existing_connector.secret_id @@ -3808,7 +3784,6 @@ def _update_connector_secret( connector_name=updated_connector.name or existing_connector.name, user=existing_connector.user.id, workspace=existing_connector.workspace.id, - is_shared=is_shared, secrets=updated_connector.secrets, ) @@ -3864,29 +3839,12 @@ def update_service_connector( # In case of a renaming update, make sure no service connector uses # that name already - if update.name: - if ( - existing_connector.name != update.name - and existing_connector.user_id is not None - ): - self._fail_if_service_connector_with_name_exists_for_user( - name=update.name, - workspace_id=existing_connector.workspace_id, - user_id=existing_connector.user_id, - session=session, - ) - - # Check if service connector update makes the service connector a - # shared service connector - # In that case, check if a service connector with the same name is - # already shared within the workspace - if update.is_shared is not None: - if not existing_connector.is_shared and update.is_shared: - self._fail_if_service_connector_with_name_already_shared( - name=update.name or existing_connector.name, - workspace_id=existing_connector.workspace_id, - session=session, - ) + if update.name and existing_connector.name != update.name: + self._fail_if_service_connector_with_name_exists( + name=update.name, + workspace_id=existing_connector.workspace_id, + session=session, + ) existing_connector_model = existing_connector.to_model() @@ -4119,7 +4077,6 @@ def get_service_connector_client( connector = connector_client.to_response_model( user=connector.user, workspace=connector.workspace, - is_shared=connector.is_shared, description=connector.description, labels=connector.labels, ) From bc9229d1f44dbec9a0cd6650284d78b003c6d67e Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 24 Oct 2023 17:29:48 +0200 Subject: [PATCH 025/103] Fix docstrings --- src/zenml/zen_server/routers/service_connectors_endpoints.py | 3 --- src/zenml/zen_stores/sql_zen_store.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/src/zenml/zen_server/routers/service_connectors_endpoints.py b/src/zenml/zen_server/routers/service_connectors_endpoints.py index 491db2fd41c..b739427c99c 100644 --- a/src/zenml/zen_server/routers/service_connectors_endpoints.py +++ b/src/zenml/zen_server/routers/service_connectors_endpoints.py @@ -164,9 +164,6 @@ def delete_service_connector( Args: connector_id: ID of the service connector. - - Raises: - KeyError: If the service connector does not exist or is not accessible. """ zen_store().delete_service_connector(connector_id) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 0f028ae8272..e7b09b9dbf6 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -1571,7 +1571,6 @@ def _fail_if_component_with_name_type_exists( f"component with the same name and type in the same " f" workspace '{existing_domain_component.workspace.name}'." ) - return None def _fail_if_component_name_reserved(self, component_name: str) -> None: """Raise an exception if the component name is reserved. @@ -3445,7 +3444,6 @@ def _fail_if_service_connector_with_name_exists( "Found an existing service connector with the same name in the " f"same workspace '{existing_domain_connector.workspace.name}'." ) - return None def _create_connector_secret( self, From f8bbb37524d4ad257d9fc51b6f8ceb0bf4958dd9 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 24 Oct 2023 18:04:45 +0200 Subject: [PATCH 026/103] Fix most tests --- tests/integration/functional/cli/test_role.py | 177 --------- .../integration/functional/cli/test_stack.py | 175 +-------- .../functional/cli/test_user_management.py | 121 ------- tests/integration/functional/cli/utils.py | 42 +-- tests/integration/functional/test_client.py | 13 - .../functional/zen_stores/test_zen_store.py | 335 +----------------- .../functional/zen_stores/utils.py | 66 ---- tests/unit/utils/test_analytics_utils.py | 2 - 8 files changed, 5 insertions(+), 926 deletions(-) delete mode 100644 tests/integration/functional/cli/test_role.py diff --git a/tests/integration/functional/cli/test_role.py b/tests/integration/functional/cli/test_role.py deleted file mode 100644 index d01ec68ccd2..00000000000 --- a/tests/integration/functional/cli/test_role.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -from click.testing import CliRunner - -from tests.integration.functional.cli.utils import ( - create_sample_role, - create_sample_user, - sample_role_name, -) -from zenml.cli.cli import cli -from zenml.client import Client -from zenml.enums import PermissionType -from zenml.zen_stores.base_zen_store import DEFAULT_ADMIN_ROLE - - -def test_create_role_succeeds() -> None: - """Test that creating a new role succeeds.""" - role_create_command = cli.commands["role"].commands["create"] - runner = CliRunner() - result = runner.invoke( - role_create_command, - [sample_role_name(), f"--permissions={PermissionType.READ}"], - ) - assert result.exit_code == 0 - - -def test_create_existing_role_fails() -> None: - """Test that creating a role that exists fails.""" - r = create_sample_role() - role_create_command = cli.commands["role"].commands["create"] - runner = CliRunner() - result = runner.invoke( - role_create_command, - [r.name, "--permissions=read"], - ) - assert result.exit_code == 1 - - -def test_update_role_permissions_succeeds() -> None: - """Test that updating a role succeeds.""" - r = create_sample_role() - role_update_command = cli.commands["role"].commands["update"] - runner = CliRunner() - result = runner.invoke( - role_update_command, - [r.name, f"--add-permission={PermissionType.WRITE.value}"], - ) - assert result.exit_code == 0 - - -def test_rename_role_succeeds() -> None: - """Test that updating a role succeeds.""" - r = create_sample_role() - role_update_command = cli.commands["role"].commands["update"] - runner = CliRunner() - result = runner.invoke( - role_update_command, - [r.name, "--name='cat_groomer'"], - ) - assert result.exit_code == 0 - - -def test_update_role_conflicting_permissions_fails() -> None: - """Test that updating a role succeeds.""" - r = create_sample_role() - role_update_command = cli.commands["role"].commands["update"] - runner = CliRunner() - result = runner.invoke( - role_update_command, - [ - r.name, - f"--add-permission={PermissionType.WRITE.value}", - f"--remove-permission={PermissionType.WRITE.value}", - ], - ) - assert result.exit_code == 1 - - -def test_update_default_role_fails() -> None: - """Test that updating the default role fails.""" - role_update_command = cli.commands["role"].commands["update"] - runner = CliRunner() - result = runner.invoke( - role_update_command, - [ - DEFAULT_ADMIN_ROLE, - f"--remove-permission={PermissionType.WRITE.value}", - ], - ) - assert result.exit_code == 1 - - -def test_delete_role_succeeds() -> None: - """Test that deleting a role succeeds.""" - r = create_sample_role() - role_update_command = cli.commands["role"].commands["delete"] - runner = CliRunner() - result = runner.invoke( - role_update_command, - [r.name], - ) - assert result.exit_code == 0 - - -def test_delete_default_role_fails() -> None: - """Test that deleting a role succeeds.""" - role_delete_command = cli.commands["role"].commands["delete"] - runner = CliRunner() - result = runner.invoke( - role_delete_command, - [DEFAULT_ADMIN_ROLE], - ) - assert result.exit_code == 1 - - -def test_assign_default_role_to_new_user_succeeds() -> None: - """Test that deleting a role succeeds.""" - user = create_sample_user() - role_assign_command = cli.commands["role"].commands["assign"] - runner = CliRunner() - result = runner.invoke( - role_assign_command, [DEFAULT_ADMIN_ROLE, f"--user={user.id}"] - ) - assert result.exit_code == 0 - assigned_roles = Client().get_user(user.id).roles - assert len(assigned_roles) == 1 - assert assigned_roles[0].name == DEFAULT_ADMIN_ROLE - - -def test_assign_role_to_user_twice_fails() -> None: - """Test that deleting a role succeeds.""" - user = create_sample_user() - role_assign_command = cli.commands["role"].commands["assign"] - runner = CliRunner() - Client().create_user_role_assignment( - role_name_or_id=DEFAULT_ADMIN_ROLE, - user_name_or_id=str(user.id), - ) - - result = runner.invoke( - role_assign_command, [DEFAULT_ADMIN_ROLE, f"--user={user.name}"] - ) - assert result.exit_code == 1 - - -def test_revoke_role_from_new_user_succeeds() -> None: - """Test that deleting a role assignment succeeds.""" - user = create_sample_user() - role_assign_command = cli.commands["role"].commands["assign"] - runner = CliRunner() - result = runner.invoke( - role_assign_command, [DEFAULT_ADMIN_ROLE, f"--user={user.name}"] - ) - assert result.exit_code == 0 - assigned_roles = Client().get_user(user.id).roles - assert len(assigned_roles) == 1 - assert assigned_roles[0].name == DEFAULT_ADMIN_ROLE - - role_revoke_command = cli.commands["role"].commands["revoke"] - runner = CliRunner() - result = runner.invoke( - role_revoke_command, [DEFAULT_ADMIN_ROLE, f"--user={user.name}"] - ) - assert result.exit_code == 0 - assigned_roles = Client().get_user(user.id).roles - assert len(assigned_roles) == 0 diff --git a/tests/integration/functional/cli/test_stack.py b/tests/integration/functional/cli/test_stack.py index 7681bbfac2f..0f6f68f941d 100644 --- a/tests/integration/functional/cli/test_stack.py +++ b/tests/integration/functional/cli/test_stack.py @@ -20,16 +20,13 @@ import pytest from click.testing import CliRunner -from tests.integration.functional.cli.utils import ( - create_sample_user_and_login, -) from zenml.artifact_stores.local_artifact_store import ( LocalArtifactStore, LocalArtifactStoreConfig, ) from zenml.cli.cli import cli from zenml.client import Client -from zenml.enums import StackComponentType, StoreType +from zenml.enums import StackComponentType from zenml.orchestrators.base_orchestrator import BaseOrchestratorConfig from zenml.orchestrators.local.local_orchestrator import LocalOrchestrator from zenml.secrets_managers.local.local_secrets_manager import ( @@ -421,176 +418,6 @@ def test_rename_stack_non_active_stack_succeeds(clean_workspace) -> None: assert clean_workspace.get_stack(new_stack.id).name == "axls_stack" -def test_sharing_nonexistent_stack_fails(clean_workspace: Client) -> None: - """Test stack rename of nonexistent stack fails.""" - runner = CliRunner() - share_command = cli.commands["stack"].commands["share"] - result = runner.invoke(share_command, ["not_a_stack"]) - assert result.exit_code == 1 - - -def test_sharing_default_stack_fails(clean_workspace: Client) -> None: - runner = CliRunner() - share_command = cli.commands["stack"].commands["share"] - result = runner.invoke(share_command, ["default"]) - assert result.exit_code == 1 - - default_stack = clean_workspace.get_stack("default") - assert default_stack.is_shared is False - - -def test_share_stack_that_is_already_shared_fails( - clean_workspace: Client, -) -> None: - new_artifact_store = _create_local_artifact_store(clean_workspace) - - new_artifact_store_model = clean_workspace.create_stack_component( - name=new_artifact_store.name, - flavor=new_artifact_store.flavor, - component_type=new_artifact_store.type, - configuration=new_artifact_store.config.dict(), - is_shared=True, - ) - - new_orchestrator = _create_local_orchestrator(clean_workspace) - - new_orchestrator_model = clean_workspace.create_stack_component( - name=new_orchestrator.name, - flavor=new_orchestrator.flavor, - component_type=new_orchestrator.type, - configuration=new_orchestrator.config.dict(), - is_shared=True, - ) - - new_stack = clean_workspace.create_stack( - name="arias_new_stack", - components={ - StackComponentType.ARTIFACT_STORE: new_artifact_store_model.name, - StackComponentType.ORCHESTRATOR: new_orchestrator_model.name, - }, - is_shared=True, - ) - - runner = CliRunner() - share_command = cli.commands["stack"].commands["share"] - result = runner.invoke(share_command, ["arias_new_stack"]) - assert result.exit_code == 1 - - arias_stack = clean_workspace.get_stack(new_stack.name) - assert arias_stack.is_shared is True - - -def test_create_shared_stack_when_component_is_private_fails( - clean_workspace: Client, -) -> None: - """When creating a shared stack all the components should also be shared, so if a component is not shared this should fail.""" - runner = CliRunner() - register_command = cli.commands["stack"].commands["register"] - result = runner.invoke( - register_command, - ["default2", "-o", "default", "-a", "default", "--share"], - ) - assert result.exit_code == 1 - - -def test_share_stack_when_component_is_already_shared_by_other_user_fails( - clean_workspace: Client, -) -> None: - """When sharing a stack all the components are also shared, so if a component with the same name is already shared this should fail.""" - if clean_workspace.zen_store.type != StoreType.REST: - pytest.skip("Only supported on ZenML server") - - # Shared component - shared_artifact_store = _create_local_artifact_store(clean_workspace) - - with create_sample_user_and_login( - prefix="Arias_Evil_Twin", initial_role="admin" - ) as ( - other_user, - other_client, - ): - other_client.create_stack_component( - name=shared_artifact_store.name, - is_shared=True, - component_type=StackComponentType.ARTIFACT_STORE, - flavor="local", - configuration={}, - ) - - # Non-shared components - new_artifact_store = _create_local_artifact_store(clean_workspace) - - new_artifact_store_model = clean_workspace.create_stack_component( - name=new_artifact_store.name, - flavor=new_artifact_store.flavor, - component_type=new_artifact_store.type, - configuration=new_artifact_store.config.dict(), - ) - - new_orchestrator = _create_local_orchestrator(clean_workspace) - - new_orchestrator_model = clean_workspace.create_stack_component( - name=new_orchestrator.name, - flavor=new_orchestrator.flavor, - component_type=new_orchestrator.type, - configuration=new_orchestrator.config.dict(), - ) - - # Register non-shared stack with non-shared components - new_stack = clean_workspace.create_stack( - name="arias_new_stack", - components={ - StackComponentType.ARTIFACT_STORE: new_artifact_store_model.id, - StackComponentType.ORCHESTRATOR: new_orchestrator_model.id, - }, - ) - - # Share stack where the shared versions of components already exists - runner = CliRunner() - share_command = cli.commands["stack"].commands["share"] - result = runner.invoke(share_command, [new_stack.name, "-r"]) - assert result.exit_code == 1 - - -def test_share_stack_when_component_is_private_fails( - clean_workspace: Client, -) -> None: - """When sharing a stack all the components are also shared, so if a component with the same name is already shared this should fail.""" - # Non-shared components - new_artifact_store = _create_local_artifact_store(clean_workspace) - - new_artifact_store_model = clean_workspace.create_stack_component( - name=new_artifact_store.name, - flavor=new_artifact_store.flavor, - component_type=new_artifact_store.type, - configuration=new_artifact_store.config.dict(), - ) - - new_orchestrator = _create_local_orchestrator(clean_workspace) - - new_orchestrator_model = clean_workspace.create_stack_component( - name=new_orchestrator.name, - flavor=new_orchestrator.flavor, - component_type=new_orchestrator.type, - configuration=new_orchestrator.config.dict(), - ) - - # Register non-shared stack with non-shared components - new_stack = clean_workspace.create_stack( - name="arias_new_stack", - components={ - StackComponentType.ARTIFACT_STORE: new_artifact_store_model.name, - StackComponentType.ORCHESTRATOR: new_orchestrator_model.name, - }, - ) - - # Share stack where the shared versions of components already exists - runner = CliRunner() - share_command = cli.commands["stack"].commands["share"] - result = runner.invoke(share_command, [new_stack.name]) - assert result.exit_code == 1 - - def test_remove_component_from_nonexistent_stack_fails( clean_workspace, ) -> None: diff --git a/tests/integration/functional/cli/test_user_management.py b/tests/integration/functional/cli/test_user_management.py index 0bdbd56285e..930f8ce1fef 100644 --- a/tests/integration/functional/cli/test_user_management.py +++ b/tests/integration/functional/cli/test_user_management.py @@ -14,22 +14,13 @@ from click.testing import CliRunner from tests.integration.functional.cli.utils import ( - create_sample_team, create_sample_user, sample_name, - sample_team_name, - team_create_command, - team_delete_command, - team_describe_command, - team_list_command, - team_update_command, user_create_command, user_delete_command, user_update_command, ) -from zenml.client import Client from zenml.zen_stores.base_zen_store import ( - DEFAULT_ADMIN_ROLE, DEFAULT_USERNAME, ) @@ -59,20 +50,6 @@ def test_create_user_that_exists_fails() -> None: assert result.exit_code == 1 -def test_create_user_with_initial_role_succeeds() -> None: - """Test that creating a new user succeeds.""" - runner = CliRunner() - result = runner.invoke( - user_create_command, - [ - sample_name(), - "--password=thesupercat", - f"--role={DEFAULT_ADMIN_ROLE}", - ], - ) - assert result.exit_code == 0 - - def test_update_user_with_new_name_succeeds() -> None: """Test that creating a new user succeeds.""" u = create_sample_user() @@ -151,101 +128,3 @@ def test_delete_user_succeeds() -> None: ) assert result.exit_code == 0 - - -# ----- # -# TEAMS # -# ----- # - - -def test_create_team_succeeds() -> None: - """Test that creating a new team with users succeeds.""" - u = create_sample_user() - runner = CliRunner() - result = runner.invoke( - team_create_command, - [sample_team_name(), f"--user={DEFAULT_USERNAME}", f"--user={u.name}"], - ) - assert result.exit_code == 0 - - -def test_create_team_without_users_succeeds() -> None: - """Test that creating a new team with users succeeds.""" - runner = CliRunner() - result = runner.invoke( - team_create_command, - [sample_team_name()], - ) - assert result.exit_code == 0 - - -def test_describe_team_succeeds() -> None: - """Test that creating a new team with users succeeds.""" - team = create_sample_team() - runner = CliRunner() - result = runner.invoke( - team_describe_command, - [team.name], - ) - assert result.exit_code == 0 - - -def test_list_team_succeeds() -> None: - """Test that creating a new team with users succeeds.""" - create_sample_team() - runner = CliRunner() - result = runner.invoke( - team_list_command, - ) - assert result.exit_code == 0 - - -def test_update_team_name_succeeds() -> None: - """Test that creating a new team with users succeeds.""" - team = create_sample_team() - new_name = sample_team_name() - runner = CliRunner() - result = runner.invoke( - team_update_command, [team.name, f"--name={new_name}"] - ) - assert result.exit_code == 0 - assert Client().get_team(new_name) - - -def test_update_team_members_succeeds() -> None: - """Test that updating a new team with new users succeeds.""" - team = create_sample_team() - user = create_sample_user() - runner = CliRunner() - result = runner.invoke( - team_update_command, [team.name, f"--add-user={user.name}"] - ) - assert result.exit_code == 0 - updated_team = Client().get_team(str(team.id)) - assert user.id in updated_team.user_ids - - -def test_update_team_members_ambiguously_fails() -> None: - """Test that updating a team with ambiguous instructions fails.""" - team = create_sample_team() - user = create_sample_user() - runner = CliRunner() - result = runner.invoke( - team_update_command, - [team.name, f"--add-user={user.name}" f"--remove-user={user.name}"], - ) - assert result.exit_code == 1 - updated_team = Client().get_team(str(team.id)) - assert set(team.user_ids) == set(updated_team.user_ids) - - -def test_delete_team_succeeds() -> None: - """Test that deleting a user succeeds.""" - team = create_sample_team() - runner = CliRunner() - result = runner.invoke( - team_delete_command, - [team.name], - ) - - assert result.exit_code == 0 diff --git a/tests/integration/functional/cli/utils.py b/tests/integration/functional/cli/utils.py index 942d2df86a4..e22da618556 100644 --- a/tests/integration/functional/cli/utils.py +++ b/tests/integration/functional/cli/utils.py @@ -21,10 +21,7 @@ temporary_active_stack, ) from zenml.client import Client -from zenml.enums import PermissionType from zenml.models import ( - RoleResponseModel, - TeamResponseModel, UserResponseModel, WorkspaceResponseModel, ) @@ -54,24 +51,21 @@ def sample_name(prefix: str = "aria") -> str: def create_sample_user( prefix: Optional[str] = None, password: Optional[str] = None, - initial_role: Optional[str] = None, ) -> UserResponseModel: """Function to create a sample user.""" return Client().create_user( name=sample_name(prefix), password=password if password is not None else random_str(16), - initial_role=initial_role, ) @contextmanager def create_sample_user_and_login( prefix: Optional[str] = None, - initial_role: Optional[str] = None, ) -> Generator[Tuple[UserResponseModel, Client], None, None]: """Context manager to create a sample user and login with it.""" password = random_str(16) - user = create_sample_user(prefix, password, initial_role) + user = create_sample_user(prefix, password) deployment = TestHarness().active_deployment with deployment.connect( @@ -81,26 +75,6 @@ def create_sample_user_and_login( yield user, client -# ----- # -# TEAMS # -# ----- # -team_create_command = cli.commands["team"].commands["create"] -team_update_command = cli.commands["team"].commands["update"] -team_list_command = cli.commands["team"].commands["list"] -team_describe_command = cli.commands["team"].commands["describe"] -team_delete_command = cli.commands["team"].commands["delete"] - - -def sample_team_name() -> str: - """Function to get random team name.""" - return f"felines_{random_str(4)}" - - -def create_sample_team() -> TeamResponseModel: - """Fixture to get a clean global configuration and repository for an individual test.""" - return Client().create_team(name=sample_team_name()) - - def test_parse_name_and_extra_arguments_returns_a_dict_of_known_options() -> ( None ): @@ -115,25 +89,13 @@ def test_parse_name_and_extra_arguments_returns_a_dict_of_known_options() -> ( assert name == "axl" -def sample_role_name() -> str: - """Function to get random role name.""" - return f"cat_feeder_{random_str(4)}" - - -def create_sample_role() -> RoleResponseModel: - """Fixture to get a global configuration with a role.""" - return Client().create_role( - name=sample_role_name(), permissions_list=[PermissionType.READ] - ) - - def sample_workspace_name() -> str: """Function to get random workspace name.""" return f"cat_prj_{random_str(4)}" def create_sample_workspace() -> WorkspaceResponseModel: - """Fixture to get a global configuration with a role.""" + """Fixture to get a workspace.""" return Client().create_workspace( name=sample_workspace_name(), description="This workspace aims to ensure world domination for all " diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index 31f151743c0..10d9e278618 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -60,7 +60,6 @@ def _create_local_orchestrator( flavor="local", component_type=StackComponentType.ORCHESTRATOR, configuration={}, - is_shared=False, ) @@ -73,7 +72,6 @@ def _create_local_artifact_store( flavor="local", component_type=StackComponentType.ARTIFACT_STORE, configuration={}, - is_shared=False, ) @@ -379,7 +377,6 @@ def test_registering_a_stack_component_with_existing_name(clean_client): flavor="local", component_type=StackComponentType.ORCHESTRATOR, configuration={}, - is_shared=False, ) @@ -1021,16 +1018,6 @@ class ClientCrudTestConfig(BaseModel): create_args={"name": sample_name("user_name")}, update_args={"updated_name": sample_name("updated_user_name")}, ), - ClientCrudTestConfig( - entity_name="team", - create_args={"name": sample_name("team_name")}, - update_args={"new_name": sample_name("updated_team_name")}, - ), - ClientCrudTestConfig( - entity_name="role", - create_args={"name": sample_name("role_name"), "permissions_list": []}, - update_args={"new_name": sample_name("updated_role_name")}, - ), ClientCrudTestConfig( entity_name="workspace", create_args={"name": sample_name("workspace_name"), "description": ""}, diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 9cfd1af2e75..44b89fb4826 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -29,11 +29,9 @@ CrudTestConfig, ModelVersionContext, PipelineRunContext, - RoleContext, ServiceConnectorContext, ServiceConnectorTypeContext, StackContext, - TeamContext, UserContext, list_of_entities, ) @@ -64,18 +62,12 @@ ModelVersionUpdateModel, PipelineRunFilterModel, PipelineRunResponseModel, - RoleFilterModel, - RoleRequestModel, - RoleUpdateModel, ServiceConnectorFilterModel, ServiceConnectorUpdateModel, StackFilterModel, StackRequestModel, StackUpdateModel, StepRunFilterModel, - TeamRoleAssignmentRequestModel, - TeamUpdateModel, - UserRoleAssignmentRequestModel, UserUpdateModel, WorkspaceFilterModel, WorkspaceUpdateModel, @@ -90,9 +82,6 @@ _load_file_from_artifact_store, ) from zenml.zen_stores.base_zen_store import ( - DEFAULT_ADMIN_ROLE, - DEFAULT_GUEST_ROLE, - DEFAULT_STACK_NAME, DEFAULT_USERNAME, DEFAULT_WORKSPACE_NAME, ) @@ -288,81 +277,6 @@ def test_deleting_default_workspace_fails(): client.zen_store.delete_workspace(DEFAULT_NAME) -# .-------. -# | TEAMS | -# '-------' - - -def test_adding_user_to_team(): - """Tests adding a user to a team.""" - zen_store = Client().zen_store - with UserContext() as created_user: - with TeamContext() as created_team: - team_update = TeamUpdateModel(users=[created_user.id]) - team_update = zen_store.update_team( - team_id=created_team.id, team_update=team_update - ) - - assert created_user.id in team_update.user_ids - assert len(team_update.users) == 1 - - # Make sure the team name has not been inadvertently changed - assert ( - zen_store.get_team(created_team.id).name == created_team.name - ) - - -def test_adding_nonexistent_user_to_real_team_raises_error(): - """Tests adding a nonexistent user to a team raises an error.""" - zen_store = Client().zen_store - with TeamContext() as created_team: - nonexistent_id = uuid.uuid4() - - team_update = TeamUpdateModel(users=[nonexistent_id]) - with pytest.raises(KeyError): - zen_store.update_team( - team_id=created_team.id, team_update=team_update - ) - - -def test_removing_user_from_team_succeeds(): - """Tests removing a user from a team.""" - - zen_store = Client().zen_store - sample_name("arias_team") - - with UserContext() as created_user: - with TeamContext() as created_team: - team_update = TeamUpdateModel(users=[created_user.id]) - team_update = zen_store.update_team( - team_id=created_team.id, team_update=team_update - ) - - assert created_user.id in team_update.user_ids - - team_update = TeamUpdateModel(users=[]) - team_update = zen_store.update_team( - team_id=created_team.id, team_update=team_update - ) - - assert created_user.id not in team_update.user_ids - - -def test_access_user_in_team_succeeds(): - """Tests accessing a users in a team.""" - - zen_store = Client().zen_store - sample_name("arias_team") - - with UserContext() as created_user: - with TeamContext() as created_team: - team_update = TeamUpdateModel(users=[created_user.id]) - team_update = zen_store.update_team( - team_id=created_team.id, team_update=team_update - ) - assert created_user in team_update.users - - # .------. # | USERS | # '-------' @@ -400,220 +314,10 @@ def test_deleting_default_user_fails(): zen_store.delete_user("default") -def test_getting_team_for_user_succeeds(): - pass - - -def test_team_for_user_succeeds(): - """Tests accessing a users in a team.""" - - zen_store = Client().zen_store - sample_name("arias_team") - - with UserContext() as created_user: - with TeamContext() as created_team: - team_update = TeamUpdateModel(users=[created_user.id]) - team_update = zen_store.update_team( - team_id=created_team.id, team_update=team_update - ) - - updated_user_response = zen_store.get_user(created_user.id) - - assert team_update in updated_user_response.teams - - -# .-------. -# | ROLES | -# '-------' - - -def test_creating_role_with_empty_permissions_succeeds(): - """Tests creating a role.""" - zen_store = Client().zen_store - - with RoleContext() as created_role: - new_role = RoleRequestModel(name=sample_name("cat"), permissions=set()) - created_role = zen_store.create_role(new_role) - with does_not_raise(): - zen_store.get_role(role_name_or_id=created_role.name) - list_of_roles = zen_store.list_roles( - RoleFilterModel(name=created_role.name) - ) - assert list_of_roles.total > 0 - - -def test_deleting_builtin_role_fails(): - """Tests deleting a built-in role fails.""" - zen_store = Client().zen_store - - with pytest.raises(IllegalOperationError): - zen_store.delete_role(DEFAULT_ADMIN_ROLE) - - with pytest.raises(IllegalOperationError): - zen_store.delete_role(DEFAULT_GUEST_ROLE) - - -def test_updating_builtin_role_fails(): - """Tests updating a built-in role fails.""" - zen_store = Client().zen_store - - role = zen_store.get_role(DEFAULT_ADMIN_ROLE) - role_update = RoleUpdateModel(name="cat_feeder") - - with pytest.raises(IllegalOperationError): - zen_store.update_role(role_id=role.id, role_update=role_update) - - role = zen_store.get_role(DEFAULT_GUEST_ROLE) - with pytest.raises(IllegalOperationError): - zen_store.update_role(role_id=role.id, role_update=role_update) - - -def test_deleting_assigned_role_fails(): - """Tests assigning a role to a user.""" - zen_store = Client().zen_store - with RoleContext() as created_role: - with UserContext() as created_user: - role_assignment = UserRoleAssignmentRequestModel( - role=created_role.id, - user=created_user.id, - workspace=None, - ) - with does_not_raise(): - (zen_store.create_user_role_assignment(role_assignment)) - with pytest.raises(IllegalOperationError): - zen_store.delete_role(created_role.id) - - -# .------------------. -# | ROLE ASSIGNMENTS | -# '------------------' - - -def test_assigning_role_to_user_succeeds(): - """Tests assigning a role to a user.""" - zen_store = Client().zen_store - - with RoleContext() as created_role: - with UserContext() as created_user: - role_assignment = UserRoleAssignmentRequestModel( - role=created_role.id, - user=created_user.id, - workspace=None, - ) - with does_not_raise(): - assignment = zen_store.create_user_role_assignment( - role_assignment - ) - - # With user and role deleted the assignment should be deleted as well - with pytest.raises(KeyError): - zen_store.delete_user_role_assignment(assignment.id) - - -def test_assigning_role_to_team_succeeds(): - """Tests assigning a role to a user.""" - zen_store = Client().zen_store - - with RoleContext() as created_role: - with TeamContext() as created_team: - role_assignment = TeamRoleAssignmentRequestModel( - role=created_role.id, - team=created_team.id, - workspace=None, - ) - with does_not_raise(): - assignment = zen_store.create_team_role_assignment( - role_assignment - ) - # With user and role deleted the assignment should be deleted as well - with pytest.raises(KeyError): - zen_store.get_team_role_assignment(assignment.id) - - -def test_assigning_role_if_assignment_already_exists_fails(): - """Tests assigning a role to a user if the assignment already exists.""" - zen_store = Client().zen_store - - with RoleContext() as created_role: - with UserContext() as created_user: - role_assignment = UserRoleAssignmentRequestModel( - role=created_role.id, - user=created_user.id, - workspace=None, - ) - with does_not_raise(): - (zen_store.create_user_role_assignment(role_assignment)) - with pytest.raises(EntityExistsError): - (zen_store.create_user_role_assignment(role_assignment)) - - -def test_revoking_role_for_user_succeeds(): - """Tests revoking a role for a user.""" - zen_store = Client().zen_store - - with RoleContext() as created_role: - with UserContext() as created_user: - role_assignment = UserRoleAssignmentRequestModel( - role=created_role.id, - user=created_user.id, - workspace=None, - ) - with does_not_raise(): - role_assignment = zen_store.create_user_role_assignment( - role_assignment - ) - zen_store.delete_user_role_assignment( - user_role_assignment_id=role_assignment.id - ) - with pytest.raises(KeyError): - zen_store.get_user_role_assignment( - user_role_assignment_id=role_assignment.id - ) - - -def test_revoking_role_for_team_succeeds(): - """Tests revoking a role for a team.""" - zen_store = Client().zen_store - - with RoleContext() as created_role: - with TeamContext() as created_team: - role_assignment = TeamRoleAssignmentRequestModel( - role=created_role.id, - team=created_team.id, - workspace=None, - ) - with does_not_raise(): - role_assignment = zen_store.create_team_role_assignment( - role_assignment - ) - zen_store.delete_team_role_assignment( - team_role_assignment_id=role_assignment.id - ) - with pytest.raises(KeyError): - zen_store.get_team_role_assignment( - team_role_assignment_id=role_assignment.id - ) - - -def test_revoking_nonexistent_role_fails(): - """Tests revoking a nonexistent role fails.""" - zen_store = Client().zen_store - with pytest.raises(KeyError): - zen_store.delete_team_role_assignment( - team_role_assignment_id=uuid.uuid4() - ) - with pytest.raises(KeyError): - zen_store.delete_user_role_assignment( - user_role_assignment_id=uuid.uuid4() - ) - - # .------------------. # | Stack components | # '------------------' -# TODO: tests regarding sharing of components missing - def test_update_default_stack_component_fails(): """Tests that updating default stack components fails.""" @@ -716,7 +420,7 @@ def test_updating_default_stack_fails(): """Tests that updating the default stack is prohibited.""" client = Client() - default_stack = client.get_stack(DEFAULT_STACK_NAME) + default_stack = client.get_stack("default") assert default_stack.name == DEFAULT_WORKSPACE_NAME stack_update = StackUpdateModel(name="axls_stack") with pytest.raises(IllegalOperationError): @@ -729,7 +433,7 @@ def test_deleting_default_stack_fails(): """Tests that deleting the default stack is prohibited.""" client = Client() - default_stack = client.get_stack(DEFAULT_STACK_NAME) + default_stack = client.get_stack("default") with pytest.raises(IllegalOperationError): client.zen_store.delete_stack(default_stack.id) @@ -942,41 +646,6 @@ def test_deleting_a_stack_recursively_with_some_stack_components_present_in_anot store.get_stack_component(secret.id) -def test_private_stacks_are_inaccessible(): - """Tests stack scoping via sharing on rest zen stores.""" - if Client().zen_store.type == StoreType.SQL: - pytest.skip("SQL Zen Stores do not support stack scoping") - - default_user_id = Client().active_user.id - with ComponentContext( - c_type=StackComponentType.ORCHESTRATOR, - flavor="local", - config={}, - user_id=default_user_id, - ) as orchestrator: - with ComponentContext( - c_type=StackComponentType.ARTIFACT_STORE, - flavor="local", - config={}, - user_id=default_user_id, - ) as artifact_store: - components = { - StackComponentType.ORCHESTRATOR: [orchestrator.id], - StackComponentType.ARTIFACT_STORE: [artifact_store.id], - } - with StackContext( - components=components, user_id=default_user_id - ) as stack: - with UserContext(login=True): - # Unshared stack should be invisible to the current user - # Client() needs to be instantiated here with the new - # logged-in user - filtered_stacks = Client().zen_store.list_stacks( - StackFilterModel(name=stack.name) - ) - assert len(filtered_stacks) == 0 - - def test_public_stacks_are_accessible(): """Tests stack scoping via sharing on rest zen stores.""" client = Client() diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index 2d21acf2d41..d5b0b488347 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -55,9 +55,6 @@ PipelineRunRequestModel, PipelineUpdateModel, ResourceTypeModel, - RoleFilterModel, - RoleRequestModel, - RoleUpdateModel, SecretFilterModel, SecretRequestModel, ServiceConnectorFilterModel, @@ -66,9 +63,6 @@ ServiceConnectorUpdateModel, StackRequestModel, StepRunFilterModel, - TeamFilterModel, - TeamRequestModel, - TeamUpdateModel, UserFilterModel, UserRequestModel, UserUpdateModel, @@ -188,11 +182,6 @@ def __enter__(self): self.created_user = self.store.get_user(self.user_name) if self.login or self.existing_user: - if not self.existing_user: - self.client.create_user_role_assignment( - role_name_or_id="admin", - user_name_or_id=self.created_user.id, - ) self.original_config = GlobalConfiguration.get_instance() self.original_client = Client.get_instance() @@ -289,42 +278,6 @@ def __exit__(self, exc_type, exc_value, exc_traceback): pass -class TeamContext: - def __init__(self, team_name: str = "arias_fanclub"): - self.team_name = sample_name(team_name) - self.client = Client() - self.store = self.client.zen_store - - def __enter__(self): - new_team = TeamRequestModel(name=self.team_name) - self.created_team = self.store.create_team(new_team) - return self.created_team - - def __exit__(self, exc_type, exc_value, exc_traceback): - try: - self.store.delete_team(self.created_team.id), - except KeyError: - pass - - -class RoleContext: - def __init__(self, role_name: str = "aria_tamer"): - self.role_name = sample_name(role_name) - self.client = Client() - self.store = self.client.zen_store - - def __enter__(self): - new_role = RoleRequestModel(name=self.role_name, permissions=set()) - self.created_role = self.store.create_role(new_role) - return self.created_role - - def __exit__(self, exc_type, exc_value, exc_traceback): - try: - self.store.delete_role(self.created_role.id) - except KeyError: - pass - - class WorkspaceContext: def __init__( self, @@ -458,7 +411,6 @@ def __init__( expiration_seconds: Optional[int] = None, user_id: Optional[uuid.UUID] = None, workspace_id: Optional[uuid.UUID] = None, - is_shared: bool = False, labels: Optional[Dict[str, str]] = None, client: Optional[Client] = None, delete: bool = True, @@ -474,7 +426,6 @@ def __init__( self.expiration_seconds = expiration_seconds self.user_id = user_id self.workspace_id = workspace_id - self.is_shared = is_shared self.labels = labels self.client = client or Client() self.store = self.client.zen_store @@ -491,7 +442,6 @@ def __enter__(self): secrets=self.secrets or {}, expires_at=self.expires_at, expiration_seconds=self.expiration_seconds, - is_shared=self.is_shared, labels=self.labels or {}, user=self.user_id or self.client.active_user.id, workspace=self.workspace_id or self.client.active_workspace.id, @@ -784,20 +734,6 @@ def update_method( filter_model=UserFilterModel, entity_name="user", ) -role_crud_test_config = CrudTestConfig( - create_model=RoleRequestModel( - name=sample_name("sample_role"), permissions=set() - ), - update_model=RoleUpdateModel(name=sample_name("updated_sample_role")), - filter_model=RoleFilterModel, - entity_name="role", -) -team_crud_test_config = CrudTestConfig( - create_model=TeamRequestModel(name=sample_name("sample_team")), - update_model=TeamUpdateModel(name=sample_name("updated_sample_team")), - filter_model=TeamFilterModel, - entity_name="team", -) flavor_crud_test_config = CrudTestConfig( create_model=FlavorRequestModel( name=sample_name("sample_flavor"), @@ -977,8 +913,6 @@ def update_method( list_of_entities = [ workspace_crud_test_config, user_crud_test_config, - role_crud_test_config, - team_crud_test_config, flavor_crud_test_config, component_crud_test_config, pipeline_crud_test_config, diff --git a/tests/unit/utils/test_analytics_utils.py b/tests/unit/utils/test_analytics_utils.py index 9161e0a01dc..b49fa7c24a1 100644 --- a/tests/unit/utils/test_analytics_utils.py +++ b/tests/unit/utils/test_analytics_utils.py @@ -74,7 +74,6 @@ def event_check( assert StackComponentType.ARTIFACT_STORE in properties assert StackComponentType.ORCHESTRATOR in properties - assert "is_shared" in properties if event == AnalyticsEvent.REGISTERED_STACK_COMPONENT: assert "type" in properties @@ -82,7 +81,6 @@ def event_check( assert "flavor" in properties assert "entity_id" in properties - assert "is_shared" in properties if event == AnalyticsEvent.RUN_PIPELINE: assert "store_type" in properties From af56cf1640b1c9cda63a9eb8fc665adbc56c6d3e Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 25 Oct 2023 10:47:31 +0200 Subject: [PATCH 027/103] Fix more tests --- tests/integration/functional/cli/test_stack.py | 2 +- .../functional/zen_stores/test_zen_store.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/integration/functional/cli/test_stack.py b/tests/integration/functional/cli/test_stack.py index 0f6f68f941d..c50cab2a431 100644 --- a/tests/integration/functional/cli/test_stack.py +++ b/tests/integration/functional/cli/test_stack.py @@ -359,7 +359,7 @@ def test_rename_stack_default_stack_fails(clean_workspace) -> None: rename_command = cli.commands["stack"].commands["rename"] result = runner.invoke(rename_command, ["default", "axls_new_stack"]) assert result.exit_code == 1 - assert len(clean_workspace.list_stacks(name="default")) == 1 + assert clean_workspace.get_stack("default") def test_rename_stack_active_stack_succeeds(clean_workspace) -> None: diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 44b89fb4826..34d0fc3dfc9 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -323,11 +323,14 @@ def test_update_default_stack_component_fails(): """Tests that updating default stack components fails.""" client = Client() store = client.zen_store + default_component_name = store._get_default_stack_and_component_name( + client.active_user.id + ) default_artifact_store = store.list_stack_components( ComponentFilterModel( workspace_id=client.active_workspace.id, type=StackComponentType.ARTIFACT_STORE, - name="default", + name=default_component_name, ) )[0] @@ -335,7 +338,7 @@ def test_update_default_stack_component_fails(): ComponentFilterModel( workspace_id=client.active_workspace.id, type=StackComponentType.ORCHESTRATOR, - name="default", + name=default_component_name, ) )[0] @@ -358,11 +361,15 @@ def test_delete_default_stack_component_fails(): """Tests that deleting default stack components is prohibited.""" client = Client() store = client.zen_store + default_component_name = store._get_default_stack_and_component_name( + client.active_user.id + ) + default_artifact_store = store.list_stack_components( ComponentFilterModel( workspace_id=client.active_workspace.id, type=StackComponentType.ARTIFACT_STORE, - name="default", + name=default_component_name, ) )[0] @@ -370,7 +377,7 @@ def test_delete_default_stack_component_fails(): ComponentFilterModel( workspace_id=client.active_workspace.id, type=StackComponentType.ORCHESTRATOR, - name="default", + name=default_component_name, ) )[0] @@ -421,7 +428,6 @@ def test_updating_default_stack_fails(): client = Client() default_stack = client.get_stack("default") - assert default_stack.name == DEFAULT_WORKSPACE_NAME stack_update = StackUpdateModel(name="axls_stack") with pytest.raises(IllegalOperationError): client.zen_store.update_stack( From 10fe4104e8c364f00008ae5601ee7b41980d2587 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 25 Oct 2023 11:23:38 +0200 Subject: [PATCH 028/103] Remove old tests --- .../zen_stores/test_secrets_store.py | 214 +--------------- .../functional/zen_stores/test_zen_store.py | 241 +----------------- 2 files changed, 6 insertions(+), 449 deletions(-) diff --git a/tests/integration/functional/zen_stores/test_secrets_store.py b/tests/integration/functional/zen_stores/test_secrets_store.py index f52956e2bf9..253a5a8eda2 100644 --- a/tests/integration/functional/zen_stores/test_secrets_store.py +++ b/tests/integration/functional/zen_stores/test_secrets_store.py @@ -26,9 +26,8 @@ ) from zenml.client import Client from zenml.enums import SecretScope, SecretsStoreType, StoreType -from zenml.exceptions import EntityExistsError, IllegalOperationError +from zenml.exceptions import EntityExistsError from zenml.models.secret_models import SecretFilterModel, SecretUpdateModel -from zenml.utils.string_utils import random_str # The AWS secrets store takes some time to reflect new and updated secrets in # the `list_secrets` API. This is the number of seconds to wait before making @@ -1573,217 +1572,6 @@ def test_list_secrets_pagination_and_sorting(): assert len(secrets.items) == 0 -def test_secret_values_cannot_be_accessed_by_readonly_user(): - """Tests that secret values cannot be retrieved by read-only users.""" - if Client().zen_store.type == StoreType.SQL: - pytest.skip("SQL Zen Stores do not support user switching.") - - # Switch to a different user with read-write access - password = random_str(32) - with UserContext(password=password, login=True) as user: - client = Client() - store = client.zen_store - - # Create a workspace-scoped and user-scoped secret - with SecretContext() as secret, SecretContext( - scope=SecretScope.USER - ) as user_secret: - all_secrets = store.list_secrets(SecretFilterModel()).items - assert len(all_secrets) >= 2 - assert secret.id in [s.id for s in all_secrets] - assert user_secret.id in [s.id for s in all_secrets] - workspace_secrets = store.list_secrets( - SecretFilterModel( - scope=SecretScope.WORKSPACE, - workspace_id=client.active_workspace.id, - ) - ).items - assert len(workspace_secrets) >= 1 - assert secret.id in [s.id for s in workspace_secrets] - - saved_secret = store.get_secret(secret.id) - assert saved_secret.secret_values == secret.secret_values - assert all(v is not None for v in saved_secret.values.values()) - - user_secrets = store.list_secrets( - SecretFilterModel( - scope=SecretScope.USER, - user_id=client.active_user.id, - workspace_id=client.active_workspace.id, - ) - ).items - assert len(user_secrets) == 1 - assert user_secret.id == user_secrets[0].id - - saved_secret = store.get_secret(user_secrets[0].id) - assert saved_secret.secret_values == user_secret.secret_values - assert all(v is not None for v in saved_secret.values.values()) - - # Remove the user's write access - client.create_user_role_assignment( - role_name_or_id="guest", user_name_or_id=user.id - ) - admin_role = client.get_role("admin") - role_assignment = client.list_user_role_assignment( - user_id=user.id, - role_id=admin_role.id, - ) - assert len(role_assignment) == 1 - client.delete_user_role_assignment(role_assignment[0].id) - - # Re-authenticate with the updated user - with UserContext( - user_name=user.name, password=password, existing_user=True - ): - all_secrets = store.list_secrets(SecretFilterModel()).items - assert len(all_secrets) >= 2 - assert secret.id in [s.id for s in all_secrets] - assert user_secret.id in [s.id for s in all_secrets] - workspace_secrets = store.list_secrets( - SecretFilterModel( - scope=SecretScope.WORKSPACE, - workspace_id=client.active_workspace.id, - ) - ).items - assert len(workspace_secrets) >= 1 - assert secret.id in [s.id for s in workspace_secrets] - - saved_secret = store.get_secret(secret.id) - assert set(saved_secret.values.keys()) == set( - secret.values.keys() - ) - assert set(saved_secret.values.values()) == {None} - - user_secrets = store.list_secrets( - SecretFilterModel( - scope=SecretScope.USER, - user_id=client.active_user.id, - workspace_id=client.active_workspace.id, - ) - ).items - assert len(user_secrets) == 1 - assert user_secret.id == user_secrets[0].id - - saved_secret = store.get_secret(user_secrets[0].id) - assert set(saved_secret.values.values()) == {None} - assert set(saved_secret.values.keys()) == set( - user_secret.values.keys() - ) - - -def test_secrets_cannot_be_created_or_updated_by_readonly_user(): - """Tests that secret values cannot be created or updated by read-only users.""" - if Client().zen_store.type == StoreType.SQL: - pytest.skip("SQL Zen Stores do not support user switching.") - - # Switch to a different user with read-write access - password = random_str(32) - with UserContext(password=password, login=True) as user: - client = Client() - store = client.zen_store - - # Create a workspace-scoped and user-scoped secret - with SecretContext() as secret, SecretContext( - scope=SecretScope.USER - ) as user_secret: - with does_not_raise(): - store.update_secret( - secret.id, SecretUpdateModel(name=f"{secret.name}-new") - ) - - if _get_secrets_store_type() == SecretsStoreType.AWS: - # The AWS secrets store returns before the secret is actually - # updated in the backend, so we need to wait a bit before - # running `list_secrets`. - time.sleep(AWS_SECRET_REFRESH_SLEEP) - - old_secrets = store.list_secrets( - SecretFilterModel(name=secret.name) - ).items - assert len(old_secrets) == 0 - - new_secrets = store.list_secrets( - SecretFilterModel(name=f"{secret.name}-new") - ).items - assert len(new_secrets) == 1 - assert new_secrets[0].id == secret.id - - with does_not_raise(): - store.update_secret( - user_secret.id, - SecretUpdateModel(name=f"{user_secret.name}-new"), - ) - - if _get_secrets_store_type() == SecretsStoreType.AWS: - # The AWS secrets store returns before the secret is actually - # updated in the backend, so we need to wait a bit before - # running `list_secrets`. - time.sleep(AWS_SECRET_REFRESH_SLEEP) - - old_secrets = store.list_secrets( - SecretFilterModel(name=user_secret.name) - ).items - assert len(old_secrets) == 0 - - new_secrets = store.list_secrets( - SecretFilterModel(name=f"{user_secret.name}-new") - ).items - assert len(new_secrets) == 1 - assert new_secrets[0].id == user_secret.id - - # Remove the user's write access - client.create_user_role_assignment( - role_name_or_id="guest", user_name_or_id=user.id - ) - admin_role = client.get_role("admin") - role_assignment = client.list_user_role_assignment( - user_id=user.id, - role_id=admin_role.id, - ) - assert len(role_assignment) == 1 - client.delete_user_role_assignment(role_assignment[0].id) - - # Re-authenticate with the updated user - with UserContext( - user_name=user.name, password=password, existing_user=True - ): - new_client = Client() - new_store = new_client.zen_store - - with pytest.raises(IllegalOperationError): - new_store.update_secret( - secret.id, SecretUpdateModel(name=f"{secret.name}") - ) - - old_secrets = new_store.list_secrets( - SecretFilterModel(name=secret.name) - ).items - assert len(old_secrets) == 0 - - new_secrets = new_store.list_secrets( - SecretFilterModel(name=f"{secret.name}-new") - ).items - assert len(new_secrets) == 1 - assert new_secrets[0].id == secret.id - - with pytest.raises(IllegalOperationError): - new_store.update_secret( - user_secret.id, - SecretUpdateModel(name=f"{user_secret.name}-new"), - ) - - old_secrets = new_store.list_secrets( - SecretFilterModel(name=user_secret.name) - ).items - assert len(old_secrets) == 0 - - new_secrets = new_store.list_secrets( - SecretFilterModel(name=f"{user_secret.name}-new") - ).items - assert len(new_secrets) == 1 - assert new_secrets[0].id == user_secret.id - - def test_secret_is_deleted_with_workspace(): """Tests that deleting a workspace automatically deletes all its secrets.""" client = Client() diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 34d0fc3dfc9..4021259be64 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -40,7 +40,7 @@ ) from zenml.client import Client from zenml.constants import RUNNING_MODEL_VERSION -from zenml.enums import ModelStages, SecretScope, StackComponentType, StoreType +from zenml.enums import ModelStages, StackComponentType, StoreType from zenml.exceptions import ( DoesNotExistException, EntityExistsError, @@ -652,8 +652,8 @@ def test_deleting_a_stack_recursively_with_some_stack_components_present_in_anot store.get_stack_component(secret.id) -def test_public_stacks_are_accessible(): - """Tests stack scoping via sharing on rest zen stores.""" +def test_stacks_are_accessible_by_other_users(): + """Tests accessing stack on rest zen stores.""" client = Client() store = client.zen_store if store.type == StoreType.SQL: @@ -679,12 +679,6 @@ def test_public_stacks_are_accessible(): with StackContext( components=components, user_id=default_user_id ) as stack: - # Update - stack_update = StackUpdateModel(is_shared=True) - store.update_stack( - stack_id=stack.id, stack_update=stack_update - ) - with UserContext(login=True): # Client() needs to be instantiated here with the new # logged-in user @@ -1172,70 +1166,13 @@ def test_connector_name_reuse_for_same_user_fails(): pass -def test_connector_same_name_different_users(): - """Tests that a connector's name can be used if another user has it.""" - - if Client().zen_store.type == StoreType.SQL: - pytest.skip("SQL Zen Stores do not support user switching.") - - with ServiceConnectorContext( - connector_type="cat'o'matic", - auth_method="paw-print", - resource_types=["cat"], - ) as connector_one: - with UserContext(login=True): - # Client() needs to be instantiated here with the new - # logged-in user - other_client = Client() - - with ServiceConnectorContext( - name=connector_one.name, - connector_type="cat'o'matic", - auth_method="paw-print", - resource_types=["cat"], - client=other_client, - ): - pass - - -def test_connector_same_name_different_users_shared(): - """Tests that a connector's name can be used even if another user has it shared.""" - - if Client().zen_store.type == StoreType.SQL: - pytest.skip("SQL Zen Stores do not support user switching.") - - with ServiceConnectorContext( - connector_type="cat'o'matic", - auth_method="paw-print", - resource_types=["cat"], - is_shared=True, - ) as connector_one: - with UserContext(login=True): - # Client() needs to be instantiated here with the new - # logged-in user - other_client = Client() - - with ServiceConnectorContext( - name=connector_one.name, - connector_type="cat'o'matic", - auth_method="paw-print", - resource_types=["cat"], - client=other_client, - ): - pass - - -def test_connector_same_name_different_users_both_shared(): - """Tests that a shared connector's name cannot be used if another user also has it shared.""" - - if Client().zen_store.type == StoreType.SQL: - pytest.skip("SQL Zen Stores do not support user switching.") +def test_connector_name_reuse_for_different_user_fails(): + """Tests that a connector's name cannot be re-used by another user.""" with ServiceConnectorContext( connector_type="cat'o'matic", auth_method="paw-print", resource_types=["cat"], - is_shared=True, ) as connector_one: with UserContext(login=True): # Client() needs to be instantiated here with the new @@ -1249,7 +1186,6 @@ def test_connector_same_name_different_users_both_shared(): auth_method="paw-print", resource_types=["cat"], client=other_client, - is_shared=True, ): pass @@ -1458,61 +1394,6 @@ def test_connector_list(): assert rodent_connector.id not in [c.id for c in connectors] -def test_private_connector_not_visible_to_other_user(): - """Tests that a private connector is not visible to another user.""" - - if Client().zen_store.type == StoreType.SQL: - pytest.skip("SQL Zen Stores do not support user switching.") - - with ServiceConnectorContext( - connector_type="cat'o'matic", - auth_method="paw-print", - resource_types=["cat"], - is_shared=False, - ) as connector: - with UserContext(login=True): - # Client() needs to be instantiated here with the new - # logged-in user - other_client = Client() - other_store = other_client.zen_store - - with pytest.raises(KeyError): - other_store.get_service_connector(connector.id) - - connectors = other_store.list_service_connectors( - ServiceConnectorFilterModel() - ).items - - assert connector.id not in [c.id for c in connectors] - - -def test_shared_connector_is_visible_to_other_user(): - """Tests that a shared connector is visible to another user.""" - - if Client().zen_store.type == StoreType.SQL: - pytest.skip("SQL Zen Stores do not support user switching.") - - with ServiceConnectorContext( - connector_type="cat'o'matic", - auth_method="paw-print", - resource_types=["cat"], - is_shared=True, - ) as connector: - with UserContext(login=True): - # Client() needs to be instantiated here with the new - # logged-in user - other_client = Client() - other_store = other_client.zen_store - - other_store.get_service_connector(connector.id) - - connectors = other_store.list_service_connectors( - ServiceConnectorFilterModel() - ).items - - assert connector.id in [c.id for c in connectors] - - def _update_connector_and_test( new_name: Optional[str] = None, new_connector_type: Optional[str] = None, @@ -1775,118 +1656,6 @@ def test_connector_name_update_fails_if_exists(): ) -def test_connector_sharing(): - """Tests that a connector can be shared.""" - - client = Client() - store = client.zen_store - - if client.zen_store.type == StoreType.SQL: - pytest.skip("SQL Zen Stores do not support user switching.") - - config = { - "language": "meow", - "foods": "tuna", - } - secrets = { - "hiding-place": SecretStr("thatsformetoknowandyouneverfindout"), - "dreams": SecretStr("notyourbusiness"), - } - - with ServiceConnectorContext( - connector_type="cat'o'matic", - auth_method="paw-print", - resource_types=["cat"], - configuration=config, - secrets=secrets, - is_shared=False, - ) as connector: - assert connector.secret_id is not None - secret = store.get_secret(connector.secret_id) - assert secret.scope == SecretScope.USER - - with UserContext(login=True): - # Client() needs to be instantiated here with the new - # logged-in user - other_client = Client() - other_store = other_client.zen_store - - with pytest.raises(KeyError): - other_store.get_service_connector(connector.id) - - connectors = other_store.list_service_connectors( - ServiceConnectorFilterModel() - ).items - - assert connector.id not in [c.id for c in connectors] - - updated_connector = store.update_service_connector( - connector.id, - update=ServiceConnectorUpdateModel(is_shared=True), - ) - - assert updated_connector.secret_id is not None - assert updated_connector.secret_id == connector.secret_id - secret = store.get_secret(updated_connector.secret_id) - assert secret.scope == SecretScope.WORKSPACE - - with UserContext(login=True): - # Client() needs to be instantiated here with the new - # logged-in user - other_client = Client() - other_store = other_client.zen_store - - other_store.get_service_connector(connector.id) - - connectors = other_store.list_service_connectors( - ServiceConnectorFilterModel() - ).items - - assert connector.id in [c.id for c in connectors] - - -def test_connector_sharing_fails_if_name_shared(): - """Tests that a connector cannot be shared if the name is already shared.""" - - client = Client() - - if client.zen_store.type == StoreType.SQL: - pytest.skip("SQL Zen Stores do not support user switching.") - - with ServiceConnectorContext( - connector_type="cat'o'matic", - auth_method="paw-print", - resource_types=["cat"], - is_shared=True, - ) as connector: - with UserContext(login=True): - # Client() needs to be instantiated here with the new - # logged-in user - other_client = Client() - other_store = other_client.zen_store - - other_store.get_service_connector(connector.id) - - connectors = other_store.list_service_connectors( - ServiceConnectorFilterModel() - ).items - - assert connector.id in [c.id for c in connectors] - - with ServiceConnectorContext( - name=connector.name, - connector_type="cat'o'matic", - auth_method="paw-print", - resource_types=["cat"], - is_shared=False, - ) as other_connector: - with pytest.raises(EntityExistsError): - other_store.update_service_connector( - other_connector.id, - update=ServiceConnectorUpdateModel(is_shared=True), - ) - - # .-------------------------. # | Service Connector Types | # '-------------------------' From 11985f03298896217502ca9b35588d17368256f7 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 25 Oct 2023 17:46:50 +0200 Subject: [PATCH 029/103] Fix alembic order --- .../migrations/versions/7500f434b71c_remove_shared_columns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py index 70ad3b75cd2..0353e4798af 100644 --- a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -10,7 +10,7 @@ # revision identifiers, used by Alembic. revision = "7500f434b71c" -down_revision = "0.45.4" +down_revision = "0.45.5" branch_labels = None depends_on = None From a82d93cdb803195791c2f24bb42e2c8396b72b9e Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 26 Oct 2023 15:56:17 +0200 Subject: [PATCH 030/103] Higher level helper functions for rbac endpoints --- src/zenml/zen_server/rbac/endpoint_utils.py | 139 ++++++++++++++++++ src/zenml/zen_server/rbac/models.py | 2 +- src/zenml/zen_server/rbac/utils.py | 19 --- src/zenml/zen_server/rbac/zenml_cloud_rbac.py | 26 +++- .../zen_server/routers/artifacts_endpoints.py | 33 ++++- .../routers/code_repositories_endpoints.py | 30 +++- .../zen_server/routers/flavors_endpoints.py | 56 ++++--- .../zen_server/routers/pipelines_endpoints.py | 35 ++++- .../routers/stack_components_endpoints.py | 41 ++++-- .../zen_server/routers/stacks_endpoints.py | 50 +++---- 10 files changed, 334 insertions(+), 97 deletions(-) create mode 100644 src/zenml/zen_server/rbac/endpoint_utils.py diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py new file mode 100644 index 00000000000..081f1717893 --- /dev/null +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -0,0 +1,139 @@ +"""High-level helper functions to write endpoints with RBAC.""" +from typing import Callable, TypeVar +from uuid import UUID + +from pydantic import BaseModel + +from zenml.exceptions import IllegalOperationError +from zenml.models.base_models import ( + BaseRequestModel, + BaseResponseModel, + UserScopedRequestModel, +) +from zenml.models.filter_models import BaseFilterModel +from zenml.models.page_model import Page +from zenml.zen_server.auth import get_auth_context +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import ( + dehydrate_page, + dehydrate_response_model, + get_allowed_resource_ids, + verify_permission, + verify_permission_for_model, +) + +AnyRequestModel = TypeVar("AnyRequestModel", bound=BaseRequestModel) +AnyResponseModel = TypeVar("AnyResponseModel", bound=BaseResponseModel) +AnyFilterModel = TypeVar("AnyFilterModel", bound=BaseFilterModel) +AnyUpdateModel = TypeVar("AnyUpdateModel", bound=BaseModel) + + +def verify_permissions_and_create_entity( + request_model: AnyRequestModel, + resource_type: ResourceType, + create_method: Callable[[AnyRequestModel], AnyResponseModel], +) -> AnyResponseModel: + """Verify permissions and create the entity if authorized. + + Args: + request_model: The entity request model. + resource_type: The resource type of the entity to create. + create_method: The method to create the entity. + + Raises: + IllegalOperationError: If the request model has a different owner then + the currently authenticated user. + + Returns: + A model of the created entity. + """ + if isinstance(request_model, UserScopedRequestModel): + auth_context = get_auth_context() + assert auth_context + + if request_model.user != auth_context.user.id: + raise IllegalOperationError( + f"Not allowed to create resource '{resource_type}' for a " + "different user." + ) + + verify_permission(resource_type=resource_type, action=Action.CREATE) + return create_method(request_model) + + +def verify_permissions_and_get_entity( + id: UUID, get_method: Callable[[UUID], AnyResponseModel] +) -> AnyResponseModel: + """Verify permissions and fetch an entity. + + Args: + id: The ID of the entity to fetch. + get_method: The method to fetch the entity. + + Returns: + A model of the fetched entity. + """ + model = get_method(id) + verify_permission_for_model(model, action=Action.READ) + return dehydrate_response_model(model) + + +def verify_permissions_and_list_entities( + filter_model: AnyFilterModel, + resource_type: ResourceType, + list_method: Callable[[AnyFilterModel], Page[AnyResponseModel]], +) -> Page[AnyResponseModel]: + """Verify permissions and list entities. + + Args: + filter_model: The entity filter model. + resource_type: The resource type of the entities to list. + list_method: The method to list the entities. + + Returns: + A page of entity models. + """ + allowed_ids = get_allowed_resource_ids(resource_type=resource_type) + filter_model.set_allowed_ids(allowed_ids) + page = list_method(filter_model) + return dehydrate_page(page) + + +def verify_permissions_and_update_entity( + id: UUID, + update_model: AnyUpdateModel, + get_method: Callable[[UUID], AnyResponseModel], + update_method: Callable[[UUID, AnyUpdateModel], AnyResponseModel], +) -> AnyResponseModel: + """Verify permissions and update an entity. + + Args: + id: The ID of the entity to update. + update_model: The entity update model. + get_method: The method to fetch the entity. + update_method: The method to update the entity. + + Returns: + A model of the updated entity. + """ + model = get_method(id) + verify_permission_for_model(model, action=Action.UPDATE) + updated_model = update_method(id, update_model) + return dehydrate_response_model(updated_model) + + +def verify_permissions_and_delete_entity( + id: UUID, + get_method: Callable[[UUID], AnyResponseModel], + delete_method: Callable[[UUID], None], +) -> None: + """Verify permissions and delete an entity. + + Args: + id: The ID of the entity to delete. + get_method: The method to fetch the entity. + delete_method: The method to delete the entity. + """ + model = get_method(id) + verify_permission_for_model(model, action=Action.DELETE) + delete_method(id) diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index 4538e963a70..57d10c2ae7d 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -37,7 +37,7 @@ class ResourceType(StrEnum): FLAVOR = "flavor" STACK_COMPONENT = "stack_component" PIPELINE = "pipeline" - CODE_REPOSITORY = "code-repository" + CODE_REPOSITORY = "code_repository" MODEL = "model" SERVICE_CONNECTOR = "service_connector" ARTIFACT = "artifact" diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index db3ab81a64d..b13ff45ec99 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -29,25 +29,6 @@ M = TypeVar("M", bound=BaseResponseModel) -def verify_read_permissions_and_dehydrate( - model: M, -) -> M: - """Verify read permissions of the model and dehydrate it if necessary. - - Args: - model: The model for which to verify permissions. - - Returns: - The (potentially) dehydrated model. - """ - if not server_config().rbac_enabled: - return model - - verify_permission_for_model(model=model, action=Action.READ) - - return dehydrate_response_model(model=model) - - def dehydrate_page(page: Page[M]) -> Page[M]: """Dehydrate all items of a page. diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py index 93b750d7448..f4ee5655b58 100644 --- a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -52,6 +52,30 @@ def _convert_to_cloud_resource(resource: Resource) -> str: return resource_string +def _convert_from_cloud_resource(cloud_resource: str) -> Resource: + """Convert a cloud resource to a ZenML server resource. + + Args: + cloud_resource: The cloud resource to convert. + + Raises: + ValueError: If the cloud resource is invalid for this server. + + Returns: + The converted resource. + """ + scope, resource_type_and_id = cloud_resource.rsplit(":", maxsplit=1) + + if scope != f"{SERVER_ID}@{SERVER_SCOPE_IDENTIFIER}": + raise ValueError("Invalid scope for server resource.") + + if "/" in resource_type_and_id: + resource_type, resource_id = resource_type_and_id.split("/") + return Resource(type=resource_type, id=resource_id) + else: + return Resource(type=resource_type_and_id) + + class ZenMLCloudRBACConfiguration(BaseModel): """ZenML Cloud RBAC configuration.""" @@ -137,7 +161,7 @@ def check_permissions( value = response.json() assert isinstance(value, dict) - return value + return {_convert_from_cloud_resource(k): v for k, v in value.items()} def list_allowed_resource_ids( self, user: "UserResponseModel", resource: Resource, action: str diff --git a/src/zenml/zen_server/routers/artifacts_endpoints.py b/src/zenml/zen_server/routers/artifacts_endpoints.py index 6dd289f442a..87d0e025212 100644 --- a/src/zenml/zen_server/routers/artifacts_endpoints.py +++ b/src/zenml/zen_server/routers/artifacts_endpoints.py @@ -30,6 +30,13 @@ from zenml.utils.artifact_utils import load_artifact_visualization from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_create_entity, + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, +) +from zenml.zen_server.rbac.models import ResourceType from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -64,8 +71,10 @@ def list_artifacts( Returns: The artifacts according to query filters. """ - return zen_store().list_artifacts( - artifact_filter_model=artifact_filter_model + return verify_permissions_and_list_entities( + filter_model=artifact_filter_model, + resource_type=ResourceType.ARTIFACT, + list_method=zen_store().list_artifacts, ) @@ -87,7 +96,11 @@ def create_artifact( Returns: The created artifact. """ - return zen_store().create_artifact(artifact) + return verify_permissions_and_create_entity( + request_model=artifact, + resource_type=ResourceType.ARTIFACT, + create_method=zen_store().create_artifact, + ) @router.get( @@ -108,7 +121,9 @@ def get_artifact( Returns: The artifact with the given ID. """ - return zen_store().get_artifact(artifact_id) + return verify_permissions_and_get_entity( + id=artifact_id, get_method=zen_store().get_artifact + ) @router.delete( @@ -125,7 +140,11 @@ def delete_artifact( Args: artifact_id: The ID of the artifact to delete. """ - zen_store().delete_artifact(artifact_id) + verify_permissions_and_delete_entity( + id=artifact_id, + get_method=zen_store().get_artifact, + delete_method=zen_store().delete_artifact, + ) @router.get( @@ -149,7 +168,9 @@ def get_artifact_visualization( The visualization of the artifact. """ store = zen_store() - artifact = store.get_artifact(artifact_id) + artifact = verify_permissions_and_get_entity( + id=artifact_id, get_method=store.get_artifact + ) return load_artifact_visualization( artifact=artifact, index=index, zen_store=store, encode_image=True ) diff --git a/src/zenml/zen_server/routers/code_repositories_endpoints.py b/src/zenml/zen_server/routers/code_repositories_endpoints.py index c0dd24daad2..da2df1d9583 100644 --- a/src/zenml/zen_server/routers/code_repositories_endpoints.py +++ b/src/zenml/zen_server/routers/code_repositories_endpoints.py @@ -25,6 +25,13 @@ from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import ResourceType from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -59,7 +66,11 @@ def list_code_repositories( Returns: Page of code repository objects. """ - return zen_store().list_code_repositories(filter_model=filter_model) + return verify_permissions_and_list_entities( + filter_model=filter_model, + resource_type=ResourceType.CODE_REPOSITORY, + list_method=zen_store().list_code_repositories, + ) @router.get( @@ -80,8 +91,8 @@ def get_code_repository( Returns: A specific code repository object. """ - return zen_store().get_code_repository( - code_repository_id=code_repository_id + return verify_permissions_and_get_entity( + id=code_repository_id, get_method=zen_store().get_code_repository ) @@ -105,8 +116,11 @@ def update_code_repository( Returns: The updated code repository object. """ - return zen_store().update_code_repository( - code_repository_id=code_repository_id, update=update + return verify_permissions_and_update_entity( + id=code_repository_id, + update_model=update, + get_method=zen_store().get_code_repository, + update_method=zen_store().update_code_repository, ) @@ -124,4 +138,8 @@ def delete_code_repository( Args: code_repository_id: The ID of the code repository to delete. """ - zen_store().delete_code_repository(code_repository_id=code_repository_id) + verify_permissions_and_delete_entity( + id=code_repository_id, + get_method=zen_store().get_code_repository, + delete_method=zen_store().delete_code_repository, + ) diff --git a/src/zenml/zen_server/routers/flavors_endpoints.py b/src/zenml/zen_server/routers/flavors_endpoints.py index 182c69cbe62..a2f55fac90c 100644 --- a/src/zenml/zen_server/routers/flavors_endpoints.py +++ b/src/zenml/zen_server/routers/flavors_endpoints.py @@ -18,7 +18,6 @@ from fastapi import APIRouter, Depends, Security from zenml.constants import API, FLAVORS, VERSION_1 -from zenml.exceptions import IllegalOperationError from zenml.models import ( FlavorFilterModel, FlavorRequestModel, @@ -28,6 +27,15 @@ from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_create_entity, + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import verify_permission from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -63,7 +71,11 @@ def list_flavors( Returns: All flavors. """ - return zen_store().list_flavors(flavor_filter_model=flavor_filter_model) + return verify_permissions_and_list_entities( + filter_model=flavor_filter_model, + resource_type=ResourceType.FLAVOR, + list_method=zen_store().list_flavors, + ) @router.get( @@ -84,8 +96,9 @@ def get_flavor( Returns: The requested stack. """ - flavor = zen_store().get_flavor(flavor_id) - return flavor + return verify_permissions_and_get_entity( + id=flavor_id, get_method=zen_store().get_flavor + ) @router.post( @@ -96,32 +109,21 @@ def get_flavor( @handle_exceptions def create_flavor( flavor: FlavorRequestModel, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> FlavorResponseModel: """Creates a stack component flavor. Args: flavor: Stack component flavor to register. - auth_context: Authentication context. Returns: The created stack component flavor. - - Raises: - IllegalOperationError: If the workspace or user specified in the stack - component flavor does not match the current workspace or authenticated - user. """ - if flavor.user != auth_context.user.id: - raise IllegalOperationError( - "Creating flavors for a user other than yourself " - "is not supported." - ) - - created_flavor = zen_store().create_flavor( - flavor=flavor, + return verify_permissions_and_create_entity( + request_model=flavor, + resource_type=ResourceType.FLAVOR, + create_method=zen_store().create_flavor, ) - return created_flavor @router.put( @@ -146,8 +148,11 @@ def update_flavor( Returns: The updated flavor. """ - return zen_store().update_flavor( - flavor_id=flavor_id, flavor_update=flavor_update + return verify_permissions_and_update_entity( + id=flavor_id, + update_model=flavor_update, + get_method=zen_store().get_flavor, + update_method=zen_store().update_flavor, ) @@ -165,7 +170,11 @@ def delete_flavor( Args: flavor_id: ID of the flavor. """ - zen_store().delete_flavor(flavor_id) + verify_permissions_and_delete_entity( + id=flavor_id, + get_method=zen_store().get_flavor, + delete_method=zen_store().delete_flavor, + ) @router.patch( @@ -181,4 +190,5 @@ def sync_flavors( Returns: None if successful. Raises an exception otherwise. """ + verify_permission(resource_type=ResourceType.FLAVOR, action=Action.UPDATE) return zen_store()._sync_flavors() diff --git a/src/zenml/zen_server/routers/pipelines_endpoints.py b/src/zenml/zen_server/routers/pipelines_endpoints.py index b1870e3ae83..25a4e6835f8 100644 --- a/src/zenml/zen_server/routers/pipelines_endpoints.py +++ b/src/zenml/zen_server/routers/pipelines_endpoints.py @@ -28,6 +28,13 @@ from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import ResourceType from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -66,8 +73,10 @@ def list_pipelines( Returns: List of pipeline objects. """ - return zen_store().list_pipelines( - pipeline_filter_model=pipeline_filter_model + return verify_permissions_and_list_entities( + filter_model=pipeline_filter_model, + resource_type=ResourceType.PIPELINE, + list_method=zen_store().list_pipelines, ) @@ -93,7 +102,9 @@ def get_pipeline( Returns: A specific pipeline object. """ - return zen_store().get_pipeline(pipeline_id=pipeline_id) + return verify_permissions_and_get_entity( + id=pipeline_id, get_method=zen_store().get_pipeline + ) @router.put( @@ -116,8 +127,11 @@ def update_pipeline( Returns: The updated pipeline object. """ - return zen_store().update_pipeline( - pipeline_id=pipeline_id, pipeline_update=pipeline_update + return verify_permissions_and_update_entity( + id=pipeline_id, + update_model=pipeline_update, + get_method=zen_store().get_pipeline, + update_method=zen_store().update_pipeline, ) @@ -135,7 +149,11 @@ def delete_pipeline( Args: pipeline_id: ID of the pipeline to delete. """ - zen_store().delete_pipeline(pipeline_id=pipeline_id) + verify_permissions_and_delete_entity( + id=pipeline_id, + get_method=zen_store().get_pipeline, + delete_method=zen_store().delete_pipeline, + ) @router.get( @@ -180,4 +198,7 @@ def get_pipeline_spec( Returns: The spec of the pipeline. """ - return zen_store().get_pipeline(pipeline_id).spec + pipeline = verify_permissions_and_get_entity( + id=pipeline_id, get_method=zen_store().get_pipeline + ) + return pipeline.spec diff --git a/src/zenml/zen_server/routers/stack_components_endpoints.py b/src/zenml/zen_server/routers/stack_components_endpoints.py index f8088f9693d..39a73efe334 100644 --- a/src/zenml/zen_server/routers/stack_components_endpoints.py +++ b/src/zenml/zen_server/routers/stack_components_endpoints.py @@ -27,6 +27,14 @@ from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import verify_permission_for_model from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -56,20 +64,21 @@ def list_stack_components( component_filter_model: ComponentFilterModel = Depends( make_dependable(ComponentFilterModel) ), - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> Page[ComponentResponseModel]: """Get a list of all stack components for a specific type. Args: component_filter_model: Filter model used for pagination, sorting, filtering - auth_context: Authentication Context Returns: List of stack components for a specific type. """ - return zen_store().list_stack_components( - component_filter_model=component_filter_model + return verify_permissions_and_list_entities( + filter_model=component_filter_model, + resource_type=ResourceType.STACK_COMPONENT, + list_method=zen_store().list_stack_components, ) @@ -91,7 +100,9 @@ def get_stack_component( Returns: The requested stack component. """ - return zen_store().get_stack_component(component_id) + return verify_permissions_and_get_entity( + id=component_id, get_method=zen_store().get_stack_component + ) @router.put( @@ -114,9 +125,17 @@ def update_stack_component( Returns: Updated stack component. """ - return zen_store().update_stack_component( - component_id=component_id, - component_update=component_update, + if component_update.connector: + service_connector = zen_store().get_service_connector( + component_update.connector + ) + verify_permission_for_model(service_connector, action=Action.READ) + + return verify_permissions_and_update_entity( + id=component_id, + update_model=component_update, + get_method=zen_store().get_stack_component, + update_method=zen_store().update_stack_component, ) @@ -134,7 +153,11 @@ def deregister_stack_component( Args: component_id: ID of the stack component. """ - zen_store().delete_stack_component(component_id) + verify_permissions_and_delete_entity( + id=component_id, + get_method=zen_store().get_stack_component, + delete_method=zen_store().delete_stack_component, + ) @types_router.get( diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index 0e87eac69a9..da2f7317755 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -25,15 +25,14 @@ authorize, ) from zenml.zen_server.exceptions import error_response -from zenml.zen_server.rbac.models import Action, ResourceType -from zenml.zen_server.rbac.utils import ( - batch_verify_permissions_for_models, - dehydrate_page, - dehydrate_response_model, - get_allowed_resource_ids, - verify_permission_for_model, - verify_read_permissions_and_dehydrate, +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, ) +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import batch_verify_permissions_for_models from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -67,10 +66,11 @@ def list_stacks( Returns: All stacks. """ - allowed_ids = get_allowed_resource_ids(resource_type=ResourceType.STACK) - stack_filter_model.set_allowed_ids(allowed_ids) - page = zen_store().list_stacks(stack_filter_model=stack_filter_model) - return dehydrate_page(page) + return verify_permissions_and_list_entities( + filter_model=stack_filter_model, + resource_type=ResourceType.STACK, + list_method=zen_store().list_stacks, + ) @router.get( @@ -95,8 +95,9 @@ def get_stack( Returns: The requested stack. """ - stack = zen_store().get_stack(stack_id) - return verify_read_permissions_and_dehydrate(stack) + return verify_permissions_and_get_entity( + id=stack_id, get_method=zen_store().get_stack + ) @router.put( @@ -119,9 +120,6 @@ def update_stack( Returns: The updated stack. """ - stack = zen_store().get_stack(stack_id) - verify_permission_for_model(stack, action=Action.UPDATE) - if stack_update.components: updated_components = [ zen_store().get_stack_component(id) @@ -133,11 +131,12 @@ def update_stack( updated_components, action=Action.READ ) - updated_stack = zen_store().update_stack( - stack_id=stack_id, - stack_update=stack_update, + return verify_permissions_and_update_entity( + id=stack_id, + update_model=stack_update, + get_method=zen_store().get_stack, + update_method=zen_store().update_stack, ) - return dehydrate_response_model(updated_stack) @router.delete( @@ -154,7 +153,8 @@ def delete_stack( Args: stack_id: Name of the stack. """ - stack = zen_store().get_stack(stack_id) - verify_permission_for_model(stack, action=Action.DELETE) - - zen_store().delete_stack(stack_id) + verify_permissions_and_delete_entity( + id=stack_id, + get_method=zen_store().get_stack, + delete_method=zen_store().delete_stack, + ) From 822fd22759b1499506e6b510547e07d638b69b0e Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 26 Oct 2023 16:22:02 +0200 Subject: [PATCH 031/103] More endpoints implemented --- .../routers/workspaces_endpoints.py | 236 ++++++++++-------- 1 file changed, 131 insertions(+), 105 deletions(-) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 1c705c9e66c..037bc00374c 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -85,13 +85,20 @@ StackRequestModel, StackResponseModel, WorkspaceFilterModel, - WorkspaceRequestModel, WorkspaceResponseModel, - WorkspaceUpdateModel, ) from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_create_entity, + verify_permissions_and_list_entities, +) +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import ( + batch_verify_permissions_for_models, + verify_permission_for_model, +) from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -131,27 +138,27 @@ def list_workspaces( ) -@router.post( - WORKSPACES, - response_model=WorkspaceResponseModel, - responses={401: error_response, 409: error_response, 422: error_response}, -) -@handle_exceptions -def create_workspace( - workspace: WorkspaceRequestModel, - _: AuthContext = Security(authorize), -) -> WorkspaceResponseModel: - """Creates a workspace based on the requestBody. +# @router.post( +# WORKSPACES, +# response_model=WorkspaceResponseModel, +# responses={401: error_response, 409: error_response, 422: error_response}, +# ) +# @handle_exceptions +# def create_workspace( +# workspace: WorkspaceRequestModel, +# _: AuthContext = Security(authorize), +# ) -> WorkspaceResponseModel: +# """Creates a workspace based on the requestBody. - # noqa: DAR401 +# # noqa: DAR401 - Args: - workspace: Workspace to create. +# Args: +# workspace: Workspace to create. - Returns: - The created workspace. - """ - return zen_store().create_workspace(workspace=workspace) +# Returns: +# The created workspace. +# """ +# return zen_store().create_workspace(workspace=workspace) @router.get( @@ -177,49 +184,49 @@ def get_workspace( return zen_store().get_workspace(workspace_name_or_id=workspace_name_or_id) -@router.put( - WORKSPACES + "/{workspace_name_or_id}", - response_model=WorkspaceResponseModel, - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def update_workspace( - workspace_name_or_id: UUID, - workspace_update: WorkspaceUpdateModel, - _: AuthContext = Security(authorize), -) -> WorkspaceResponseModel: - """Get a workspace for given name. - - # noqa: DAR401 - - Args: - workspace_name_or_id: Name or ID of the workspace to update. - workspace_update: the workspace to use to update - - Returns: - The updated workspace. - """ - return zen_store().update_workspace( - workspace_id=workspace_name_or_id, - workspace_update=workspace_update, - ) - - -@router.delete( - WORKSPACES + "/{workspace_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_workspace( - workspace_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize), -) -> None: - """Deletes a workspace. - - Args: - workspace_name_or_id: Name or ID of the workspace. - """ - zen_store().delete_workspace(workspace_name_or_id=workspace_name_or_id) +# @router.put( +# WORKSPACES + "/{workspace_name_or_id}", +# response_model=WorkspaceResponseModel, +# responses={401: error_response, 404: error_response, 422: error_response}, +# ) +# @handle_exceptions +# def update_workspace( +# workspace_name_or_id: UUID, +# workspace_update: WorkspaceUpdateModel, +# _: AuthContext = Security(authorize), +# ) -> WorkspaceResponseModel: +# """Get a workspace for given name. + +# # noqa: DAR401 + +# Args: +# workspace_name_or_id: Name or ID of the workspace to update. +# workspace_update: the workspace to use to update + +# Returns: +# The updated workspace. +# """ +# return zen_store().update_workspace( +# workspace_id=workspace_name_or_id, +# workspace_update=workspace_update, +# ) + + +# @router.delete( +# WORKSPACES + "/{workspace_name_or_id}", +# responses={401: error_response, 404: error_response, 422: error_response}, +# ) +# @handle_exceptions +# def delete_workspace( +# workspace_name_or_id: Union[str, UUID], +# _: AuthContext = Security(authorize), +# ) -> None: +# """Deletes a workspace. + +# Args: +# workspace_name_or_id: Name or ID of the workspace. +# """ +# zen_store().delete_workspace(workspace_name_or_id=workspace_name_or_id) @router.get( @@ -233,7 +240,7 @@ def list_workspace_stacks( stack_filter_model: StackFilterModel = Depends( make_dependable(StackFilterModel) ), - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> Page[StackResponseModel]: """Get stacks that are part of a specific workspace for the user. @@ -242,14 +249,18 @@ def list_workspace_stacks( Args: workspace_name_or_id: Name or ID of the workspace. stack_filter_model: Filter model used for pagination, sorting, filtering - auth_context: Authentication Context Returns: All stacks part of the specified workspace. """ workspace = zen_store().get_workspace(workspace_name_or_id) stack_filter_model.set_scope_workspace(workspace.id) - return zen_store().list_stacks(stack_filter_model=stack_filter_model) + + return verify_permissions_and_list_entities( + filter_model=stack_filter_model, + resource_type=ResourceType.STACK, + list_method=zen_store().list_stacks, + ) @router.post( @@ -261,14 +272,13 @@ def list_workspace_stacks( def create_stack( workspace_name_or_id: Union[str, UUID], stack: StackRequestModel, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> StackResponseModel: """Creates a stack for a particular workspace. Args: workspace_name_or_id: Name or ID of the workspace. stack: Stack to register. - auth_context: The authentication context. Returns: The created stack. @@ -285,13 +295,21 @@ def create_stack( f"of this endpoint `{workspace_name_or_id}` is " f"not supported." ) - if stack.user != auth_context.user.id: - raise IllegalOperationError( - "Creating stacks for a user other than yourself " - "is not supported." - ) - return zen_store().create_stack(stack=stack) + if stack.components: + components = [ + zen_store().get_stack_component(id) + for ids in stack.components.values() + for id in ids + ] + + batch_verify_permissions_for_models(components, action=Action.READ) + + return verify_permissions_and_create_entity( + request_model=stack, + resource_type=ResourceType.STACK, + create_method=zen_store().create_stack, + ) @router.get( @@ -305,7 +323,7 @@ def list_workspace_stack_components( component_filter_model: ComponentFilterModel = Depends( make_dependable(ComponentFilterModel) ), - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> Page[ComponentResponseModel]: """List stack components that are part of a specific workspace. @@ -315,15 +333,17 @@ def list_workspace_stack_components( workspace_name_or_id: Name or ID of the workspace. component_filter_model: Filter model used for pagination, sorting, filtering - auth_context: Authentication Context Returns: All stack components part of the specified workspace. """ workspace = zen_store().get_workspace(workspace_name_or_id) component_filter_model.set_scope_workspace(workspace.id) - return zen_store().list_stack_components( - component_filter_model=component_filter_model + + return verify_permissions_and_list_entities( + filter_model=component_filter_model, + resource_type=ResourceType.STACK_COMPONENT, + list_method=zen_store().list_stack_components, ) @@ -336,14 +356,13 @@ def list_workspace_stack_components( def create_stack_component( workspace_name_or_id: Union[str, UUID], component: ComponentRequestModel, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> ComponentResponseModel: """Creates a stack component. Args: workspace_name_or_id: Name or ID of the workspace. component: Stack component to register. - auth_context: Authentication context. Returns: The created stack component. @@ -361,16 +380,18 @@ def create_stack_component( f"of this endpoint `{workspace_name_or_id}` is " f"not supported." ) - if component.user != auth_context.user.id: - raise IllegalOperationError( - "Creating components for a user other than yourself " - "is not supported." - ) - # TODO: [server] if possible it should validate here that the configuration - # conforms to the flavor + if component.connector: + service_connector = zen_store().get_service_connector( + component.connector + ) + verify_permission_for_model(service_connector, action=Action.READ) - return zen_store().create_stack_component(component=component) + return verify_permissions_and_create_entity( + request_model=component, + resource_type=ResourceType.STACK_COMPONENT, + create_method=zen_store().create_stack_component, + ) @router.get( @@ -400,8 +421,11 @@ def list_workspace_pipelines( """ workspace = zen_store().get_workspace(workspace_name_or_id) pipeline_filter_model.set_scope_workspace(workspace.id) - return zen_store().list_pipelines( - pipeline_filter_model=pipeline_filter_model + + return verify_permissions_and_list_entities( + filter_model=pipeline_filter_model, + resource_type=ResourceType.PIPELINE, + list_method=zen_store().list_pipelines, ) @@ -438,13 +462,12 @@ def create_pipeline( f"of this endpoint `{workspace_name_or_id}` is " f"not supported." ) - if pipeline.user != auth_context.user.id: - raise IllegalOperationError( - "Creating pipelines for a user other than yourself " - "is not supported." - ) - return zen_store().create_pipeline(pipeline=pipeline) + return verify_permissions_and_create_entity( + request_model=pipeline, + resource_type=ResourceType.PIPELINE, + create_method=zen_store().create_pipeline, + ) @router.get( @@ -865,7 +888,12 @@ def list_workspace_code_repositories( """ workspace = zen_store().get_workspace(workspace_name_or_id) filter_model.set_scope_workspace(workspace.id) - return zen_store().list_code_repositories(filter_model=filter_model) + + return verify_permissions_and_list_entities( + filter_model=filter_model, + resource_type=ResourceType.CODE_REPOSITORY, + list_method=zen_store().list_code_repositories, + ) @router.post( @@ -877,14 +905,13 @@ def list_workspace_code_repositories( def create_code_repository( workspace_name_or_id: Union[str, UUID], code_repository: CodeRepositoryRequestModel, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> CodeRepositoryResponseModel: """Creates a code repository. Args: workspace_name_or_id: Name or ID of the workspace. code_repository: Code repository to create. - auth_context: Authentication context. Returns: The created code repository. @@ -902,13 +929,12 @@ def create_code_repository( f"of this endpoint `{workspace_name_or_id}` is " f"not supported." ) - if code_repository.user != auth_context.user.id: - raise IllegalOperationError( - "Creating code repositories for a user other than yourself " - "is not supported." - ) - return zen_store().create_code_repository(code_repository=code_repository) + return verify_permissions_and_create_entity( + request_model=code_repository, + resource_type=ResourceType.CODE_REPOSITORY, + create_method=zen_store().create_code_repository, + ) @router.get( From 5d72cc60eb1d124daf6d633350d02e0498efb428 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 27 Oct 2023 15:30:31 +0200 Subject: [PATCH 032/103] More endpoints implemented --- src/zenml/models/filter_models.py | 9 +-- src/zenml/zen_server/rbac/endpoint_utils.py | 21 +++--- src/zenml/zen_server/rbac/models.py | 4 ++ src/zenml/zen_server/rbac/utils.py | 4 +- .../zen_server/routers/models_endpoints.py | 31 +++++++-- .../routers/service_connectors_endpoints.py | 68 ++++++++++++++++--- .../routers/workspaces_endpoints.py | 43 ++++++------ 7 files changed, 129 insertions(+), 51 deletions(-) diff --git a/src/zenml/models/filter_models.py b/src/zenml/models/filter_models.py index ba2ae3e049d..4e418cc4547 100644 --- a/src/zenml/models/filter_models.py +++ b/src/zenml/models/filter_models.py @@ -22,6 +22,7 @@ Dict, List, Optional, + Set, Tuple, Type, TypeVar, @@ -299,7 +300,7 @@ class BaseFilterModel(BaseModel): default=None, description="Updated" ) - _allowed_ids: Optional[List[UUID]] = None + _allowed_ids: Optional[Set[UUID]] = None @validator("sort_by", pre=True) def validate_sort_by(cls, v: str) -> str: @@ -392,12 +393,12 @@ def sorting_params(self) -> Tuple[str, SorterOps]: return column, operator - def set_allowed_ids(self, allowed_ids: Optional[List[UUID]]) -> None: + def set_allowed_ids(self, allowed_ids: Optional[Set[UUID]]) -> None: """Set allowed IDs for the query. Args: - allowed_ids: List of IDs to limit the query to. If given, the - remaining filters will be applied to entities within this list + allowed_ids: Set of IDs to limit the query to. If given, the + remaining filters will be applied to entities within this set only. If `None`, the remaining filters will applied to all entries in the table. """ diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py index 081f1717893..00af2fabba2 100644 --- a/src/zenml/zen_server/rbac/endpoint_utils.py +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -1,5 +1,5 @@ """High-level helper functions to write endpoints with RBAC.""" -from typing import Callable, TypeVar +from typing import Callable, TypeVar, Union from uuid import UUID from pydantic import BaseModel @@ -26,6 +26,7 @@ AnyResponseModel = TypeVar("AnyResponseModel", bound=BaseResponseModel) AnyFilterModel = TypeVar("AnyFilterModel", bound=BaseFilterModel) AnyUpdateModel = TypeVar("AnyUpdateModel", bound=BaseModel) +UUIDOrStr = TypeVar("UUIDOrStr", UUID, Union[UUID, str]) def verify_permissions_and_create_entity( @@ -62,7 +63,7 @@ def verify_permissions_and_create_entity( def verify_permissions_and_get_entity( - id: UUID, get_method: Callable[[UUID], AnyResponseModel] + id: UUIDOrStr, get_method: Callable[[UUIDOrStr], AnyResponseModel] ) -> AnyResponseModel: """Verify permissions and fetch an entity. @@ -100,10 +101,10 @@ def verify_permissions_and_list_entities( def verify_permissions_and_update_entity( - id: UUID, + id: UUIDOrStr, update_model: AnyUpdateModel, - get_method: Callable[[UUID], AnyResponseModel], - update_method: Callable[[UUID, AnyUpdateModel], AnyResponseModel], + get_method: Callable[[UUIDOrStr], AnyResponseModel], + update_method: Callable[[UUIDOrStr, AnyUpdateModel], AnyResponseModel], ) -> AnyResponseModel: """Verify permissions and update an entity. @@ -118,14 +119,14 @@ def verify_permissions_and_update_entity( """ model = get_method(id) verify_permission_for_model(model, action=Action.UPDATE) - updated_model = update_method(id, update_model) + updated_model = update_method(model.id, update_model) return dehydrate_response_model(updated_model) def verify_permissions_and_delete_entity( - id: UUID, - get_method: Callable[[UUID], AnyResponseModel], - delete_method: Callable[[UUID], None], + id: UUIDOrStr, + get_method: Callable[[UUIDOrStr], AnyResponseModel], + delete_method: Callable[[UUIDOrStr], None], ) -> None: """Verify permissions and delete an entity. @@ -136,4 +137,4 @@ def verify_permissions_and_delete_entity( """ model = get_method(id) verify_permission_for_model(model, action=Action.DELETE) - delete_method(id) + delete_method(model.id) diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index 57d10c2ae7d..aab752d7e14 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -29,6 +29,10 @@ class Action(StrEnum): UPDATE = "update" DELETE = "delete" + # Service connectors + CLIENT = "client" # TODO: rename + READ_SECRET_VALUE = "read_secret_value" + class ResourceType(StrEnum): """Resource types of the server API.""" diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index b13ff45ec99..53842d6a645 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -289,7 +289,7 @@ def verify_permission( def get_allowed_resource_ids( resource_type: str, action: str = Action.READ, -) -> Optional[List[UUID]]: +) -> Optional[Set[UUID]]: """Get all resource IDs of a resource type that a user can access. Args: @@ -318,7 +318,7 @@ def get_allowed_resource_ids( if has_full_resource_access: return None - return [UUID(id) for id in allowed_ids] + return {UUID(id) for id in allowed_ids} def get_resource_for_model(model: "BaseResponseModel") -> Optional[Resource]: diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 9699e8f29e0..5d99580d837 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -43,6 +43,13 @@ from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import ResourceType from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -82,8 +89,10 @@ def list_models( Returns: The models according to query filters. """ - return zen_store().list_models( - model_filter_model=model_filter_model, + return verify_permissions_and_list_entities( + filter_model=model_filter_model, + resource_type=ResourceType.MODEL, + list_method=zen_store().list_models, ) @@ -105,7 +114,9 @@ def get_model( Returns: The model with the given name or ID. """ - return zen_store().get_model(model_name_or_id) + return verify_permissions_and_get_entity( + id=model_name_or_id, get_method=zen_store().get_model + ) @router.put( @@ -128,9 +139,11 @@ def update_model( Returns: The updated model. """ - return zen_store().update_model( - model_id=model_id, - model_update=model_update, + return verify_permissions_and_update_entity( + id=model_id, + update_model=model_update, + get_method=zen_store().get_model, + update_method=zen_store().update_model, ) @@ -148,7 +161,11 @@ def delete_model( Args: model_name_or_id: The name or ID of the model to delete. """ - zen_store().delete_model(model_name_or_id) + verify_permissions_and_delete_entity( + id=model_name_or_id, + get_method=zen_store().get_model, + delete_method=zen_store().delete_model, + ) ################# diff --git a/src/zenml/zen_server/routers/service_connectors_endpoints.py b/src/zenml/zen_server/routers/service_connectors_endpoints.py index b739427c99c..4eac0f10200 100644 --- a/src/zenml/zen_server/routers/service_connectors_endpoints.py +++ b/src/zenml/zen_server/routers/service_connectors_endpoints.py @@ -36,6 +36,18 @@ from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_delete_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import ( + get_allowed_resource_ids, + has_permissions_for_model, + verify_permission, + verify_permission_for_model, +) from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -78,14 +90,30 @@ def list_service_connectors( Returns: Page with list of service connectors for a specific type. """ - connectors = zen_store().list_service_connectors( - filter_model=connector_filter_model + connectors = verify_permissions_and_list_entities( + filter_model=connector_filter_model, + resource_type=ResourceType.SERVICE_CONNECTOR, + list_method=zen_store().list_service_connectors, ) if expand_secrets: + # This will be `None` if the user is allowed to read secret values + # for all service connectors + allowed_ids = get_allowed_resource_ids( + resource_type=ResourceType.SERVICE_CONNECTOR, + action=Action.READ_SECRET_VALUE, + ) + for connector in connectors.items: if not connector.secret_id: continue + + if allowed_ids and connector.id not in allowed_ids: + # The user is not allowed to read secret values for this + # connector. We don't raise an exception here but don't include + # the secret values + continue + secret = zen_store().get_secret(secret_id=connector.secret_id) # Update the connector configuration with the secret. @@ -115,8 +143,15 @@ def get_service_connector( The requested service connector. """ connector = zen_store().get_service_connector(connector_id) - - if expand_secrets and connector.secret_id: + verify_permission_for_model(connector, action=Action.READ) + + if ( + expand_secrets + and connector.secret_id + and has_permissions_for_model( + connector, action=Action.READ_SECRET_VALUE + ) + ): secret = zen_store().get_secret(secret_id=connector.secret_id) # Update the connector configuration with the secret. @@ -145,9 +180,11 @@ def update_service_connector( Returns: Updated service connector. """ - return zen_store().update_service_connector( - service_connector_id=connector_id, - update=connector_update, + return verify_permissions_and_update_entity( + id=connector_id, + update_model=connector_update, + get_method=zen_store().get_service_connector, + update_method=zen_store().update_service_connector, ) @@ -165,7 +202,11 @@ def delete_service_connector( Args: connector_id: ID of the service connector. """ - zen_store().delete_service_connector(connector_id) + verify_permissions_and_delete_entity( + id=connector_id, + get_method=zen_store().get_service_connector, + delete_method=zen_store().delete_service_connector, + ) @router.post( @@ -194,6 +235,10 @@ def validate_and_verify_service_connector_config( The list of resources that the service connector configuration has access to. """ + verify_permission( + resource_type=ResourceType.SERVICE_CONNECTOR, action=Action.CREATE + ) + return zen_store().verify_service_connector_config( service_connector=connector, list_resources=list_resources, @@ -231,6 +276,9 @@ def validate_and_verify_service_connector( The list of resources that the service connector has access to, scoped to the supplied resource type and ID, if provided. """ + connector = zen_store().get_service_connector(connector_id) + verify_permission_for_model(model=connector, action=Action.READ) + return zen_store().verify_service_connector( service_connector_id=connector_id, resource_type=resource_type, @@ -266,6 +314,10 @@ def get_service_connector_client( A service connector client that can be used to access the given resource. """ + connector = zen_store().get_service_connector(connector_id) + verify_permission_for_model(model=connector, action=Action.READ) + verify_permission_for_model(model=connector, action=Action.CLIENT) + return zen_store().get_service_connector_client( service_connector_id=connector_id, resource_type=resource_type, diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 037bc00374c..83fd7db5f9c 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -980,7 +980,7 @@ def list_workspace_service_connectors( connector_filter_model: ServiceConnectorFilterModel = Depends( make_dependable(ServiceConnectorFilterModel) ), - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> Page[ServiceConnectorResponseModel]: """List service connectors that are part of a specific workspace. @@ -990,15 +990,17 @@ def list_workspace_service_connectors( workspace_name_or_id: Name or ID of the workspace. connector_filter_model: Filter model used for pagination, sorting, filtering - auth_context: Authentication Context Returns: All service connectors part of the specified workspace. """ workspace = zen_store().get_workspace(workspace_name_or_id) connector_filter_model.set_scope_workspace(workspace.id) - return zen_store().list_service_connectors( - filter_model=connector_filter_model + + return verify_permissions_and_list_entities( + filter_model=connector_filter_model, + resource_type=ResourceType.SERVICE_CONNECTOR, + list_method=zen_store().list_service_connectors, ) @@ -1036,13 +1038,12 @@ def create_service_connector( f"of this endpoint `{workspace_name_or_id}` is " f"not supported." ) - if connector.user != auth_context.user.id: - raise IllegalOperationError( - "Creating connectors for a user other than yourself " - "is not supported." - ) - return zen_store().create_service_connector(service_connector=connector) + return verify_permissions_and_create_entity( + request_model=connector, + resource_type=ResourceType.SERVICE_CONNECTOR, + create_method=zen_store().create_service_connector, + ) @router.get( @@ -1092,14 +1093,13 @@ def list_service_connector_resources( def create_model( workspace_name_or_id: Union[str, UUID], model: ModelRequestModel, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> ModelResponseModel: """Create a new model. Args: workspace_name_or_id: Name or ID of the workspace. model: The model to create. - auth_context: Authentication context. Returns: The created model. @@ -1117,12 +1117,12 @@ def create_model( f"of this endpoint `{workspace_name_or_id}` is " f"not supported." ) - if model.user != auth_context.user.id: - raise IllegalOperationError( - "Creating models for a user other than yourself " - "is not supported." - ) - return zen_store().create_model(model) + + return verify_permissions_and_create_entity( + request_model=model, + resource_type=ResourceType.MODEL, + create_method=zen_store().create_model, + ) @router.get( @@ -1151,8 +1151,11 @@ def list_workspace_models( """ workspace_id = zen_store().get_workspace(workspace_name_or_id).id model_filter_model.set_scope_workspace(workspace_id) - return zen_store().list_models( - model_filter_model=model_filter_model, + + return verify_permissions_and_list_entities( + filter_model=model_filter_model, + resource_type=ResourceType.MODEL, + list_method=zen_store().list_models, ) From 6aac95785395de6e6ab856721b0bfa08e95afe35 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 27 Oct 2023 15:52:04 +0200 Subject: [PATCH 033/103] Secrets endpoints --- .../zen_server/routers/secrets_endpoints.py | 56 ++++++++++++++++--- 1 file changed, 47 insertions(+), 9 deletions(-) diff --git a/src/zenml/zen_server/routers/secrets_endpoints.py b/src/zenml/zen_server/routers/secrets_endpoints.py index 694d4122052..e4a9e611b98 100644 --- a/src/zenml/zen_server/routers/secrets_endpoints.py +++ b/src/zenml/zen_server/routers/secrets_endpoints.py @@ -26,6 +26,17 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import ( + get_allowed_resource_ids, + has_permissions_for_model, +) from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -60,9 +71,25 @@ def list_secrets( Returns: List of secret objects. """ - # TODO: we should probably have separate permissions here for reading the - # secret and its content - return zen_store().list_secrets(secret_filter_model=secret_filter_model) + secrets = verify_permissions_and_list_entities( + filter_model=secret_filter_model, + resource_type=ResourceType.SECRET, + list_method=zen_store().list_secrets, + ) + + # This will be `None` if the user is allowed to read secret values + # for all secrets + allowed_ids = get_allowed_resource_ids( + resource_type=ResourceType.SECRET, + action=Action.READ_SECRET_VALUE, + ) + + if allowed_ids is not None: + for secret in secrets.items: + if secret.id not in allowed_ids: + secret.remove_secrets() + + return secrets @router.get( @@ -83,9 +110,13 @@ def get_secret( Returns: A specific secret object. """ - # TODO: we should probably have separate permissions here for reading the - # secret and its content - return zen_store().get_secret(secret_id=secret_id) + secret = verify_permissions_and_get_entity( + id=secret_id, get_method=zen_store().get_secret + ) + if not has_permissions_for_model(secret, action=Action.READ_SECRET_VALUE): + secret.remove_secrets() + + return secret @router.put( @@ -120,8 +151,11 @@ def update_secret( if key not in secret_update.values: secret_update.values[key] = None - return zen_store().update_secret( - secret_id=secret_id, secret_update=secret_update + return verify_permissions_and_update_entity( + id=secret_id, + update_model=secret_update, + get_method=zen_store().get_secret, + update_method=zen_store().update_secret, ) @@ -139,4 +173,8 @@ def delete_secret( Args: secret_id: ID of the secret to delete. """ - zen_store().delete_secret(secret_id=secret_id) + verify_permissions_and_delete_entity( + id=secret_id, + get_method=zen_store().get_secret, + delete_method=zen_store().delete_secret, + ) From af7b0ca088b625c81befb56b7ec23b23fc245d33 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 30 Oct 2023 12:03:03 +0100 Subject: [PATCH 034/103] Custom cache key for local artifact store --- src/zenml/artifact_stores/base_artifact_store.py | 12 ++++++++++++ src/zenml/artifact_stores/local_artifact_store.py | 12 ++++++++++++ src/zenml/orchestrators/cache_utils.py | 3 +++ 3 files changed, 27 insertions(+) diff --git a/src/zenml/artifact_stores/base_artifact_store.py b/src/zenml/artifact_stores/base_artifact_store.py index d23f0a02dd1..d5a0c61e763 100644 --- a/src/zenml/artifact_stores/base_artifact_store.py +++ b/src/zenml/artifact_stores/base_artifact_store.py @@ -186,6 +186,18 @@ def path(self) -> str: """ return self.config.path + @property + def custom_cache_key(self) -> Optional[bytes]: + """Custom cache key. + + Any artifact store can override this property in case they need + additional control over the caching behavior. + + Returns: + Custom cache key. + """ + return None + # --- User interface --- @abstractmethod def open(self, name: PathType, mode: str = "r") -> Any: diff --git a/src/zenml/artifact_stores/local_artifact_store.py b/src/zenml/artifact_stores/local_artifact_store.py index 3d600c30273..690e4c6c326 100644 --- a/src/zenml/artifact_stores/local_artifact_store.py +++ b/src/zenml/artifact_stores/local_artifact_store.py @@ -133,6 +133,18 @@ def local_path(self) -> Optional[str]: """ return self.path + @property + def custom_cache_key(self) -> Optional[bytes]: + """Custom cache key. + + The client ID is returned here to invalidate caching when using the same + local artifact store on multiple client machines. + + Returns: + Custom cache key. + """ + return GlobalConfiguration().user_id.bytes + class LocalArtifactStoreFlavor(BaseArtifactStoreFlavor): """Class for the `LocalArtifactStoreFlavor`.""" diff --git a/src/zenml/orchestrators/cache_utils.py b/src/zenml/orchestrators/cache_utils.py index b8b010529c1..1c6d952f208 100644 --- a/src/zenml/orchestrators/cache_utils.py +++ b/src/zenml/orchestrators/cache_utils.py @@ -69,6 +69,9 @@ def generate_cache_key( hash_.update(artifact_store.id.bytes) hash_.update(artifact_store.path.encode()) + if artifact_store.custom_cache_key: + hash_.update(artifact_store.custom_cache_key) + # Step source. This currently only uses the string representation of the # source (e.g. my_module.step_class) instead of the full source to keep # the caching behavior of previous versions and to not invalidate caching From ccfbd3b3e0646f2ea61a9aebba48030d39f9be18 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 30 Oct 2023 17:24:33 +0100 Subject: [PATCH 035/103] WIP single default stack --- src/zenml/client.py | 12 --- src/zenml/models/base_models.py | 17 ++++ src/zenml/zen_server/rbac/utils.py | 15 ++-- src/zenml/zen_stores/base_zen_store.py | 78 +++++++------------ .../7500f434b71c_remove_shared_columns.py | 58 ++++++++++++-- src/zenml/zen_stores/rest_zen_store.py | 11 +++ src/zenml/zen_stores/sql_zen_store.py | 34 ++------ 7 files changed, 125 insertions(+), 100 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index 540a0dc189e..e43e5915074 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -1026,13 +1026,6 @@ def get_stack( Returns: The stack. """ - if name_id_or_prefix == "default": - name_id_or_prefix = ( - self.zen_store._get_default_stack_and_component_name( - user_id=self.active_user.id - ) - ) - if name_id_or_prefix is not None: return self._get_entity_by_id_or_name_or_prefix( get_method=self.zen_store.get_stack, @@ -1401,11 +1394,6 @@ def get_stack_component( KeyError: If no name_id_or_prefix is provided and no such component is part of the active stack. """ - if name_id_or_prefix == "default": - self.zen_store._get_default_stack_and_component_name( - user_id=self.active_user.id - ) - # If no `name_id_or_prefix` provided, try to get the active component. if not name_id_or_prefix: components = self.active_stack_model.components.get( diff --git a/src/zenml/models/base_models.py b/src/zenml/models/base_models.py index 7ab2f9b7d23..534296a125c 100644 --- a/src/zenml/models/base_models.py +++ b/src/zenml/models/base_models.py @@ -269,3 +269,20 @@ def update_model(_cls: Type[T]) -> Type[T]: value.allow_none = True return _cls + + +def internal_model(_cls: Type[T]) -> Type[T]: + """Convert a request model to an internal model. + + Args: + _cls: The class to decorate + + Returns: + The decorated class. + """ + if user_field := _cls.__fields__.get("user", None): + user_field.required = False + user_field.allow_none = True + user_field.default = None + + return _cls diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 53842d6a645..29f3ebcea6a 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -318,6 +318,7 @@ def get_allowed_resource_ids( if has_full_resource_access: return None + # TODO: this does not account for ownership right now return {UUID(id) for id in allowed_ids} @@ -390,13 +391,13 @@ def is_owned_by_authenticated_user(model: "BaseResponseModel") -> bool: auth_context = get_auth_context() assert auth_context - if ( - isinstance(model, UserScopedResponseModel) - and model.user - and model.user.id == auth_context.user.id - ): - # User is the owner of the model - return True + if isinstance(model, UserScopedResponseModel): + if model.user: + return model.user.id == auth_context.user.id + else: + # The model is server-owned and for RBAC purposes we consider + # every user to be the owner of it + return True return False diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index 1b477d7e4f9..95c3cdd1837 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -80,7 +80,7 @@ DEFAULT_USERNAME = "default" DEFAULT_PASSWORD = "" DEFAULT_WORKSPACE_NAME = "default" -DEFAULT_STACK_AND_COMPONENT_NAME_PREFIX = "default" +DEFAULT_STACK_AND_COMPONENT_NAME = "default" @make_proxy_class(SecretsStoreInterface, "_secrets_store") @@ -299,17 +299,16 @@ def _initialize_database(self) -> None: default_workspace = self._create_default_workspace() config = ServerConfiguration.get_server_config() - # If the auth scheme is external, don't create the default user and - # stack + # If the auth scheme is external, don't create the default user if config.auth_scheme != AuthScheme.EXTERNAL: try: - default_user = self._default_user + _ = self._default_user except KeyError: - default_user = self._create_default_user() - self._get_or_create_default_stack( - workspace=default_workspace, - user_id=default_user.id, - ) + self._create_default_user() + + self._get_or_create_default_stack( + workspace=default_workspace, + ) @property def url(self) -> str: @@ -454,17 +453,14 @@ def is_local_store(self) -> bool: def _get_or_create_default_stack( self, workspace: "WorkspaceResponseModel", - user_id: Optional[UUID] = None, ) -> "StackResponseModel": try: return self._get_default_stack( workspace_id=workspace.id, - user_id=user_id or self.get_user().id, ) except KeyError: return self._create_default_stack( workspace_id=workspace.id, - user_id=user_id or self.get_user().id, ) def _get_or_create_default_workspace(self) -> "WorkspaceResponseModel": @@ -516,7 +512,6 @@ def _trigger_event(self, event: StoreEvent, **kwargs: Any) -> None: def _create_default_stack( self, workspace_id: UUID, - user_id: UUID, ) -> StackResponseModel: """Create the default stack components and stack. @@ -526,28 +521,28 @@ def _create_default_stack( Args: workspace_id: ID of the workspace to which the stack belongs. - user_id: ID of the user that owns the stack. Returns: The model of the created default stack. """ with analytics_disabler(): workspace = self.get_workspace(workspace_name_or_id=workspace_id) - user = self.get_user(user_name_or_id=user_id) logger.info( - f"Creating default stack for user '{user.name}' in workspace " - f"{workspace.name}..." + f"Creating default stack in workspace {workspace.name}..." ) + from zenml.models.base_models import internal_model - name = self._get_default_stack_and_component_name(user_id=user_id) + @internal_model + class InternalComponentRequestModel(ComponentRequestModel): + pass # Register the default orchestrator orchestrator = self.create_stack_component( - component=ComponentRequestModel( - user=user.id, + component=InternalComponentRequestModel( + user=None, workspace=workspace.id, - name=name, + name=DEFAULT_STACK_AND_COMPONENT_NAME, type=StackComponentType.ORCHESTRATOR, flavor="local", configuration={}, @@ -556,10 +551,10 @@ def _create_default_stack( # Register the default artifact store artifact_store = self.create_stack_component( - component=ComponentRequestModel( - user=user.id, + component=InternalComponentRequestModel( + user=None, workspace=workspace.id, - name=name, + name=DEFAULT_STACK_AND_COMPONENT_NAME, type=StackComponentType.ARTIFACT_STORE, flavor="local", configuration={}, @@ -569,57 +564,44 @@ def _create_default_stack( components = { c.type: [c.id] for c in [orchestrator, artifact_store] } + + @internal_model + class InternalStackRequestModel(StackRequestModel): + pass + # Register the default stack - stack = StackRequestModel( - name=name, + stack = InternalStackRequestModel( + name=DEFAULT_STACK_AND_COMPONENT_NAME, components=components, workspace=workspace.id, - user=user.id, + user=None, ) return self.create_stack(stack=stack) - def _get_default_stack_and_component_name(self, user_id: UUID) -> str: - """Get the name for the default stack and its components. - - Args: - user_id: ID of the user to which the default stack belongs. - - Returns: - The default stack/component name. - """ - return f"{DEFAULT_STACK_AND_COMPONENT_NAME_PREFIX}-{user_id}" - def _get_default_stack( self, workspace_id: UUID, - user_id: UUID, ) -> StackResponseModel: """Get the default stack for a user in a workspace. Args: workspace_id: ID of the workspace. - user_id: ID of the user. Returns: - The default stack in the workspace owned by the supplied user. + The default stack in the workspace. Raises: KeyError: if the workspace or default stack doesn't exist. """ - stack_name = self._get_default_stack_and_component_name( - user_id=user_id - ) default_stacks = self.list_stacks( StackFilterModel( workspace_id=workspace_id, - user_id=user_id, - name=stack_name, + name=DEFAULT_STACK_AND_COMPONENT_NAME, ) ) if default_stacks.total == 0: raise KeyError( - f"No default stack found for user {str(user_id)} in " - f"workspace {str(workspace_id)}" + f"No default stack found in workspace {workspace_id}." ) return default_stacks.items[0] diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py index 0353e4798af..201af954acb 100644 --- a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -5,6 +5,8 @@ Create Date: 2023-10-16 15:15:34.865337 """ +from uuid import uuid4 + import sqlalchemy as sa from alembic import op @@ -15,7 +17,7 @@ depends_on = None -def _rename_default_entities(table: sa.Table) -> None: +def _rename_old_default_entities(table: sa.Table) -> None: """Include owner id in the name of default entities. Args: @@ -39,14 +41,59 @@ def _rename_default_entities(table: sa.Table) -> None: def resolve_duplicate_names() -> None: """Resolve duplicate names for shareable entities.""" + connection = op.get_bind() + meta = sa.MetaData(bind=op.get_bind()) - meta.reflect(only=("stack", "stack_component", "service_connector")) + meta.reflect( + only=("stack", "stack_component", "service_connector", "workspace") + ) stack_table = sa.Table("stack", meta) stack_component_table = sa.Table("stack_component", meta) - - _rename_default_entities(stack_table) - _rename_default_entities(stack_component_table) + workspace_table = sa.Table("workspace", meta) + + _rename_old_default_entities(stack_table) + _rename_old_default_entities(stack_component_table) + + workspace_query = sa.select(workspace_table.c.id) + + stack_components = [] + stacks = [] + for workspace_id in connection.execute(workspace_query).fetchall(): + artifact_store_id = str(uuid4()).replace("-", "") + default_artifact_store = { + "id": artifact_store_id, + "workspace": workspace_id, + "name": "default", + "type": "artifact_store", + "flavor": "local", + "configuration": {}, + } + orchestrator_id = str(uuid4()).replace("-", "") + default_orchestrator = { + "id": orchestrator_id, + "workspace": workspace_id, + "name": "default", + "type": "orchestrator", + "flavor": "local", + "configuration": {}, + } + + default_stack = { + "id": str(uuid4()).replace("-", ""), + "workspace": workspace_id, + "name": "default", + "components": { + "artifact_store": [artifact_store_id], + "orchestrator": [orchestrator_id], + }, + } + stack_components.append(default_artifact_store) + stack_components.append(default_orchestrator) + stacks.append(default_stack) + + op.bulk_insert(stack_component_table, rows=stack_components) + op.bulk_insert(stack_table, rows=stacks) service_connector_table = sa.Table("service_connector", meta) query = sa.select( @@ -55,7 +102,6 @@ def resolve_duplicate_names() -> None: service_connector_table.c.user_id, ) - connection = op.get_bind() names = set() for id, name, user_id in connection.execute(query).fetchall(): if name in names: diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index df738a0142c..6036a26670a 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -405,6 +405,17 @@ def _initialize_database(self) -> None: """Initialize the database.""" # don't do anything for a REST store + def _get_or_create_default_stack( + self, + workspace: "WorkspaceResponseModel", + ) -> "StackResponseModel": + # Overwrite this function so we don't try to create a default stack + # client-side. The default stack can't be deleted/modified so the + # fetching should always work. + return self._get_default_stack( + workspace_id=workspace.id, + ) + # ==================================== # ZenML Store interface implementation # ==================================== diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index e7b09b9dbf6..1141961a2ba 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -178,7 +178,7 @@ ) from zenml.utils.string_utils import random_str from zenml.zen_stores.base_zen_store import ( - DEFAULT_STACK_AND_COMPONENT_NAME_PREFIX, + DEFAULT_STACK_AND_COMPONENT_NAME, BaseZenStore, ) from zenml.zen_stores.enums import StoreEvent @@ -1142,13 +1142,7 @@ def update_stack( f"Unable to update stack with id '{stack_id}': Found no" f"existing stack with this id." ) - if ( - existing_stack.user_id - and existing_stack.name - == self._get_default_stack_and_component_name( - existing_stack.user_id - ) - ): + if existing_stack.name == DEFAULT_STACK_AND_COMPONENT_NAME: raise IllegalOperationError( "The default stack cannot be modified." ) @@ -1203,13 +1197,7 @@ def delete_stack(self, stack_id: UUID) -> None: if stack is None: raise KeyError(f"Stack with ID {stack_id} not found.") - if ( - stack.user_id - and stack.name - == self._get_default_stack_and_component_name( - user_id=stack.user_id - ) - ): + if stack.name == DEFAULT_STACK_AND_COMPONENT_NAME: raise IllegalOperationError( "The default stack cannot be deleted." ) @@ -1257,7 +1245,7 @@ def _fail_if_stack_name_reserved(self, stack_name: str) -> None: Raises: IllegalOperationError: If the stack name is reserved. """ - if stack_name == DEFAULT_STACK_AND_COMPONENT_NAME_PREFIX: + if stack_name == DEFAULT_STACK_AND_COMPONENT_NAME: raise IllegalOperationError( f"Unable to register stack with reserved name '{stack_name}'." ) @@ -1430,11 +1418,7 @@ def update_stack_component( ) if ( - existing_component.user_id - and existing_component.name - == self._get_default_stack_and_component_name( - user_id=existing_component.user_id - ) + existing_component.name == DEFAULT_STACK_AND_COMPONENT_NAME and existing_component.type in [ StackComponentType.ORCHESTRATOR, @@ -1505,11 +1489,7 @@ def delete_stack_component(self, component_id: UUID) -> None: if stack_component is None: raise KeyError(f"Stack with ID {component_id} not found.") if ( - stack_component.user_id - and stack_component.name - == self._get_default_stack_and_component_name( - user_id=stack_component.user_id - ) + stack_component.name == DEFAULT_STACK_AND_COMPONENT_NAME and stack_component.type in [ StackComponentType.ORCHESTRATOR, @@ -1581,7 +1561,7 @@ def _fail_if_component_name_reserved(self, component_name: str) -> None: Raises: IllegalOperationError: If the component name is reserved. """ - if component_name == DEFAULT_STACK_AND_COMPONENT_NAME_PREFIX: + if component_name == DEFAULT_STACK_AND_COMPONENT_NAME: raise IllegalOperationError( f"Unable to register component with reserved name " f"'{component_name}'." From d519c1739d01d7f4a841b13a10171842e1e92154 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 30 Oct 2023 17:28:08 +0100 Subject: [PATCH 036/103] Fix config --- .../7500f434b71c_remove_shared_columns.py | 25 +++++++++--- src/zenml/zen_stores/sql_zen_store.py | 39 ------------------- 2 files changed, 19 insertions(+), 45 deletions(-) diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py index 201af954acb..a02d50385d5 100644 --- a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -5,6 +5,8 @@ Create Date: 2023-10-16 15:15:34.865337 """ +import base64 +from datetime import datetime from uuid import uuid4 import sqlalchemy as sa @@ -56,37 +58,48 @@ def resolve_duplicate_names() -> None: _rename_old_default_entities(stack_component_table) workspace_query = sa.select(workspace_table.c.id) + utcnow = datetime.utcnow() stack_components = [] stacks = [] - for workspace_id in connection.execute(workspace_query).fetchall(): + for row in connection.execute(workspace_query).fetchall(): + workspace_id = row[0] artifact_store_id = str(uuid4()).replace("-", "") default_artifact_store = { "id": artifact_store_id, - "workspace": workspace_id, + "workspace_id": workspace_id, "name": "default", "type": "artifact_store", "flavor": "local", - "configuration": {}, + "configuration": base64.b64encode("{}".encode("utf-8")), + "is_shared": True, + "created": utcnow, + "updated": utcnow, } orchestrator_id = str(uuid4()).replace("-", "") default_orchestrator = { "id": orchestrator_id, - "workspace": workspace_id, + "workspace_id": workspace_id, "name": "default", "type": "orchestrator", "flavor": "local", - "configuration": {}, + "configuration": base64.b64encode("{}".encode("utf-8")), + "is_shared": True, + "created": utcnow, + "updated": utcnow, } default_stack = { "id": str(uuid4()).replace("-", ""), - "workspace": workspace_id, + "workspace_id": workspace_id, "name": "default", "components": { "artifact_store": [artifact_store_id], "orchestrator": [orchestrator_id], }, + "is_shared": True, + "created": utcnow, + "updated": utcnow, } stack_components.append(default_artifact_store) stack_components.append(default_orchestrator) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 1141961a2ba..1cfc8108a4a 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -1015,7 +1015,6 @@ def create_stack(self, stack: StackRequestModel) -> StackResponseModel: The registered stack. """ with Session(self.engine) as session: - self._fail_if_stack_name_reserved(stack_name=stack.name) self._fail_if_stack_with_name_exists(stack=stack, session=session) # Get the Schemas of all components mentioned @@ -1150,9 +1149,6 @@ def update_stack( # with that name if stack_update.name: if existing_stack.name != stack_update.name: - self._fail_if_stack_name_reserved( - stack_name=stack_update.name - ) self._fail_if_stack_with_name_exists( stack=stack_update, session=session ) @@ -1236,20 +1232,6 @@ def _fail_if_stack_with_name_exists( f"name in the active workspace '{workspace.name}'." ) - def _fail_if_stack_name_reserved(self, stack_name: str) -> None: - """Raise an exception if the stack name is reserved. - - Args: - stack_name: The stack name. - - Raises: - IllegalOperationError: If the stack name is reserved. - """ - if stack_name == DEFAULT_STACK_AND_COMPONENT_NAME: - raise IllegalOperationError( - f"Unable to register stack with reserved name '{stack_name}'." - ) - # ---------------- # Stack components # ---------------- @@ -1272,9 +1254,6 @@ def create_stack_component( connector. """ with Session(self.engine) as session: - self._fail_if_component_name_reserved( - component_name=component.name - ) self._fail_if_component_with_name_type_exists( name=component.name, component_type=component.type, @@ -1433,9 +1412,6 @@ def update_stack_component( # type already exists with that name if component_update.name: if existing_component.name != component_update.name: - self._fail_if_component_name_reserved( - component_name=component_update.name - ) self._fail_if_component_with_name_type_exists( name=component_update.name, component_type=existing_component.type, @@ -1552,21 +1528,6 @@ def _fail_if_component_with_name_type_exists( f" workspace '{existing_domain_component.workspace.name}'." ) - def _fail_if_component_name_reserved(self, component_name: str) -> None: - """Raise an exception if the component name is reserved. - - Args: - component_name: The component name. - - Raises: - IllegalOperationError: If the component name is reserved. - """ - if component_name == DEFAULT_STACK_AND_COMPONENT_NAME: - raise IllegalOperationError( - f"Unable to register component with reserved name " - f"'{component_name}'." - ) - # ----------------------- # Stack component flavors # ----------------------- From 0ecba873a89953bc0e9579d247cde0e54497170c Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 30 Oct 2023 18:13:56 +0100 Subject: [PATCH 037/103] Respect ownership in list endpoints --- src/zenml/models/filter_models.py | 28 +++++++++++---- src/zenml/zen_server/rbac/endpoint_utils.py | 7 +++- src/zenml/zen_server/rbac/utils.py | 1 - .../routers/service_connectors_endpoints.py | 2 ++ .../zen_server/routers/users_endpoints.py | 35 +------------------ 5 files changed, 31 insertions(+), 42 deletions(-) diff --git a/src/zenml/models/filter_models.py b/src/zenml/models/filter_models.py index 4e418cc4547..f2af582c161 100644 --- a/src/zenml/models/filter_models.py +++ b/src/zenml/models/filter_models.py @@ -300,7 +300,8 @@ class BaseFilterModel(BaseModel): default=None, description="Updated" ) - _allowed_ids: Optional[Set[UUID]] = None + _rbac_allowed_ids: Optional[Set[UUID]] = None + _rbac_user_id: Optional[UUID] = None @validator("sort_by", pre=True) def validate_sort_by(cls, v: str) -> str: @@ -393,16 +394,21 @@ def sorting_params(self) -> Tuple[str, SorterOps]: return column, operator - def set_allowed_ids(self, allowed_ids: Optional[Set[UUID]]) -> None: - """Set allowed IDs for the query. + def set_rbac_allowed_ids_and_user( + self, allowed_ids: Optional[Set[UUID]], user_id: Optional[UUID] + ) -> None: + """Set allowed IDs and user ID for the query. Args: allowed_ids: Set of IDs to limit the query to. If given, the remaining filters will be applied to entities within this set only. If `None`, the remaining filters will applied to all entries in the table. + user_id: ID of the authenticated user. If given, all entities owned + by this user will be included in addition to the `allowed_ids`. """ - self._allowed_ids = allowed_ids + self._rbac_allowed_ids = allowed_ids + self._rbac_user_id = user_id @classmethod def _generate_filter_list(cls, values: Dict[str, Any]) -> List[Filter]: @@ -768,8 +774,18 @@ def apply_filter( Returns: The query with filter applied. """ - if self._allowed_ids is not None: - query = query.where(table.id.in_(self._allowed_ids)) # type: ignore[attr-defined] + from sqlmodel import or_ + + if self._rbac_allowed_ids is not None: + if self._rbac_user_id and hasattr(table, "user"): + query = query.where( + or_( + table.id.in_(self._rbac_allowed_ids), # type: ignore[attr-defined] + getattr(table, "user") == self._rbac_user_id, + ) + ) + else: + query = query.where(table.id.in_(self._rbac_allowed_ids)) # type: ignore[attr-defined] filters = self.generate_filter(table=table) diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py index 00af2fabba2..6ced02d7ee4 100644 --- a/src/zenml/zen_server/rbac/endpoint_utils.py +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -94,8 +94,13 @@ def verify_permissions_and_list_entities( Returns: A page of entity models. """ + auth_context = get_auth_context() + assert auth_context + allowed_ids = get_allowed_resource_ids(resource_type=resource_type) - filter_model.set_allowed_ids(allowed_ids) + filter_model.set_rbac_allowed_ids_and_user( + allowed_ids=allowed_ids, user_id=auth_context.user.id + ) page = list_method(filter_model) return dehydrate_page(page) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 29f3ebcea6a..dd9e37ceb44 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -318,7 +318,6 @@ def get_allowed_resource_ids( if has_full_resource_access: return None - # TODO: this does not account for ownership right now return {UUID(id) for id in allowed_ids} diff --git a/src/zenml/zen_server/routers/service_connectors_endpoints.py b/src/zenml/zen_server/routers/service_connectors_endpoints.py index 4eac0f10200..2621008df33 100644 --- a/src/zenml/zen_server/routers/service_connectors_endpoints.py +++ b/src/zenml/zen_server/routers/service_connectors_endpoints.py @@ -108,6 +108,8 @@ def list_service_connectors( if not connector.secret_id: continue + # TODO: check for ownership. Can I always read the secret of a service connector I own? + # What if someone updates the secret? if allowed_ids and connector.id not in allowed_ids: # The user is not allowed to read secret values for this # connector. We don't raise an exception here but don't include diff --git a/src/zenml/zen_server/routers/users_endpoints.py b/src/zenml/zen_server/routers/users_endpoints.py index bac42fdbec4..cb29bfc4d21 100644 --- a/src/zenml/zen_server/routers/users_endpoints.py +++ b/src/zenml/zen_server/routers/users_endpoints.py @@ -28,7 +28,7 @@ VERSION_1, ) from zenml.enums import AuthScheme -from zenml.exceptions import AuthorizationException, IllegalOperationError +from zenml.exceptions import AuthorizationException from zenml.logger import get_logger from zenml.models import ( UserFilterModel, @@ -271,39 +271,6 @@ def deactivate_user( user.activation_token = token return user - @router.delete( - "/{user_name_or_id}", - responses={ - 401: error_response, - 404: error_response, - 422: error_response, - }, - ) - @handle_exceptions - def delete_user( - user_name_or_id: Union[str, UUID], - auth_context: AuthContext = Security(authorize), - ) -> None: - """Deletes a specific user. - - Args: - user_name_or_id: Name or ID of the user. - auth_context: The authentication context. - - Raises: - IllegalOperationError: If the user is not authorized to delete the user. - """ - user = zen_store().get_user(user_name_or_id) - - if auth_context.user.name == user.name: - raise IllegalOperationError( - "You cannot delete the user account currently used to authenticate " - "to the ZenML server. If you wish to delete this account, " - "please authenticate with another account or contact your ZenML " - "administrator." - ) - zen_store().delete_user(user_name_or_id=user_name_or_id) - @router.put( "/{user_name_or_id}" + EMAIL_ANALYTICS, response_model=UserResponseModel, From dda2905b011b697f3845cc0e7feb48fef10dcfaa Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 31 Oct 2023 09:46:50 +0100 Subject: [PATCH 038/103] Respect ownership in service connector endpoints --- .../routers/service_connectors_endpoints.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/zenml/zen_server/routers/service_connectors_endpoints.py b/src/zenml/zen_server/routers/service_connectors_endpoints.py index 2621008df33..850e5a9864a 100644 --- a/src/zenml/zen_server/routers/service_connectors_endpoints.py +++ b/src/zenml/zen_server/routers/service_connectors_endpoints.py @@ -45,6 +45,7 @@ from zenml.zen_server.rbac.utils import ( get_allowed_resource_ids, has_permissions_for_model, + is_owned_by_authenticated_user, verify_permission, verify_permission_for_model, ) @@ -108,9 +109,13 @@ def list_service_connectors( if not connector.secret_id: continue - # TODO: check for ownership. Can I always read the secret of a service connector I own? - # What if someone updates the secret? - if allowed_ids and connector.id not in allowed_ids: + if allowed_ids is None or is_owned_by_authenticated_user( + connector + ): + # The user either owns the connector or has permissions to + # read secret values for all service connectors + pass + elif connector.id not in allowed_ids: # The user is not allowed to read secret values for this # connector. We don't raise an exception here but don't include # the secret values @@ -150,8 +155,11 @@ def get_service_connector( if ( expand_secrets and connector.secret_id - and has_permissions_for_model( - connector, action=Action.READ_SECRET_VALUE + and ( + is_owned_by_authenticated_user(connector) + or has_permissions_for_model( + connector, action=Action.READ_SECRET_VALUE + ) ) ): secret = zen_store().get_secret(secret_id=connector.secret_id) From 78b82bc9ae422591cd190b6292c35e772649ab09 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 31 Oct 2023 11:36:54 +0100 Subject: [PATCH 039/103] Refactor internal request models --- src/zenml/models/__init__.py | 4 ++++ src/zenml/models/base_models.py | 4 ++-- src/zenml/models/component_models.py | 6 ++++++ src/zenml/models/stack_models.py | 6 ++++++ src/zenml/zen_server/rbac/utils.py | 3 +++ .../routers/service_connectors_endpoints.py | 7 ++----- src/zenml/zen_stores/base_zen_store.py | 13 ++----------- 7 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index 837763840c8..3162e8d9baa 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -39,6 +39,7 @@ from zenml.models.component_models import ( ComponentFilterModel, ComponentRequestModel, + InternalComponentRequestModel, ComponentResponseModel, ComponentUpdateModel, ) @@ -121,6 +122,7 @@ StackRequestModel, StackResponseModel, StackUpdateModel, + InternalStackRequestModel, ) from zenml.models.step_run_models import ( StepRunFilterModel, @@ -315,6 +317,7 @@ "CodeRepositoryUpdateModel", "ComponentFilterModel", "ComponentRequestModel", + "InternalComponentRequestModel", "ComponentResponseModel", "ComponentUpdateModel", "ExternalUserModel", @@ -396,6 +399,7 @@ "ServiceConnectorUpdateModel", "StackFilterModel", "StackRequestModel", + "InternalStackRequestModel", "StackResponseModel", "StackUpdateModel", "StepRunFilterModel", diff --git a/src/zenml/models/base_models.py b/src/zenml/models/base_models.py index 534296a125c..e9e1d341621 100644 --- a/src/zenml/models/base_models.py +++ b/src/zenml/models/base_models.py @@ -271,8 +271,8 @@ def update_model(_cls: Type[T]) -> Type[T]: return _cls -def internal_model(_cls: Type[T]) -> Type[T]: - """Convert a request model to an internal model. +def server_owned_request_model(_cls: Type[T]) -> Type[T]: + """Convert a request model to a model which does not require a user ID. Args: _cls: The class to decorate diff --git a/src/zenml/models/component_models.py b/src/zenml/models/component_models.py index dcf24a30a72..76e65988523 100644 --- a/src/zenml/models/component_models.py +++ b/src/zenml/models/component_models.py @@ -32,6 +32,7 @@ from zenml.models.base_models import ( WorkspaceScopedRequestModel, WorkspaceScopedResponseModel, + server_owned_request_model, update_model, ) from zenml.models.constants import STR_FIELD_MAX_LENGTH @@ -218,6 +219,11 @@ def name_cant_be_a_secret_reference(cls, name: str) -> str: return name +@server_owned_request_model +class InternalComponentRequestModel(ComponentRequestModel): + pass + + # ------ # # UPDATE # # ------ # diff --git a/src/zenml/models/stack_models.py b/src/zenml/models/stack_models.py index 74e26e1852c..f7168db815a 100644 --- a/src/zenml/models/stack_models.py +++ b/src/zenml/models/stack_models.py @@ -23,6 +23,7 @@ from zenml.models.base_models import ( WorkspaceScopedRequestModel, WorkspaceScopedResponseModel, + server_owned_request_model, update_model, ) from zenml.models.component_models import ComponentResponseModel @@ -175,6 +176,11 @@ def is_valid(self) -> bool: ) +@server_owned_request_model +class InternalStackRequestModel(StackRequestModel): + pass + + # ------ # # UPDATE # # ------ # diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index dd9e37ceb44..9a5bf749ef6 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -133,6 +133,9 @@ def has_permissions_for_model(model: "BaseResponseModel", action: str) -> bool: Returns: If the active user has permissions to perform the action on the model. """ + if is_owned_by_authenticated_user(model): + return True + try: verify_permission_for_model(model=model, action=action) return True diff --git a/src/zenml/zen_server/routers/service_connectors_endpoints.py b/src/zenml/zen_server/routers/service_connectors_endpoints.py index 850e5a9864a..cb6a6cbc8a5 100644 --- a/src/zenml/zen_server/routers/service_connectors_endpoints.py +++ b/src/zenml/zen_server/routers/service_connectors_endpoints.py @@ -155,11 +155,8 @@ def get_service_connector( if ( expand_secrets and connector.secret_id - and ( - is_owned_by_authenticated_user(connector) - or has_permissions_for_model( - connector, action=Action.READ_SECRET_VALUE - ) + and has_permissions_for_model( + connector, action=Action.READ_SECRET_VALUE ) ): secret = zen_store().get_secret(secret_id=connector.secret_id) diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index 95c3cdd1837..01803851bd6 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -50,9 +50,9 @@ ) from zenml.logger import get_logger from zenml.models import ( - ComponentRequestModel, + InternalComponentRequestModel, + InternalStackRequestModel, StackFilterModel, - StackRequestModel, StackResponseModel, UserRequestModel, UserResponseModel, @@ -531,11 +531,6 @@ def _create_default_stack( logger.info( f"Creating default stack in workspace {workspace.name}..." ) - from zenml.models.base_models import internal_model - - @internal_model - class InternalComponentRequestModel(ComponentRequestModel): - pass # Register the default orchestrator orchestrator = self.create_stack_component( @@ -565,10 +560,6 @@ class InternalComponentRequestModel(ComponentRequestModel): c.type: [c.id] for c in [orchestrator, artifact_store] } - @internal_model - class InternalStackRequestModel(StackRequestModel): - pass - # Register the default stack stack = InternalStackRequestModel( name=DEFAULT_STACK_AND_COMPONENT_NAME, From 9e10e059e266702447543b330dca508184b30012 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 31 Oct 2023 12:02:56 +0100 Subject: [PATCH 040/103] Disable workspace endpoints --- src/zenml/zen_stores/base_zen_store.py | 14 +-- src/zenml/zen_stores/rest_zen_store.py | 110 ++++++++++---------- src/zenml/zen_stores/sql_zen_store.py | 7 ++ src/zenml/zen_stores/zen_store_interface.py | 84 ++++++++------- 4 files changed, 108 insertions(+), 107 deletions(-) diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index 01803851bd6..8152e249fed 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Base Zen Store implementation.""" import os -from abc import ABC +from abc import ABC, abstractmethod from typing import ( Any, Callable, @@ -56,7 +56,6 @@ StackResponseModel, UserRequestModel, UserResponseModel, - WorkspaceRequestModel, WorkspaceResponseModel, ) from zenml.models.server_models import ( @@ -293,10 +292,7 @@ def get_default_store_config(path: str) -> StoreConfiguration: def _initialize_database(self) -> None: """Initialize the database on first use.""" - try: - default_workspace = self._default_workspace - except KeyError: - default_workspace = self._create_default_workspace() + default_workspace = self._get_or_create_default_workspace() config = ServerConfiguration.get_server_config() # If the auth scheme is external, don't create the default user @@ -696,17 +692,13 @@ def _default_workspace(self) -> WorkspaceResponseModel: f"The default workspace '{workspace_name}' is not configured" ) + @abstractmethod def _create_default_workspace(self) -> WorkspaceResponseModel: """Creates a default workspace. Returns: The default workspace. """ - workspace_name = self._default_workspace_name - logger.info(f"Creating default workspace '{workspace_name}' ...") - return self.create_workspace( - WorkspaceRequestModel(name=workspace_name) - ) class Config: """Pydantic configuration class.""" diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 6036a26670a..6ae0c8f4eb6 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -151,9 +151,7 @@ UserResponseModel, UserUpdateModel, WorkspaceFilterModel, - WorkspaceRequestModel, WorkspaceResponseModel, - WorkspaceUpdateModel, ) from zenml.models.base_models import ( BaseRequestModel, @@ -405,15 +403,21 @@ def _initialize_database(self) -> None: """Initialize the database.""" # don't do anything for a REST store - def _get_or_create_default_stack( + def _create_default_stack( self, - workspace: "WorkspaceResponseModel", - ) -> "StackResponseModel": - # Overwrite this function so we don't try to create a default stack - # client-side. The default stack can't be deleted/modified so the - # fetching should always work. - return self._get_default_stack( - workspace_id=workspace.id, + workspace_id: UUID, + ) -> StackResponseModel: + workspace = self.get_workspace(workspace_id) + + raise RuntimeError( + f"Unable to create default stack in workspace " + f"{workspace.name}." + ) + + def _create_default_workspace(self) -> WorkspaceResponseModel: + raise RuntimeError( + f"Unable to create default workspace " + f"{self._default_workspace_name}." ) # ==================================== @@ -812,22 +816,22 @@ def delete_user(self, user_name_or_id: Union[str, UUID]) -> None: # Workspaces # -------- - def create_workspace( - self, workspace: WorkspaceRequestModel - ) -> WorkspaceResponseModel: - """Creates a new workspace. + # def create_workspace( + # self, workspace: WorkspaceRequestModel + # ) -> WorkspaceResponseModel: + # """Creates a new workspace. - Args: - workspace: The workspace to create. + # Args: + # workspace: The workspace to create. - Returns: - The newly created workspace. - """ - return self._create_resource( - resource=workspace, - route=WORKSPACES, - response_model=WorkspaceResponseModel, - ) + # Returns: + # The newly created workspace. + # """ + # return self._create_resource( + # resource=workspace, + # route=WORKSPACES, + # response_model=WorkspaceResponseModel, + # ) def get_workspace( self, workspace_name_or_id: Union[UUID, str] @@ -864,35 +868,35 @@ def list_workspaces( filter_model=workspace_filter_model, ) - def update_workspace( - self, workspace_id: UUID, workspace_update: WorkspaceUpdateModel - ) -> WorkspaceResponseModel: - """Update an existing workspace. - - Args: - workspace_id: The ID of the workspace to be updated. - workspace_update: The update to be applied to the workspace. - - Returns: - The updated workspace. - """ - return self._update_resource( - resource_id=workspace_id, - resource_update=workspace_update, - route=WORKSPACES, - response_model=WorkspaceResponseModel, - ) - - def delete_workspace(self, workspace_name_or_id: Union[str, UUID]) -> None: - """Deletes a workspace. - - Args: - workspace_name_or_id: Name or ID of the workspace to delete. - """ - self._delete_resource( - resource_id=workspace_name_or_id, - route=WORKSPACES, - ) + # def update_workspace( + # self, workspace_id: UUID, workspace_update: WorkspaceUpdateModel + # ) -> WorkspaceResponseModel: + # """Update an existing workspace. + + # Args: + # workspace_id: The ID of the workspace to be updated. + # workspace_update: The update to be applied to the workspace. + + # Returns: + # The updated workspace. + # """ + # return self._update_resource( + # resource_id=workspace_id, + # resource_update=workspace_update, + # route=WORKSPACES, + # response_model=WorkspaceResponseModel, + # ) + + # def delete_workspace(self, workspace_name_or_id: Union[str, UUID]) -> None: + # """Deletes a workspace. + + # Args: + # workspace_name_or_id: Name or ID of the workspace to delete. + # """ + # self._delete_resource( + # resource_id=workspace_name_or_id, + # route=WORKSPACES, + # ) # --------- # Pipelines diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 1cfc8108a4a..d9da57ca498 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -667,6 +667,13 @@ class SqlZenStore(BaseZenStore): _engine: Optional[Engine] = None _alembic: Optional[Alembic] = None + def _create_default_workspace(self) -> WorkspaceResponseModel: + workspace_name = self._default_workspace_name + logger.info(f"Creating default workspace '{workspace_name}' ...") + return self.create_workspace( + WorkspaceRequestModel(name=workspace_name) + ) + @property def engine(self) -> Engine: """The SQLAlchemy engine. diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 8aaab6cdd81..44beed6e27f 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -87,9 +87,7 @@ UserResponseModel, UserUpdateModel, WorkspaceFilterModel, - WorkspaceRequestModel, WorkspaceResponseModel, - WorkspaceUpdateModel, ) from zenml.models.page_model import Page from zenml.models.run_metadata_models import RunMetadataFilterModel @@ -493,36 +491,36 @@ def update_user( KeyError: If no user with the given name exists. """ - @abstractmethod - def delete_user(self, user_name_or_id: Union[str, UUID]) -> None: - """Deletes a user. + # @abstractmethod + # def delete_user(self, user_name_or_id: Union[str, UUID]) -> None: + # """Deletes a user. - Args: - user_name_or_id: The name or ID of the user to delete. + # Args: + # user_name_or_id: The name or ID of the user to delete. - Raises: - KeyError: If no user with the given ID exists. - """ + # Raises: + # KeyError: If no user with the given ID exists. + # """ # -------- # Workspaces # -------- - @abstractmethod - def create_workspace( - self, workspace: WorkspaceRequestModel - ) -> WorkspaceResponseModel: - """Creates a new workspace. + # @abstractmethod + # def create_workspace( + # self, workspace: WorkspaceRequestModel + # ) -> WorkspaceResponseModel: + # """Creates a new workspace. - Args: - workspace: The workspace to create. + # Args: + # workspace: The workspace to create. - Returns: - The newly created workspace. + # Returns: + # The newly created workspace. - Raises: - EntityExistsError: If a workspace with the given name already exists. - """ + # Raises: + # EntityExistsError: If a workspace with the given name already exists. + # """ @abstractmethod def get_workspace( @@ -554,33 +552,33 @@ def list_workspaces( A list of all workspace matching the filter criteria. """ - @abstractmethod - def update_workspace( - self, workspace_id: UUID, workspace_update: WorkspaceUpdateModel - ) -> WorkspaceResponseModel: - """Update an existing workspace. + # @abstractmethod + # def update_workspace( + # self, workspace_id: UUID, workspace_update: WorkspaceUpdateModel + # ) -> WorkspaceResponseModel: + # """Update an existing workspace. - Args: - workspace_id: The ID of the workspace to be updated. - workspace_update: The update to be applied to the workspace. + # Args: + # workspace_id: The ID of the workspace to be updated. + # workspace_update: The update to be applied to the workspace. - Returns: - The updated workspace. + # Returns: + # The updated workspace. - Raises: - KeyError: if the workspace does not exist. - """ + # Raises: + # KeyError: if the workspace does not exist. + # """ - @abstractmethod - def delete_workspace(self, workspace_name_or_id: Union[str, UUID]) -> None: - """Deletes a workspace. + # @abstractmethod + # def delete_workspace(self, workspace_name_or_id: Union[str, UUID]) -> None: + # """Deletes a workspace. - Args: - workspace_name_or_id: Name or ID of the workspace to delete. + # Args: + # workspace_name_or_id: Name or ID of the workspace to delete. - Raises: - KeyError: If no workspace with the given name exists. - """ + # Raises: + # KeyError: If no workspace with the given name exists. + # """ # --------- # Pipelines From e57eb638cd3d4ff143e706f875f4696f3c33f1e1 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 3 Nov 2023 12:49:29 +0100 Subject: [PATCH 041/103] Model endpoints --- src/zenml/models/model_models.py | 29 +++++++ src/zenml/zen_server/rbac/models.py | 6 +- src/zenml/zen_server/rbac/utils.py | 33 ++++++++ .../zen_server/routers/models_endpoints.py | 81 ++++++++++--------- .../routers/workspaces_endpoints.py | 30 ++++--- 5 files changed, 132 insertions(+), 47 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 0f11176ac20..af2a0a037ab 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -20,6 +20,7 @@ Dict, List, Optional, + Set, Type, TypeVar, Union, @@ -75,6 +76,23 @@ class ModelScopedFilterModel(WorkspaceScopedFilterModel): """Base filter model inside Model Scope.""" _model_id: UUID = PrivateAttr(None) + _rbac_allowed_model_ids: Optional[Set[UUID]] = None + + def set_rbac_allowed_model_ids_and_user( + self, allowed_model_ids: Optional[Set[UUID]], user_id: Optional[UUID] + ) -> None: + """Set allowed model IDs and user ID for the query. + + Args: + allowed_model_ids: Set of IDs to limit the query to. If given, the + remaining filters will be applied to entities within this set + only. If `None`, the remaining filters will applied to all + entries in the table. + user_id: ID of the authenticated user. If given, all entities owned + by this user will be included in addition to the `allowed_ids`. + """ + self._rbac_allowed_model_ids = allowed_model_ids + self._rbac_user_id = user_id def set_scope_model(self, model_name_or_id: Union[str, UUID]) -> None: """Set the model to scope this response. @@ -105,8 +123,19 @@ def apply_filter( Returns: The query with filter applied. """ + from sqlmodel import or_ + query = super().apply_filter(query=query, table=table) + if self._rbac_allowed_model_ids is not None: + if self._rbac_user_id and hasattr(table, "user"): + query = query.where( + or_( + getattr(table, "model_id").in_(self._rbac_allowed_model_ids), # type: ignore[attr-defined] + getattr(table, "user") == self._rbac_user_id, + ) + ) + if self._model_id: query = query.where(getattr(table, "model_id") == self._model_id) diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index aab752d7e14..baf48687cc1 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -28,10 +28,13 @@ class Action(StrEnum): READ = "read" UPDATE = "update" DELETE = "delete" + READ_SECRET_VALUE = "read_secret_value" # Service connectors CLIENT = "client" # TODO: rename - READ_SECRET_VALUE = "read_secret_value" + + # Models + PROMOTE = "promote" class ResourceType(StrEnum): @@ -43,6 +46,7 @@ class ResourceType(StrEnum): PIPELINE = "pipeline" CODE_REPOSITORY = "code_repository" MODEL = "model" + MODEL_VERSION = "model_version" SERVICE_CONNECTOR = "service_connector" ARTIFACT = "artifact" SECRET = "secret" diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 9a5bf749ef6..b3416099fce 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -100,6 +100,9 @@ def _dehydrate_value( The recursively dehydrated value. """ if isinstance(value, BaseResponseModel): + value = get_surrogate_permission_model_for_model( + value, action=Action.READ + ) resource = get_resource_for_model(value) has_permissions = resource and (permissions or {}).get(resource, False) @@ -212,6 +215,8 @@ def batch_verify_permissions_for_models( # The model owner always has permissions continue + model = get_surrogate_permission_model_for_model(model, action=action) + if resource := get_resource_for_model(model): resources.add(resource) @@ -342,6 +347,31 @@ def get_resource_for_model(model: "BaseResponseModel") -> Optional[Resource]: return Resource(type=resource_type, id=model.id) +def get_surrogate_permission_model_for_model( + model: "BaseResponseModel", action: str +) -> "BaseResponseModel": + """Get a surrogate permission model for a model. + + In some cases a different model instead of the original model is used to + verify permissions. For example, a parent container model might be used + to verify permissions for all its children. + + Args: + model: The original model. + action: The action that the user wants to perform on the model. + + Returns: + A surrogate model or the original. + """ + from zenml.models import ModelVersionResponseModel + + if action == Action.READ == isinstance(model, ModelVersionResponseModel): + # Permissions to read a model version is the same as reading the model + return model.model + + return model + + def get_resource_type_for_model( model: "BaseResponseModel", ) -> Optional[ResourceType]: @@ -436,6 +466,9 @@ def _get_subresources_for_value(value: Any) -> Set[Resource]: if isinstance(value, BaseResponseModel): resources = set() if not is_owned_by_authenticated_user(value): + value = get_surrogate_permission_model_for_model( + value, action=Action.READ + ) if resource := get_resource_for_model(value): resources.add(resource) diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index f1c366465a3..7f1f57113f7 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -49,7 +49,13 @@ verify_permissions_and_list_entities, verify_permissions_and_update_entity, ) -from zenml.zen_server.rbac.models import ResourceType +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import ( + dehydrate_page, + dehydrate_response_model, + get_allowed_resource_ids, + verify_permission_for_model, +) from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -184,7 +190,7 @@ def list_model_versions( model_version_filter_model: ModelVersionFilterModel = Depends( make_dependable(ModelVersionFilterModel) ), - _: AuthContext = Security(authorize), + auth_context: AuthContext = Security(authorize), ) -> Page[ModelVersionResponseModel]: """Get model versions according to query filters. @@ -196,10 +202,18 @@ def list_model_versions( Returns: The model versions according to query filters. """ - return zen_store().list_model_versions( + allowed_model_ids = get_allowed_resource_ids( + resource_type=ResourceType.MODEL + ) + model_version_filter_model.set_rbac_allowed_model_ids_and_user( + allowed_model_ids=allowed_model_ids, user_id=auth_context.user.id + ) + + model_versions = zen_store().list_model_versions( model_name_or_id=model_name_or_id, model_version_filter_model=model_version_filter_model, ) + return dehydrate_page(model_versions) @router.get( @@ -229,13 +243,18 @@ def get_model_version( Returns: The model version with the given name or ID. """ - return zen_store().get_model_version( + model = zen_store().get_model(model_name_or_id) + verify_permission_for_model(model, action=Action.READ) + + model_version = zen_store().get_model_version( model_name_or_id, model_version_name_or_number_or_id if not is_number else int(model_version_name_or_number_or_id), ) + return dehydrate_response_model(model_version) + @router.put( "/{model_id}" + MODEL_VERSIONS + "/{model_version_id}", @@ -244,6 +263,7 @@ def get_model_version( ) @handle_exceptions def update_model_version( + model_id: UUID, model_version_id: UUID, model_version_update_model: ModelVersionUpdateModel, _: AuthContext = Security(authorize), @@ -251,17 +271,26 @@ def update_model_version( """Get all model versions by filter. Args: + model_id: The ID of the model that the version belongs to. model_version_id: The ID of model version to be updated. model_version_update_model: The model version to be updated. Returns: An updated model version. """ - return zen_store().update_model_version( - model_version_id=model_version_id, - model_version_update_model=model_version_update_model, + if model_version_update_model.stage: + # Make sure the user has permissions to promote the model + model = zen_store().get_model(model_id) + verify_permission_for_model(model, action=Action.PROMOTE) + + model_version = zen_store().get_model_version(model_id, model_version_id) + verify_permission_for_model(model_version, action=Action.UPDATE) + updated_model_version = zen_store().update_model_version( + model_version_id=model_version_id ) + return dehydrate_response_model(updated_model_version) + @router.delete( "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}", @@ -279,6 +308,10 @@ def delete_model_version( model_name_or_id: The name or ID of the model containing version. model_version_name_or_id: The name or ID of the model version to delete. """ + model_version = zen_store().get_model_version( + model_name_or_id, model_version_name_or_id + ) + verify_permission_for_model(model_version, action=Action.DELETE) zen_store().delete_model_version( model_name_or_id, model_version_name_or_id ) @@ -346,6 +379,11 @@ def delete_model_version_artifact_link( model_version_name_or_id: name or ID of the model version containing the link. model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. """ + model_version = zen_store().get_model_version( + model_name_or_id, model_version_name_or_id + ) + verify_permission_for_model(model_version, action=Action.UPDATE) + zen_store().delete_model_version_artifact_link( model_name_or_id, model_version_name_or_id, @@ -391,32 +429,3 @@ def list_model_version_pipeline_run_links( model_version_name_or_id=model_version_name_or_id, model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, ) - - -@router.delete( - "/{model_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_name_or_id}" - + RUNS - + "/{model_version_pipeline_run_link_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_model_version_pipeline_run_link( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], - model_version_pipeline_run_link_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize), -) -> None: - """Deletes a model version link. - - Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. - model_version_pipeline_run_link_name_or_id: name or ID of the model version link to be deleted. - """ - zen_store().delete_model_version_pipeline_run_link( - model_name_or_id, - model_version_name_or_id, - model_version_pipeline_run_link_name_or_id, - ) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 98224f5f880..0dbaa70371f 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1197,9 +1197,8 @@ def create_model_version( The created model version. Raises: - IllegalOperationError: If the workspace or user specified in the - model version does not match the current workspace or authenticated - user. + IllegalOperationError: If the workspace specified in the + model version does not match the current workspace. """ workspace = zen_store().get_workspace(workspace_name_or_id) @@ -1209,13 +1208,12 @@ def create_model_version( f"of this endpoint `{workspace_name_or_id}` is " f"not supported." ) - if model_version.user != auth_context.user.id: - raise IllegalOperationError( - "Creating models for a user other than yourself " - "is not supported." - ) - mv = zen_store().create_model_version(model_version) - return mv + + return verify_permissions_and_create_entity( + request_model=model_version, + resource_type=ResourceType.MODEL_VERSION, + create_method=zen_store().create_model_version, + ) @router.post( @@ -1267,6 +1265,12 @@ def create_model_version_artifact_link( "Creating model to artifact links for a user other than yourself " "is not supported." ) + + model_version = zen_store().get_model_version( + model_name_or_id, model_version_name_or_id + ) + verify_permission_for_model(model_version, action=Action.UPDATE) + mv = zen_store().create_model_version_artifact_link( model_version_artifact_link ) @@ -1363,6 +1367,12 @@ def create_model_version_pipeline_run_link( "Creating models for a user other than yourself " "is not supported." ) + + model_version = zen_store().get_model_version( + model_name_or_id, model_version_name_or_id + ) + verify_permission_for_model(model_version, action=Action.UPDATE) + mv = zen_store().create_model_version_pipeline_run_link( model_version_pipeline_run_link ) From 1278009bd1e6d57ecb6b1ffd9bf5f1d6b5bcc860 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 16 Nov 2023 14:20:49 +0100 Subject: [PATCH 042/103] Start post-merge cleanup --- src/zenml/cli/__init__.py | 82 +----------- src/zenml/cli/service_accounts.py | 12 -- src/zenml/client.py | 123 +++++++++--------- src/zenml/constants.py | 4 - src/zenml/models/v2/misc/user_auth.py | 7 +- .../zen_server/routers/flavors_endpoints.py | 6 +- src/zenml/zen_stores/rest_zen_store.py | 2 +- 7 files changed, 68 insertions(+), 168 deletions(-) diff --git a/src/zenml/cli/__init__.py b/src/zenml/cli/__init__.py index 3bbf9cff5a7..3a90c78c3bc 100644 --- a/src/zenml/cli/__init__.py +++ b/src/zenml/cli/__init__.py @@ -1267,11 +1267,11 @@ def my_pipeline(...): ssl_verify_server_cert: false ``` -Managing users, teams, workspaces and roles +Managing users and workspaces ------------------------------------------- -When using the ZenML service, you can manage permissions by managing users, -teams, workspaces and roles using the CLI. +When using the ZenML service, you can manage permissions by managing users and +workspaces and using the CLI. If you want to create a new user or delete an existing one, run either ```bash @@ -1282,87 +1282,11 @@ def my_pipeline(...): zenml user delete USER_NAME ``` -A freshly created user will by default be assigned the admin role. This -behavior can be overwritten: -```bash -zenml user create USER_NAME --role guest -``` - To see a list of all users, run: ```bash zenml user list ``` -A team is a grouping of many users that allows you to quickly assign and -revoke roles. If you want to create a new team, run: - -```bash -zenml team create TEAM_NAME -``` -To add one or more users to a team, run: -```bash -zenml team add TEAM_NAME --user USER_NAME [--user USER_NAME ...] -``` -Similarly, to remove users from a team run: -```bash -zenml team remove TEAM_NAME --user USER_NAME [--user USER_NAME ...] -``` -To delete a team (keep in mind this will revoke any roles assigned to this -team from the team members), run: -```bash -zenml team delete TEAM_NAME -``` - -To see a list of all teams, run: -```bash -zenml team list -``` - -A role groups permissions to resources. Currently, there are the following -globally scoped roles to choose from: 'write', 'read' and 'me'. To create -a role, run one of the following commands: -```bash -zenml role create ROLE_NAME -p write -p read -p me -zenml role create ROLE_NAME -p read -``` - -To delete a role run: -```bash -zenml role delete ROLE_NAME -``` - -To see a list of all roles, run: -```bash -zenml role list -``` - -You can also update the role name and the attached permissions of a role: -```bash -zenml role update [-n | -r | -a ] -``` - -If you want to assign or revoke a role from users or teams, you can run - -```bash -zenml role assign ROLE_NAME --user USER_NAME [--user USER_NAME ...] -zenml role assign ROLE_NAME --team TEAM_NAME [--team TEAM_NAME ...] -``` -or -```bash -zenml role revoke ROLE_NAME --user USER_NAME [--user USER_NAME ...] -zenml role revoke ROLE_NAME --team TEAM_NAME [--team TEAM_NAME ...] -``` - -You can see a list of all current role assignments by running: - -```bash -zenml role assignment list -``` - -At any point you may inspect all available permissions: -```bash -zenml permission list -``` Managing service accounts ------------------------- diff --git a/src/zenml/cli/service_accounts.py b/src/zenml/cli/service_accounts.py index 35b8c8c96f2..192925e2077 100644 --- a/src/zenml/cli/service_accounts.py +++ b/src/zenml/cli/service_accounts.py @@ -110,21 +110,11 @@ def service_account() -> None: help=("Configure the local client to use the generated API key."), is_flag=True, ) -@click.option( - "--role", - "-r", - "initial_role", - help="Give the service account an initial role.", - required=False, - type=str, - default="admin", -) def create_service_account( service_account_name: str, description: str = "", create_api_key: bool = True, set_api_key: bool = False, - initial_role: str = "admin", ) -> None: """Create a new service account. @@ -133,14 +123,12 @@ def create_service_account( description: The API key description. create_api_key: Create an API key for the service account. set_api_key: Configure the local client to use the generated API key. - initial_role: Give the service account an initial role """ client = Client() try: service_account = client.create_service_account( name=service_account_name, description=description, - initial_role=initial_role, ) cli_utils.declare(f"Created service account '{service_account.name}'.") diff --git a/src/zenml/client.py b/src/zenml/client.py index 3ab17db9118..df8acdc7b4b 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -63,7 +63,6 @@ from zenml.exceptions import ( AuthorizationException, EntityExistsError, - IllegalOperationError, InitializationException, ValidationError, ZenKeyError, @@ -149,9 +148,7 @@ UserResponse, UserUpdate, WorkspaceFilter, - WorkspaceRequest, WorkspaceResponse, - WorkspaceUpdate, ) from zenml.utils import io_utils, source_utils from zenml.utils.filesync_model import FileSyncModel @@ -802,21 +799,21 @@ def active_user(self) -> "UserResponse": # -------------------------------- Workspaces ------------------------------ - def create_workspace( - self, name: str, description: str - ) -> WorkspaceResponse: - """Create a new workspace. + # def create_workspace( + # self, name: str, description: str + # ) -> WorkspaceResponse: + # """Create a new workspace. - Args: - name: Name of the workspace. - description: Description of the workspace. + # Args: + # name: Name of the workspace. + # description: Description of the workspace. - Returns: - The created workspace. - """ - return self.zen_store.create_workspace( - WorkspaceRequest(name=name, description=description) - ) + # Returns: + # The created workspace. + # """ + # return self.zen_store.create_workspace( + # WorkspaceRequest(name=name, description=description) + # ) def get_workspace( self, @@ -880,53 +877,53 @@ def list_workspaces( ) ) - def update_workspace( - self, - name_id_or_prefix: Optional[Union[UUID, str]], - new_name: Optional[str] = None, - new_description: Optional[str] = None, - ) -> WorkspaceResponse: - """Update a workspace. - - Args: - name_id_or_prefix: Name, ID or prefix of the workspace to update. - new_name: New name of the workspace. - new_description: New description of the workspace. - - Returns: - The updated workspace. - """ - workspace = self.get_workspace( - name_id_or_prefix=name_id_or_prefix, allow_name_prefix_match=False - ) - workspace_update = WorkspaceUpdate(name=new_name or workspace.name) - if new_description: - workspace_update.description = new_description - return self.zen_store.update_workspace( - workspace_id=workspace.id, - workspace_update=workspace_update, - ) - - def delete_workspace(self, name_id_or_prefix: str) -> None: - """Delete a workspace. - - Args: - name_id_or_prefix: The name or ID of the workspace to delete. - - Raises: - IllegalOperationError: If the workspace to delete is the active - workspace. - """ - workspace = self.get_workspace( - name_id_or_prefix, allow_name_prefix_match=False - ) - if self.active_workspace.id == workspace.id: - raise IllegalOperationError( - f"Workspace '{name_id_or_prefix}' cannot be deleted since " - "it is currently active. Please set another workspace as " - "active first." - ) - self.zen_store.delete_workspace(workspace_name_or_id=workspace.id) + # def update_workspace( + # self, + # name_id_or_prefix: Optional[Union[UUID, str]], + # new_name: Optional[str] = None, + # new_description: Optional[str] = None, + # ) -> WorkspaceResponse: + # """Update a workspace. + + # Args: + # name_id_or_prefix: Name, ID or prefix of the workspace to update. + # new_name: New name of the workspace. + # new_description: New description of the workspace. + + # Returns: + # The updated workspace. + # """ + # workspace = self.get_workspace( + # name_id_or_prefix=name_id_or_prefix, allow_name_prefix_match=False + # ) + # workspace_update = WorkspaceUpdate(name=new_name or workspace.name) + # if new_description: + # workspace_update.description = new_description + # return self.zen_store.update_workspace( + # workspace_id=workspace.id, + # workspace_update=workspace_update, + # ) + + # def delete_workspace(self, name_id_or_prefix: str) -> None: + # """Delete a workspace. + + # Args: + # name_id_or_prefix: The name or ID of the workspace to delete. + + # Raises: + # IllegalOperationError: If the workspace to delete is the active + # workspace. + # """ + # workspace = self.get_workspace( + # name_id_or_prefix, allow_name_prefix_match=False + # ) + # if self.active_workspace.id == workspace.id: + # raise IllegalOperationError( + # f"Workspace '{name_id_or_prefix}' cannot be deleted since " + # "it is currently active. Please set another workspace as " + # "active first." + # ) + # self.zen_store.delete_workspace(workspace_name_or_id=workspace.id) @property def active_workspace(self) -> WorkspaceResponse: diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 02fb5326329..e24e0423b29 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -181,12 +181,8 @@ def handle_int_env_var(var: str, default: int = 0) -> int: STATISTICS = "/statistics" USERS = "/users" CURRENT_USER = "/current-user" -TEAMS = "/teams" WORKSPACES = "/workspaces" -ROLES = "/roles" FLAVORS = "/flavors" -USER_ROLE_ASSIGNMENTS = "/role_assignments" -TEAM_ROLE_ASSIGNMENTS = "/team_role_assignments" LOGIN = "/login" LOGOUT = "/logout" PIPELINES = "/pipelines" diff --git a/src/zenml/models/v2/misc/user_auth.py b/src/zenml/models/v2/misc/user_auth.py index 2bc8ee48910..b11fdf31711 100644 --- a/src/zenml/models/v2/misc/user_auth.py +++ b/src/zenml/models/v2/misc/user_auth.py @@ -15,7 +15,7 @@ import re from datetime import datetime -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from uuid import UUID from pydantic import Field, SecretStr @@ -26,8 +26,6 @@ if TYPE_CHECKING: from passlib.context import CryptContext - from zenml.models.v2.core.team import TeamResponse - class UserAuthModel(BaseZenModel): """Authentication Model for the User. @@ -51,9 +49,6 @@ class UserAuthModel(BaseZenModel): activation_token: Optional[SecretStr] = Field(default=None, exclude=True) password: Optional[SecretStr] = Field(default=None, exclude=True) - teams: Optional[List["TeamResponse"]] = Field( - default=None, title="The list of teams for this user." - ) name: str = Field( title="The unique username for the account.", max_length=STR_FIELD_MAX_LENGTH, diff --git a/src/zenml/zen_server/routers/flavors_endpoints.py b/src/zenml/zen_server/routers/flavors_endpoints.py index 5daabbb7de1..50a2de4786c 100644 --- a/src/zenml/zen_server/routers/flavors_endpoints.py +++ b/src/zenml/zen_server/routers/flavors_endpoints.py @@ -131,7 +131,7 @@ def create_flavor( @router.put( - "/{team_id}", + "/{flavor_id}", response_model=FlavorResponse, responses={401: error_response, 409: error_response, 422: error_response}, ) @@ -146,8 +146,8 @@ def update_flavor( # noqa: DAR401 Args: - flavor_id: ID of the team to update. - flavor_update: Team update. + flavor_id: ID of the flavor to update. + flavor_update: Flavor update. Returns: The updated flavor. diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index f36b0fdb76b..3913ae5cb30 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2305,7 +2305,7 @@ def create_user(self, user: UserRequest) -> UserResponse: """ return self._create_resource( resource=user, - route=USERS + "?assign_default_role=False", + route=USERS, response_model=UserResponse, ) From 2a9c7ad51a4933585392ca448786be9de05ed17a Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 17 Nov 2023 09:40:51 +0100 Subject: [PATCH 043/103] Some after merge fixes --- src/zenml/models/base_models.py | 16 --------- src/zenml/models/v2/base/base.py | 28 ++++++++++++--- src/zenml/models/v2/base/internal.py | 36 +++++++++++++++++++ src/zenml/models/v2/core/component.py | 2 +- src/zenml/models/v2/core/stack.py | 2 +- src/zenml/zen_server/rbac/rbac_interface.py | 6 ++-- src/zenml/zen_server/rbac/zenml_cloud_rbac.py | 6 ++-- .../routers/service_accounts_endpoints.py | 23 ++++++------ .../zen_server/routers/tags_endpoints.py | 11 +++--- 9 files changed, 84 insertions(+), 46 deletions(-) create mode 100644 src/zenml/models/v2/base/internal.py diff --git a/src/zenml/models/base_models.py b/src/zenml/models/base_models.py index 8fe98bb29e3..750cddf11dc 100644 --- a/src/zenml/models/base_models.py +++ b/src/zenml/models/base_models.py @@ -269,19 +269,3 @@ def update_model(_cls: Type[T]) -> Type[T]: return _cls - -def server_owned_request_model(_cls: Type[T]) -> Type[T]: - """Convert a request model to a model which does not require a user ID. - - Args: - _cls: The class to decorate - - Returns: - The decorated class. - """ - if user_field := _cls.__fields__.get("user", None): - user_field.required = False - user_field.allow_none = True - user_field.default = None - - return _cls diff --git a/src/zenml/models/v2/base/base.py b/src/zenml/models/v2/base/base.py index 708594c3f20..6204f5349d8 100644 --- a/src/zenml/models/v2/base/base.py +++ b/src/zenml/models/v2/base/base.py @@ -22,7 +22,7 @@ from zenml.analytics.models import AnalyticsTrackedModelMixin from zenml.enums import ResponseUpdateStrategy -from zenml.exceptions import HydrationError +from zenml.exceptions import HydrationError, IllegalOperationError from zenml.logger import get_logger from zenml.utils.pydantic_utils import YAMLSerializationMixin @@ -104,7 +104,7 @@ class BaseResponse(GenericModel, Generic[AnyBody, AnyMetadata], BaseZenModel): id: UUID = Field(title="The unique resource id.") # Body and metadata pair - body: "AnyBody" = Field(title="The body of the resource.") + body: Optional["AnyBody"] = Field(title="The body of the resource.") metadata: Optional["AnyMetadata"] = Field( title="The metadata related to this resource." ) @@ -255,7 +255,17 @@ def get_body(self) -> AnyBody: Returns: The body field of the response. + + Raises: + IllegalOperationError: If the user lacks permission to access the + entity represented by this response. """ + if not self.body: + raise IllegalOperationError( + f"Missing permissions to access {type(self).__name__} with " + f"ID {self.id}." + ) + return self.body def get_metadata(self) -> "AnyMetadata": @@ -263,7 +273,17 @@ def get_metadata(self) -> "AnyMetadata": Returns: The metadata field of the response. + + Raises: + IllegalOperationError: If the user lacks permission to access this + entity represented by this response. """ + if not self.body: + raise IllegalOperationError( + f"Missing permissions to access {type(self).__name__} with " + f"ID {self.id}." + ) + if self.metadata is None: # If the metadata is not there, check the class first. metadata_type = self.__fields__["metadata"].type_ @@ -302,7 +322,7 @@ def created(self) -> Optional[datetime]: Returns: the value of the property. """ - return self.body.created + return self.get_body().created @property def updated(self) -> Optional[datetime]: @@ -311,4 +331,4 @@ def updated(self) -> Optional[datetime]: Returns: the value of the property. """ - return self.body.updated + return self.get_body().updated diff --git a/src/zenml/models/v2/base/internal.py b/src/zenml/models/v2/base/internal.py new file mode 100644 index 00000000000..d5c8d63f702 --- /dev/null +++ b/src/zenml/models/v2/base/internal.py @@ -0,0 +1,36 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Utility methods for internal models.""" + +from typing import TypeVar, Type +from zenml.models.v2.base.base import BaseRequest + + +T = TypeVar("T", bound="BaseRequest") + +def server_owned_request_model(_cls: Type[T]) -> Type[T]: + """Convert a request model to a model which does not require a user ID. + + Args: + _cls: The class to decorate + + Returns: + The decorated class. + """ + if user_field := _cls.__fields__.get("user", None): + user_field.required = False + user_field.allow_none = True + user_field.default = None + + return _cls diff --git a/src/zenml/models/v2/core/component.py b/src/zenml/models/v2/core/component.py index cd7ac7b4fa0..4686e3de5e2 100644 --- a/src/zenml/models/v2/core/component.py +++ b/src/zenml/models/v2/core/component.py @@ -30,7 +30,7 @@ from zenml.constants import STR_FIELD_MAX_LENGTH from zenml.enums import StackComponentType -from zenml.models.base_models import server_owned_request_model +from zenml.models.v2.base.internal import server_owned_request_model from zenml.models.v2.base.scoped import ( WorkspaceScopedFilter, WorkspaceScopedRequest, diff --git a/src/zenml/models/v2/core/stack.py b/src/zenml/models/v2/core/stack.py index 062b473aafd..b5d18da3d0d 100644 --- a/src/zenml/models/v2/core/stack.py +++ b/src/zenml/models/v2/core/stack.py @@ -21,7 +21,7 @@ from zenml.constants import STR_FIELD_MAX_LENGTH from zenml.enums import StackComponentType -from zenml.models.base_models import server_owned_request_model +from zenml.models.v2.base.internal import server_owned_request_model from zenml.models.v2.base.scoped import ( WorkspaceScopedFilter, WorkspaceScopedRequest, diff --git a/src/zenml/zen_server/rbac/rbac_interface.py b/src/zenml/zen_server/rbac/rbac_interface.py index 8a89629d5e4..d94b3b32b48 100644 --- a/src/zenml/zen_server/rbac/rbac_interface.py +++ b/src/zenml/zen_server/rbac/rbac_interface.py @@ -19,7 +19,7 @@ from zenml.zen_server.rbac.models import Resource if TYPE_CHECKING: - from zenml.models import UserResponseModel + from zenml.models import UserResponse class RBACInterface(ABC): @@ -27,7 +27,7 @@ class RBACInterface(ABC): @abstractmethod def check_permissions( - self, user: "UserResponseModel", resources: Set[Resource], action: str + self, user: "UserResponse", resources: Set[Resource], action: str ) -> Dict[Resource, bool]: """Checks if a user has permissions to perform an action on resources. @@ -43,7 +43,7 @@ def check_permissions( @abstractmethod def list_allowed_resource_ids( - self, user: "UserResponseModel", resource: Resource, action: str + self, user: "UserResponse", resource: Resource, action: str ) -> Tuple[bool, List[str]]: """Lists all resource IDs of a resource type that a user can access. diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py index f4ee5655b58..606d6ff2b03 100644 --- a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -23,7 +23,7 @@ from zenml.zen_server.utils import server_config if TYPE_CHECKING: - from zenml.models import UserResponseModel + from zenml.models import UserResponse ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_" @@ -131,7 +131,7 @@ def __init__(self) -> None: self._session: Optional[requests.Session] = None def check_permissions( - self, user: "UserResponseModel", resources: Set[Resource], action: str + self, user: "UserResponse", resources: Set[Resource], action: str ) -> Dict[Resource, bool]: """Checks if a user has permissions to perform an action on resources. @@ -164,7 +164,7 @@ def check_permissions( return {_convert_from_cloud_resource(k): v for k, v in value.items()} def list_allowed_resource_ids( - self, user: "UserResponseModel", resource: Resource, action: str + self, user: "UserResponse", resource: Resource, action: str ) -> Tuple[bool, List[str]]: """Lists all resource IDs of a resource type that a user can access. diff --git a/src/zenml/zen_server/routers/service_accounts_endpoints.py b/src/zenml/zen_server/routers/service_accounts_endpoints.py index b1bf9c6243a..5d75e05fcd5 100644 --- a/src/zenml/zen_server/routers/service_accounts_endpoints.py +++ b/src/zenml/zen_server/routers/service_accounts_endpoints.py @@ -25,7 +25,6 @@ SERVICE_ACCOUNTS, VERSION_1, ) -from zenml.enums import PermissionType from zenml.models import ( APIKeyFilter, APIKeyRequest, @@ -69,7 +68,7 @@ @handle_exceptions def create_service_account( service_account: ServiceAccountRequest, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> ServiceAccountResponse: """Creates a service account. @@ -93,7 +92,7 @@ def create_service_account( @handle_exceptions def get_service_account( service_account_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), hydrate: bool = True, ) -> ServiceAccountResponse: """Returns a specific service account. @@ -123,7 +122,7 @@ def list_service_accounts( make_dependable(ServiceAccountFilter) ), hydrate: bool = False, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[ServiceAccountResponse]: """Returns a list of service accounts. @@ -154,7 +153,7 @@ def list_service_accounts( def update_service_account( service_account_name_or_id: Union[str, UUID], service_account_update: ServiceAccountUpdate, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> ServiceAccountResponse: """Updates a specific service account. @@ -178,7 +177,7 @@ def update_service_account( @handle_exceptions def delete_service_account( service_account_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> None: """Delete a specific service account. @@ -202,7 +201,7 @@ def delete_service_account( def create_api_key( service_account_id: UUID, api_key: APIKeyRequest, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> APIKeyResponse: """Creates an API key for a service account. @@ -231,7 +230,7 @@ def get_api_key( service_account_id: UUID, api_key_name_or_id: Union[str, UUID], hydrate: bool = True, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> APIKeyResponse: """Returns the requested API key. @@ -263,7 +262,7 @@ def list_api_keys( service_account_id: UUID, filter_model: APIKeyFilter = Depends(make_dependable(APIKeyFilter)), hydrate: bool = False, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> Page[APIKeyResponse]: """List API keys associated with a service account. @@ -296,7 +295,7 @@ def update_api_key( service_account_id: UUID, api_key_name_or_id: Union[str, UUID], api_key_update: APIKeyUpdate, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> APIKeyResponse: """Updates an API key for a service account. @@ -329,7 +328,7 @@ def rotate_api_key( service_account_id: UUID, api_key_name_or_id: Union[str, UUID], rotate_request: APIKeyRotateRequest, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> APIKeyResponse: """Rotate an API key. @@ -357,7 +356,7 @@ def rotate_api_key( def delete_api_key( service_account_id: UUID, api_key_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Deletes an API key. diff --git a/src/zenml/zen_server/routers/tags_endpoints.py b/src/zenml/zen_server/routers/tags_endpoints.py index 80e492ea416..56814b3d9c1 100644 --- a/src/zenml/zen_server/routers/tags_endpoints.py +++ b/src/zenml/zen_server/routers/tags_endpoints.py @@ -23,7 +23,6 @@ TAGS, VERSION_1, ) -from zenml.enums import PermissionType from zenml.models import ( Page, TagFilterModel, @@ -58,7 +57,7 @@ @handle_exceptions def create_tag( tag: TagRequestModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> TagResponseModel: """Create a new tag. @@ -81,7 +80,7 @@ def list_tags( tag_filter_model: TagFilterModel = Depends( make_dependable(TagFilterModel) ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> Page[TagResponseModel]: """Get tags according to query filters. @@ -106,7 +105,7 @@ def list_tags( @handle_exceptions def get_tag( tag_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), + _: AuthContext = Security(authorize), ) -> TagResponseModel: """Get a tag by name or ID. @@ -128,7 +127,7 @@ def get_tag( def update_tag( tag_id: UUID, tag_update_model: TagUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> TagResponseModel: """Updates a tag. @@ -152,7 +151,7 @@ def update_tag( @handle_exceptions def delete_tag( tag_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> None: """Delete a tag by name or ID. From ead6262cec5397f96c17b0c927da02c6adc12e77 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 17 Nov 2023 11:07:06 +0100 Subject: [PATCH 044/103] More fixes and hopeless attempts to fix type checking --- src/zenml/cli/utils.py | 6 +- src/zenml/models/base_models.py | 1 - src/zenml/models/v2/base/filter.py | 21 ++- src/zenml/models/v2/base/internal.py | 5 +- src/zenml/zen_server/rbac/endpoint_utils.py | 29 ++-- src/zenml/zen_server/rbac/utils.py | 148 +++++++++++++----- .../7500f434b71c_remove_shared_columns.py | 2 +- 7 files changed, 147 insertions(+), 65 deletions(-) diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index 64ff25a45fa..8e47e00cba9 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -618,7 +618,7 @@ def print_stack_component_configuration( if component.user: user_name = component.user.name else: - user_name = "[DELETED]" + user_name = "-" declare( f"{component.type.value.title()} '{component.name}' of flavor " @@ -1481,7 +1481,7 @@ def print_stacks_table( if stack.user: user_name = stack.user.name else: - user_name = "[DELETED]" + user_name = "-" stack_config = { "ACTIVE": ":point_right:" if is_active else "", @@ -2249,7 +2249,7 @@ def print_pipeline_runs_table( if pipeline_run.user: user_name = pipeline_run.user.name else: - user_name = "[DELETED]" + user_name = "-" if pipeline_run.pipeline is None: pipeline_name = "unlisted" diff --git a/src/zenml/models/base_models.py b/src/zenml/models/base_models.py index 750cddf11dc..2a9f6143893 100644 --- a/src/zenml/models/base_models.py +++ b/src/zenml/models/base_models.py @@ -268,4 +268,3 @@ def update_model(_cls: Type[T]) -> Type[T]: value.allow_none = True return _cls - diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 00a8b755774..3f04e43bd90 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -767,15 +767,20 @@ def apply_filter( from sqlmodel import or_ if self._rbac_allowed_ids is not None: - if self._rbac_user_id and hasattr(table, "user"): - query = query.where( - or_( - table.id.in_(self._rbac_allowed_ids), # type: ignore[attr-defined] - getattr(table, "user") == self._rbac_user_id, + expressions = [table.id.in_(self._rbac_allowed_ids)] # type: ignore[attr-defined] + + if hasattr(table, "user_id"): + # Unowned entities are considered server-owned and can be seen + # by anyone + expressions.append(getattr(table, "user_id").is_(None)) + + if self._rbac_user_id: + # The authenticated user owns this entity + expressions.append( + getattr(table, "user_id") == self._rbac_user_id ) - ) - else: - query = query.where(table.id.in_(self._rbac_allowed_ids)) # type: ignore[attr-defined] + + query = query.where(or_(*expressions)) filters = self.generate_filter(table=table) diff --git a/src/zenml/models/v2/base/internal.py b/src/zenml/models/v2/base/internal.py index d5c8d63f702..b42547d1d07 100644 --- a/src/zenml/models/v2/base/internal.py +++ b/src/zenml/models/v2/base/internal.py @@ -13,12 +13,13 @@ # permissions and limitations under the License. """Utility methods for internal models.""" -from typing import TypeVar, Type -from zenml.models.v2.base.base import BaseRequest +from typing import Type, TypeVar +from zenml.models.v2.base.base import BaseRequest T = TypeVar("T", bound="BaseRequest") + def server_owned_request_model(_cls: Type[T]) -> Type[T]: """Convert a request model to a model which does not require a user ID. diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py index 030d68c9e5a..dc578e9ebd9 100644 --- a/src/zenml/zen_server/rbac/endpoint_utils.py +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -1,17 +1,22 @@ """High-level helper functions to write endpoints with RBAC.""" -from typing import Callable, TypeVar, Union +from typing import Any, Callable, TypeVar, Union from uuid import UUID from pydantic import BaseModel from zenml.exceptions import IllegalOperationError +from zenml.models import ( + BaseFilter, + BaseRequest, + BaseResponse, + Page, + UserScopedRequest, +) from zenml.models.base_models import ( BaseRequestModel, BaseResponseModel, UserScopedRequestModel, ) -from zenml.models.filter_models import BaseFilterModel -from zenml.models.page_model import Page from zenml.zen_server.auth import get_auth_context from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( @@ -22,9 +27,13 @@ verify_permission_for_model, ) -AnyRequestModel = TypeVar("AnyRequestModel", bound=BaseRequestModel) -AnyResponseModel = TypeVar("AnyResponseModel", bound=BaseResponseModel) -AnyFilterModel = TypeVar("AnyFilterModel", bound=BaseFilterModel) +AnyRequestModel = TypeVar( + "AnyRequestModel", bound=Union[BaseRequestModel, BaseRequest] +) +AnyResponseModel = TypeVar( + "AnyResponseModel", bound=Union[BaseResponseModel, BaseResponse] +) +AnyFilterModel = TypeVar("AnyFilterModel", bound=BaseFilter) AnyUpdateModel = TypeVar("AnyUpdateModel", bound=BaseModel) UUIDOrStr = TypeVar("UUIDOrStr", UUID, Union[UUID, str]) @@ -48,7 +57,7 @@ def verify_permissions_and_create_entity( Returns: A model of the created entity. """ - if isinstance(request_model, UserScopedRequestModel): + if isinstance(request_model, (UserScopedRequest, UserScopedRequestModel)): auth_context = get_auth_context() assert auth_context @@ -65,13 +74,14 @@ def verify_permissions_and_create_entity( def verify_permissions_and_get_entity( id: UUIDOrStr, get_method: Callable[[UUIDOrStr], AnyResponseModel], - **get_method_kwargs, + **get_method_kwargs: Any, ) -> AnyResponseModel: """Verify permissions and fetch an entity. Args: id: The ID of the entity to fetch. get_method: The method to fetch the entity. + get_method_kwargs: Keyword arguments to pass to the get method. Returns: A model of the fetched entity. @@ -85,7 +95,7 @@ def verify_permissions_and_list_entities( filter_model: AnyFilterModel, resource_type: ResourceType, list_method: Callable[[AnyFilterModel], Page[AnyResponseModel]], - **list_method_kwargs, + **list_method_kwargs: Any, ) -> Page[AnyResponseModel]: """Verify permissions and list entities. @@ -93,6 +103,7 @@ def verify_permissions_and_list_entities( filter_model: The entity filter model. resource_type: The resource type of the entities to list. list_method: The method to list the entities. + list_method_kwargs: Keyword arguments to pass to the list method. Returns: A page of entity models. diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 29299b0a979..d90dbe0b4aa 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -15,21 +15,55 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Sequence, Set, Type, TypeVar +from typing import ( + Any, + Dict, + List, + Optional, + Sequence, + Set, + Type, + TypeVar, + Union, +) from uuid import UUID from fastapi import HTTPException - -from zenml.models import Page +from pydantic import BaseModel + +from zenml.models import ( + BaseResponse, + BaseResponseBody, + BaseResponseMetadata, + ComponentResponse, + Page, + StackResponse, + UserScopedResponse, +) from zenml.models.base_models import BaseResponseModel, UserScopedResponseModel from zenml.zen_server.auth import get_auth_context from zenml.zen_server.rbac.models import Action, Resource, ResourceType from zenml.zen_server.utils import rbac, server_config -M = TypeVar("M", bound=BaseResponseModel) - - -def dehydrate_page(page: Page[M]) -> Page[M]: +AnyOldResponseModel = TypeVar("AnyOldResponseModel", bound=BaseResponseModel) +AnyNewResponseModel = TypeVar( + "AnyNewResponseModel", bound=Union[StackResponse, ComponentResponse] +) +AnyResponseModel = TypeVar( + "AnyResponseModel", + bound=Union[StackResponse, ComponentResponse, BaseResponseModel], +) + +AnyResponseBody = TypeVar("AnyResponseBody", bound=BaseResponseBody) +AnyResponseMetadata = TypeVar( + "AnyResponseMetadata", bound=BaseResponseMetadata +) +AnyModel = TypeVar("AnyModel", bound=BaseModel) + + +def dehydrate_page( + page: Page[AnyResponseModel], +) -> Page[AnyResponseModel]: """Dehydrate all items of a page. Args: @@ -59,8 +93,8 @@ def dehydrate_page(page: Page[M]) -> Page[M]: def dehydrate_response_model( - model: M, permissions: Optional[Dict[Resource, bool]] = None -) -> M: + model: AnyModel, permissions: Optional[Dict[Resource, bool]] = None +) -> AnyModel: """Dehydrate a model if necessary. Args: @@ -99,7 +133,7 @@ def _dehydrate_value( Returns: The recursively dehydrated value. """ - if isinstance(value, BaseResponseModel): + if isinstance(value, (BaseResponse, BaseResponseModel)): value = get_surrogate_permission_model_for_model( value, action=Action.READ ) @@ -112,6 +146,8 @@ def _dehydrate_value( return dehydrate_response_model(value, permissions=permissions) else: return get_permission_denied_model(value) + elif isinstance(value, BaseModel): + return dehydrate_response_model(value, permissions=permissions) elif isinstance(value, Dict): return { k: _dehydrate_value(v, permissions=permissions) @@ -126,7 +162,7 @@ def _dehydrate_value( return value -def has_permissions_for_model(model: "BaseResponseModel", action: str) -> bool: +def has_permissions_for_model(model: AnyResponseModel, action: str) -> bool: """If the active user has permissions to perform the action on the model. Args: @@ -146,9 +182,32 @@ def has_permissions_for_model(model: "BaseResponseModel", action: str) -> bool: return False -def get_permission_denied_model( - model: M, keep_id: bool = True, keep_name: bool = True -) -> M: +def get_permission_denied_model(model: AnyResponseModel) -> AnyResponseModel: + if isinstance(model, BaseResponse): + return get_permission_denied_model_v2(model) + else: + return get_permission_denied_model_v1(model) + + +def get_permission_denied_model_v2( + model: AnyNewResponseModel, +) -> AnyNewResponseModel: + """Get a model to return in case of missing read permissions. + + This function removes the body and metadata of the model. + + Args: + model: The original model. + + Returns: + The model with body and metadata removed. + """ + return model.copy(exclude={"body", "metadata"}) + + +def get_permission_denied_model_v1( + model: AnyOldResponseModel, keep_id: bool = True, keep_name: bool = True +) -> AnyOldResponseModel: """Get a model to return in case of missing read permissions. This function replaces all attributes except name and ID in the given model. @@ -173,9 +232,11 @@ def get_permission_denied_model( elif field.allow_none: value = None elif isinstance(value, BaseResponseModel): - value = get_permission_denied_model( + value = get_permission_denied_model_v1( value, keep_id=False, keep_name=False ) + elif isinstance(value, BaseResponse): + value = get_permission_denied_model_v2(value) elif isinstance(value, UUID): value = UUID(int=0) elif isinstance(value, datetime): @@ -197,7 +258,7 @@ def get_permission_denied_model( def batch_verify_permissions_for_models( - models: Sequence["BaseResponseModel"], + models: Sequence[AnyResponseModel], action: str, ) -> None: """Batch permission verification for models. @@ -224,7 +285,7 @@ def batch_verify_permissions_for_models( def verify_permission_for_model( - model: "BaseResponseModel", + model: AnyResponseModel, action: str, ) -> None: """Verifies if a user has permission to perform an action on a model. @@ -329,7 +390,7 @@ def get_allowed_resource_ids( return {UUID(id) for id in allowed_ids} -def get_resource_for_model(model: "BaseResponseModel") -> Optional[Resource]: +def get_resource_for_model(model: AnyResponseModel) -> Optional[Resource]: """Get the resource associated with a model object. Args: @@ -348,8 +409,8 @@ def get_resource_for_model(model: "BaseResponseModel") -> Optional[Resource]: def get_surrogate_permission_model_for_model( - model: "BaseResponseModel", action: str -) -> "BaseResponseModel": + model: AnyResponseModel, action: str +) -> Union[BaseResponse[Any, Any], BaseResponseModel]: """Get a surrogate permission model for a model. In some cases a different model instead of the original model is used to @@ -365,7 +426,7 @@ def get_surrogate_permission_model_for_model( """ from zenml.models import ModelVersionResponseModel - if action == Action.READ == isinstance(model, ModelVersionResponseModel): + if action == Action.READ and isinstance(model, ModelVersionResponseModel): # Permissions to read a model version is the same as reading the model return model.model @@ -373,7 +434,7 @@ def get_surrogate_permission_model_for_model( def get_resource_type_for_model( - model: "BaseResponseModel", + model: AnyResponseModel, ) -> Optional[ResourceType]: """Get the resource type associated with a model object. @@ -385,33 +446,36 @@ def get_resource_type_for_model( is not associated with any resource type. """ from zenml.models import ( - ArtifactResponseModel, - CodeRepositoryResponseModel, - ComponentResponseModel, - FlavorResponseModel, + ArtifactResponse, + CodeRepositoryResponse, + ComponentResponse, + FlavorResponse, ModelResponseModel, - PipelineResponseModel, + PipelineResponse, SecretResponseModel, - ServiceConnectorResponseModel, - StackResponseModel, + ServiceConnectorResponse, + StackResponse, ) - mapping: Dict[Type[BaseResponseModel], ResourceType] = { - FlavorResponseModel: ResourceType.FLAVOR, - ServiceConnectorResponseModel: ResourceType.SERVICE_CONNECTOR, - ComponentResponseModel: ResourceType.STACK_COMPONENT, - StackResponseModel: ResourceType.STACK, - PipelineResponseModel: ResourceType.PIPELINE, - CodeRepositoryResponseModel: ResourceType.CODE_REPOSITORY, + mapping: Dict[ + Union[Type[BaseResponseModel], Type[BaseResponse[Any, Any]]], + ResourceType, + ] = { + FlavorResponse: ResourceType.FLAVOR, + ServiceConnectorResponse: ResourceType.SERVICE_CONNECTOR, + ComponentResponse: ResourceType.STACK_COMPONENT, + StackResponse: ResourceType.STACK, + PipelineResponse: ResourceType.PIPELINE, + CodeRepositoryResponse: ResourceType.CODE_REPOSITORY, SecretResponseModel: ResourceType.SECRET, ModelResponseModel: ResourceType.MODEL, - ArtifactResponseModel: ResourceType.ARTIFACT, + ArtifactResponse: ResourceType.ARTIFACT, } return mapping.get(type(model)) -def is_owned_by_authenticated_user(model: "BaseResponseModel") -> bool: +def is_owned_by_authenticated_user(model: AnyResponseModel) -> bool: """Returns whether the currently authenticated user owns the model. Args: @@ -423,7 +487,7 @@ def is_owned_by_authenticated_user(model: "BaseResponseModel") -> bool: auth_context = get_auth_context() assert auth_context - if isinstance(model, UserScopedResponseModel): + if isinstance(model, (UserScopedResponseModel, UserScopedResponse)): if model.user: return model.user.id == auth_context.user.id else: @@ -435,7 +499,7 @@ def is_owned_by_authenticated_user(model: "BaseResponseModel") -> bool: def get_subresources_for_model( - model: "BaseResponseModel", + model: AnyModel, ) -> Set[Resource]: """Get all subresources of a model which need permission verification. @@ -463,7 +527,7 @@ def _get_subresources_for_value(value: Any) -> Set[Resource]: Returns: All resources of the value which need permission verification. """ - if isinstance(value, BaseResponseModel): + if isinstance(value, (BaseResponse, BaseResponseModel)): resources = set() if not is_owned_by_authenticated_user(value): value = get_surrogate_permission_model_for_model( @@ -473,6 +537,8 @@ def _get_subresources_for_value(value: Any) -> Set[Resource]: resources.add(resource) return resources.union(get_subresources_for_model(value)) + elif isinstance(value, BaseModel): + return get_subresources_for_model(value) elif isinstance(value, Dict): resources_list = [ _get_subresources_for_value(v) for v in value.values() diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py index a02d50385d5..93313facbb7 100644 --- a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -14,7 +14,7 @@ # revision identifiers, used by Alembic. revision = "7500f434b71c" -down_revision = "0.45.5" +down_revision = "0.47.0" branch_labels = None depends_on = None From 5e23c8fab61b2c2e3dc7566eb67db4d53f5a538e Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 17 Nov 2023 14:27:02 +0100 Subject: [PATCH 045/103] Some ugly workarounds to partially make typechecking work --- src/zenml/zen_server/rbac/endpoint_utils.py | 2 +- src/zenml/zen_server/rbac/utils.py | 45 ++++++++++++--------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py index dc578e9ebd9..3a0958952bb 100644 --- a/src/zenml/zen_server/rbac/endpoint_utils.py +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -31,7 +31,7 @@ "AnyRequestModel", bound=Union[BaseRequestModel, BaseRequest] ) AnyResponseModel = TypeVar( - "AnyResponseModel", bound=Union[BaseResponseModel, BaseResponse] + "AnyResponseModel", bound=Union[BaseResponseModel, BaseResponse] # type: ignore[type-arg] ) AnyFilterModel = TypeVar("AnyFilterModel", bound=BaseFilter) AnyUpdateModel = TypeVar("AnyUpdateModel", bound=BaseModel) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index d90dbe0b4aa..8d21379c90b 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -22,9 +22,9 @@ Optional, Sequence, Set, - Type, TypeVar, Union, + cast, ) from uuid import UUID @@ -33,11 +33,7 @@ from zenml.models import ( BaseResponse, - BaseResponseBody, - BaseResponseMetadata, - ComponentResponse, Page, - StackResponse, UserScopedResponse, ) from zenml.models.base_models import BaseResponseModel, UserScopedResponseModel @@ -47,16 +43,11 @@ AnyOldResponseModel = TypeVar("AnyOldResponseModel", bound=BaseResponseModel) AnyNewResponseModel = TypeVar( - "AnyNewResponseModel", bound=Union[StackResponse, ComponentResponse] + "AnyNewResponseModel", bound=BaseResponse # type: ignore[type-arg] ) AnyResponseModel = TypeVar( "AnyResponseModel", - bound=Union[StackResponse, ComponentResponse, BaseResponseModel], -) - -AnyResponseBody = TypeVar("AnyResponseBody", bound=BaseResponseBody) -AnyResponseMetadata = TypeVar( - "AnyResponseMetadata", bound=BaseResponseMetadata + bound=Union[BaseResponse, BaseResponseModel], # type: ignore[type-arg] ) AnyModel = TypeVar("AnyModel", bound=BaseModel) @@ -183,16 +174,28 @@ def has_permissions_for_model(model: AnyResponseModel, action: str) -> bool: def get_permission_denied_model(model: AnyResponseModel) -> AnyResponseModel: + """Get a model to return in case of missing read permissions. + + Args: + model: The original model. + + Returns: + The permission denied model. + """ + if isinstance(model, BaseResponse): - return get_permission_denied_model_v2(model) + return cast(AnyResponseModel, get_permission_denied_model_v2(model)) else: - return get_permission_denied_model_v1(model) + return cast( + AnyResponseModel, + get_permission_denied_model_v1(cast(BaseResponseModel, model)), + ) def get_permission_denied_model_v2( model: AnyNewResponseModel, ) -> AnyNewResponseModel: - """Get a model to return in case of missing read permissions. + """Get a V2 model to return in case of missing read permissions. This function removes the body and metadata of the model. @@ -208,7 +211,7 @@ def get_permission_denied_model_v2( def get_permission_denied_model_v1( model: AnyOldResponseModel, keep_id: bool = True, keep_name: bool = True ) -> AnyOldResponseModel: - """Get a model to return in case of missing read permissions. + """Get a V1 model to return in case of missing read permissions. This function replaces all attributes except name and ID in the given model. @@ -276,9 +279,11 @@ def batch_verify_permissions_for_models( # The model owner always has permissions continue - model = get_surrogate_permission_model_for_model(model, action=action) + permission_model = get_surrogate_permission_model_for_model( + model, action=action + ) - if resource := get_resource_for_model(model): + if resource := get_resource_for_model(permission_model): resources.add(resource) batch_verify_permissions(resources=resources, action=action) @@ -410,7 +415,7 @@ def get_resource_for_model(model: AnyResponseModel) -> Optional[Resource]: def get_surrogate_permission_model_for_model( model: AnyResponseModel, action: str -) -> Union[BaseResponse[Any, Any], BaseResponseModel]: +) -> Union[BaseResponse, BaseResponseModel]: """Get a surrogate permission model for a model. In some cases a different model instead of the original model is used to @@ -458,7 +463,7 @@ def get_resource_type_for_model( ) mapping: Dict[ - Union[Type[BaseResponseModel], Type[BaseResponse[Any, Any]]], + Any, ResourceType, ] = { FlavorResponse: ResourceType.FLAVOR, From 55b5889d8a58f0d4681e5bc6c8e0ea78bf21bf3b Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 17 Nov 2023 14:34:19 +0100 Subject: [PATCH 046/103] Remove role from log message --- src/zenml/cli/server.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/zenml/cli/server.py b/src/zenml/cli/server.py index d613442eb18..571f43b8fc9 100644 --- a/src/zenml/cli/server.py +++ b/src/zenml/cli/server.py @@ -781,12 +781,10 @@ def connect( store_config = store_config_class.parse_obj(store_dict) try: GlobalConfiguration().set_store(store_config) - except IllegalOperationError as e: + except IllegalOperationError: cli_utils.warning( f"User '{username}' does not have sufficient permissions to " - f"to access the server at '{url}'. Please ask the server " - f"administrator to assign a role with permissions to your " - f"username: {str(e)}" + f"to access the server at '{url}'." ) if workspace: From b1829e7e576982f7ca97d7f4b6db1c5f2708577e Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 17 Nov 2023 15:14:06 +0100 Subject: [PATCH 047/103] Some adjustments for new model endpoints --- src/zenml/models/model_models.py | 21 +++++++++++++------ .../routers/workspaces_endpoints.py | 13 ++++-------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 2c12bc9defc..956e03f9fdf 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -130,13 +130,22 @@ def apply_filter( query = super().apply_filter(query=query, table=table) if self._rbac_allowed_model_ids is not None: - if self._rbac_user_id and hasattr(table, "user"): - query = query.where( - or_( - getattr(table, "model_id").in_(self._rbac_allowed_model_ids), # type: ignore[attr-defined] - getattr(table, "user") == self._rbac_user_id, + expressions = [ + getattr(table, "model_id").in_(self._rbac_allowed_model_ids) + ] + + if hasattr(table, "user_id"): + # Unowned entities are considered server-owned and can be seen + # by anyone + expressions.append(getattr(table, "user_id").is_(None)) + + if self._rbac_user_id: + # The authenticated user owns this entity + expressions.append( + getattr(table, "user_id") == self._rbac_user_id ) - ) + + query = query.where(or_(*expressions)) if self._model_id: query = query.where(getattr(table, "model_id") == self._model_id) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 7e28e20947a..4d4f054fc62 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1265,7 +1265,7 @@ def create_model_version( @handle_exceptions def create_model_version_artifact_link( workspace_name_or_id: Union[str, UUID], - model_version_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link: ModelVersionArtifactRequestModel, auth_context: AuthContext = Security(authorize), ) -> ModelVersionArtifactResponseModel: @@ -1273,7 +1273,7 @@ def create_model_version_artifact_link( Args: workspace_name_or_id: Name or ID of the workspace. - model_version_id: Name or ID of the model version. + model_version_id: ID of the model version. model_version_artifact_link: The model version to artifact link to create. auth_context: Authentication context. @@ -1287,7 +1287,6 @@ def create_model_version_artifact_link( """ workspace = zen_store().get_workspace(workspace_name_or_id) if str(model_version_id) != str(model_version_artifact_link.model_version): - breakpoint() raise IllegalOperationError( f"The model version id in your path `{model_version_id}` does not " f"match the model version specified in the request model " @@ -1306,9 +1305,7 @@ def create_model_version_artifact_link( "is not supported." ) - model_version = zen_store().get_model_version( - model_name_or_id, model_version_name_or_id - ) + model_version = zen_store().get_model_version(model_version_id) verify_permission_for_model(model_version, action=Action.UPDATE) mv = zen_store().create_model_version_artifact_link( @@ -1409,9 +1406,7 @@ def create_model_version_pipeline_run_link( "is not supported." ) - model_version = zen_store().get_model_version( - model_name_or_id, model_version_name_or_id - ) + model_version = zen_store().get_model_version(model_version_id) verify_permission_for_model(model_version, action=Action.UPDATE) mv = zen_store().create_model_version_pipeline_run_link( From 094c1f6293034bdf64342401cebb46669a3e271a Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 17 Nov 2023 15:34:30 +0100 Subject: [PATCH 048/103] Fix migrations --- .../7500f434b71c_remove_shared_columns.py | 4 +- .../86fa52918b54_remove_teams_and_roles.py | 135 ++++++++++++++++++ 2 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 src/zenml/zen_stores/migrations/versions/86fa52918b54_remove_teams_and_roles.py diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py index 93313facbb7..dfdaadcbd8d 100644 --- a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -1,7 +1,7 @@ """Remove shared columns [7500f434b71c]. Revision ID: 7500f434b71c -Revises: 0.45.4 +Revises: 14d687c8fa1c Create Date: 2023-10-16 15:15:34.865337 """ @@ -14,7 +14,7 @@ # revision identifiers, used by Alembic. revision = "7500f434b71c" -down_revision = "0.47.0" +down_revision = "14d687c8fa1c" branch_labels = None depends_on = None diff --git a/src/zenml/zen_stores/migrations/versions/86fa52918b54_remove_teams_and_roles.py b/src/zenml/zen_stores/migrations/versions/86fa52918b54_remove_teams_and_roles.py new file mode 100644 index 00000000000..34f4373c93e --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/86fa52918b54_remove_teams_and_roles.py @@ -0,0 +1,135 @@ +"""Remove teams and roles [86fa52918b54]. + +Revision ID: 86fa52918b54 +Revises: 7500f434b71c +Create Date: 2023-11-17 15:33:56.501617 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "86fa52918b54" +down_revision = "7500f434b71c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("role_permission") + op.drop_table("role") + op.drop_table("team_role_assignment") + op.drop_table("team") + op.drop_table("team_assignment") + op.drop_table("user_role_assignment") + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "user_role_assignment", + sa.Column("id", sa.CHAR(length=32), nullable=False), + sa.Column("role_id", sa.CHAR(length=32), nullable=False), + sa.Column("user_id", sa.CHAR(length=32), nullable=False), + sa.Column("workspace_id", sa.CHAR(length=32), nullable=True), + sa.Column("created", sa.DATETIME(), nullable=False), + sa.Column("updated", sa.DATETIME(), nullable=False), + sa.ForeignKeyConstraint( + ["role_id"], + ["role.id"], + name="fk_user_role_assignment_role_id_role", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_user_role_assignment_user_id_user", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name="fk_user_role_assignment_workspace_id_workspace", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "team_assignment", + sa.Column("user_id", sa.CHAR(length=32), nullable=False), + sa.Column("team_id", sa.CHAR(length=32), nullable=False), + sa.ForeignKeyConstraint( + ["team_id"], + ["team.id"], + name="fk_team_assignment_team_id_team", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_team_assignment_user_id_user", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("user_id", "team_id"), + ) + op.create_table( + "team", + sa.Column("id", sa.CHAR(length=32), nullable=False), + sa.Column("name", sa.VARCHAR(), nullable=False), + sa.Column("created", sa.DATETIME(), nullable=False), + sa.Column("updated", sa.DATETIME(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "team_role_assignment", + sa.Column("id", sa.CHAR(length=32), nullable=False), + sa.Column("role_id", sa.CHAR(length=32), nullable=False), + sa.Column("team_id", sa.CHAR(length=32), nullable=False), + sa.Column("workspace_id", sa.CHAR(length=32), nullable=True), + sa.Column("created", sa.DATETIME(), nullable=False), + sa.Column("updated", sa.DATETIME(), nullable=False), + sa.ForeignKeyConstraint( + ["role_id"], + ["role.id"], + name="fk_team_role_assignment_role_id_role", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["team_id"], + ["team.id"], + name="fk_team_role_assignment_team_id_team", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name="fk_team_role_assignment_workspace_id_workspace", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "role", + sa.Column("id", sa.CHAR(length=32), nullable=False), + sa.Column("name", sa.VARCHAR(), nullable=False), + sa.Column("created", sa.DATETIME(), nullable=False), + sa.Column("updated", sa.DATETIME(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "role_permission", + sa.Column("name", sa.VARCHAR(), nullable=False), + sa.Column("role_id", sa.CHAR(length=32), nullable=False), + sa.ForeignKeyConstraint( + ["role_id"], + ["role.id"], + name="fk_role_permission_role_id_role", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("name", "role_id"), + ) + # ### end Alembic commands ### From 446b76bb2aaa87728ff3efaf46524831dda560f6 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 17 Nov 2023 15:38:45 +0100 Subject: [PATCH 049/103] More mypy fixes --- src/zenml/zen_server/rbac/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 8d21379c90b..00b94de519e 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -415,7 +415,7 @@ def get_resource_for_model(model: AnyResponseModel) -> Optional[Resource]: def get_surrogate_permission_model_for_model( model: AnyResponseModel, action: str -) -> Union[BaseResponse, BaseResponseModel]: +) -> Union[BaseResponse[Any, Any], BaseResponseModel]: """Get a surrogate permission model for a model. In some cases a different model instead of the original model is used to From 7faefed1b63f3978b439e84e07fce67af0f47472 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 17 Nov 2023 15:48:38 +0100 Subject: [PATCH 050/103] Some docstrings --- src/zenml/zen_server/routers/model_versions_endpoints.py | 1 + src/zenml/zen_server/routers/models_endpoints.py | 1 + src/zenml/zen_server/routers/workspaces_endpoints.py | 4 ++-- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py index 1ba2b54e544..bf6cbec355b 100644 --- a/src/zenml/zen_server/routers/model_versions_endpoints.py +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -79,6 +79,7 @@ def list_model_versions( Args: model_version_filter_model: Filter model used for pagination, sorting, filtering + auth_context: The authentication context. Returns: The model versions according to query filters. diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 769ef31651c..248e04ad3f3 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -189,6 +189,7 @@ def list_model_versions( model_name_or_id: The name or ID of the model to list in. model_version_filter_model: Filter model used for pagination, sorting, filtering + auth_context: The authentication context. Returns: The model versions according to query filters. diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 4d4f054fc62..c904c99f6c4 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -281,7 +281,7 @@ def list_workspace_stacks( def create_stack( workspace_name_or_id: Union[str, UUID], stack: StackRequest, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> StackResponse: """Creates a stack for a particular workspace. @@ -370,7 +370,7 @@ def list_workspace_stack_components( def create_stack_component( workspace_name_or_id: Union[str, UUID], component: ComponentRequest, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> ComponentResponse: """Creates a stack component. From 05d28d8c8ecbc76972ac090606856e9af8647167 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 20 Nov 2023 12:41:55 +0100 Subject: [PATCH 051/103] Permissions for workspaces, service accounts and tags --- src/zenml/client.py | 123 ++++++------- src/zenml/zen_server/rbac/models.py | 3 + .../routers/service_accounts_endpoints.py | 54 ++++-- .../zen_server/routers/tags_endpoints.py | 40 ++++- .../zen_server/routers/users_endpoints.py | 2 +- .../routers/workspaces_endpoints.py | 169 ++++++++++-------- src/zenml/zen_stores/rest_zen_store.py | 88 ++++----- src/zenml/zen_stores/zen_store_interface.py | 68 +++---- .../functional/zen_stores/utils.py | 8 - 9 files changed, 314 insertions(+), 241 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index a02ad06bbae..66e22e996cd 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -64,6 +64,7 @@ from zenml.exceptions import ( AuthorizationException, EntityExistsError, + IllegalOperationError, InitializationException, ValidationError, ZenKeyError, @@ -149,7 +150,9 @@ UserResponse, UserUpdate, WorkspaceFilter, + WorkspaceRequest, WorkspaceResponse, + WorkspaceUpdate, ) from zenml.utils import io_utils, source_utils from zenml.utils.filesync_model import FileSyncModel @@ -802,21 +805,21 @@ def active_user(self) -> "UserResponse": # -------------------------------- Workspaces ------------------------------ - # def create_workspace( - # self, name: str, description: str - # ) -> WorkspaceResponse: - # """Create a new workspace. + def create_workspace( + self, name: str, description: str + ) -> WorkspaceResponse: + """Create a new workspace. - # Args: - # name: Name of the workspace. - # description: Description of the workspace. + Args: + name: Name of the workspace. + description: Description of the workspace. - # Returns: - # The created workspace. - # """ - # return self.zen_store.create_workspace( - # WorkspaceRequest(name=name, description=description) - # ) + Returns: + The created workspace. + """ + return self.zen_store.create_workspace( + WorkspaceRequest(name=name, description=description) + ) def get_workspace( self, @@ -880,53 +883,53 @@ def list_workspaces( ) ) - # def update_workspace( - # self, - # name_id_or_prefix: Optional[Union[UUID, str]], - # new_name: Optional[str] = None, - # new_description: Optional[str] = None, - # ) -> WorkspaceResponse: - # """Update a workspace. - - # Args: - # name_id_or_prefix: Name, ID or prefix of the workspace to update. - # new_name: New name of the workspace. - # new_description: New description of the workspace. - - # Returns: - # The updated workspace. - # """ - # workspace = self.get_workspace( - # name_id_or_prefix=name_id_or_prefix, allow_name_prefix_match=False - # ) - # workspace_update = WorkspaceUpdate(name=new_name or workspace.name) - # if new_description: - # workspace_update.description = new_description - # return self.zen_store.update_workspace( - # workspace_id=workspace.id, - # workspace_update=workspace_update, - # ) - - # def delete_workspace(self, name_id_or_prefix: str) -> None: - # """Delete a workspace. - - # Args: - # name_id_or_prefix: The name or ID of the workspace to delete. - - # Raises: - # IllegalOperationError: If the workspace to delete is the active - # workspace. - # """ - # workspace = self.get_workspace( - # name_id_or_prefix, allow_name_prefix_match=False - # ) - # if self.active_workspace.id == workspace.id: - # raise IllegalOperationError( - # f"Workspace '{name_id_or_prefix}' cannot be deleted since " - # "it is currently active. Please set another workspace as " - # "active first." - # ) - # self.zen_store.delete_workspace(workspace_name_or_id=workspace.id) + def update_workspace( + self, + name_id_or_prefix: Optional[Union[UUID, str]], + new_name: Optional[str] = None, + new_description: Optional[str] = None, + ) -> WorkspaceResponse: + """Update a workspace. + + Args: + name_id_or_prefix: Name, ID or prefix of the workspace to update. + new_name: New name of the workspace. + new_description: New description of the workspace. + + Returns: + The updated workspace. + """ + workspace = self.get_workspace( + name_id_or_prefix=name_id_or_prefix, allow_name_prefix_match=False + ) + workspace_update = WorkspaceUpdate(name=new_name or workspace.name) + if new_description: + workspace_update.description = new_description + return self.zen_store.update_workspace( + workspace_id=workspace.id, + workspace_update=workspace_update, + ) + + def delete_workspace(self, name_id_or_prefix: str) -> None: + """Delete a workspace. + + Args: + name_id_or_prefix: The name or ID of the workspace to delete. + + Raises: + IllegalOperationError: If the workspace to delete is the active + workspace. + """ + workspace = self.get_workspace( + name_id_or_prefix, allow_name_prefix_match=False + ) + if self.active_workspace.id == workspace.id: + raise IllegalOperationError( + f"Workspace '{name_id_or_prefix}' cannot be deleted since " + "it is currently active. Please set another workspace as " + "active first." + ) + self.zen_store.delete_workspace(workspace_name_or_id=workspace.id) @property def active_workspace(self) -> WorkspaceResponse: diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index baf48687cc1..fd983545832 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -50,6 +50,9 @@ class ResourceType(StrEnum): SERVICE_CONNECTOR = "service_connector" ARTIFACT = "artifact" SECRET = "secret" + TAG = "tag" + SERVICE_ACCOUNT = "service_account" + WORKSPACE = "workspace" class Resource(BaseModel): diff --git a/src/zenml/zen_server/routers/service_accounts_endpoints.py b/src/zenml/zen_server/routers/service_accounts_endpoints.py index 5d75e05fcd5..5fabea257c4 100644 --- a/src/zenml/zen_server/routers/service_accounts_endpoints.py +++ b/src/zenml/zen_server/routers/service_accounts_endpoints.py @@ -39,6 +39,15 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_create_entity, + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import verify_permission_for_model from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -78,10 +87,11 @@ def create_service_account( Returns: The created service account. """ - new_service_account = zen_store().create_service_account( - service_account=service_account + return verify_permissions_and_create_entity( + request_model=service_account, + resource_type=ResourceType.SERVICE_ACCOUNT, + create_method=zen_store().create_service_account, ) - return new_service_account @router.get( @@ -105,8 +115,9 @@ def get_service_account( Returns: The service account matching the given name or ID. """ - return zen_store().get_service_account( - service_account_name_or_id=service_account_name_or_id, + return verify_permissions_and_get_entity( + id=service_account_name_or_id, + get_method=zen_store().get_service_account, hydrate=hydrate, ) @@ -135,8 +146,11 @@ def list_service_accounts( Returns: A list of service accounts matching the filter. """ - return zen_store().list_service_accounts( - filter_model=filter_model, hydrate=hydrate + return verify_permissions_and_list_entities( + filter_model=filter_model, + resource_type=ResourceType.SERVICE_ACCOUNT, + list_method=zen_store().list_service_accounts, + hydrate=hydrate, ) @@ -164,9 +178,11 @@ def update_service_account( Returns: The updated service account. """ - return zen_store().update_service_account( - service_account_name_or_id=service_account_name_or_id, - service_account_update=service_account_update, + return verify_permissions_and_update_entity( + id=service_account_name_or_id, + update_model=service_account_update, + get_method=zen_store().get_service_account, + update_method=zen_store().update_service_account, ) @@ -184,7 +200,11 @@ def delete_service_account( Args: service_account_name_or_id: Name or ID of the service account. """ - zen_store().delete_service_account(service_account_name_or_id) + return verify_permissions_and_delete_entity( + id=service_account_name_or_id, + get_method=zen_store().get_service_account, + delete_method=zen_store().delete_service_account, + ) # -------- @@ -213,6 +233,8 @@ def create_api_key( Returns: The created API key. """ + service_account = zen_store().get_service_account(service_account_id) + verify_permission_for_model(service_account, action=Action.UPDATE) created_api_key = zen_store().create_api_key( service_account_id=service_account_id, api_key=api_key, @@ -244,6 +266,8 @@ def get_api_key( Returns: The requested API key. """ + service_account = zen_store().get_service_account(service_account_id) + verify_permission_for_model(service_account, action=Action.READ) api_key = zen_store().get_api_key( service_account_id=service_account_id, api_key_name_or_id=api_key_name_or_id, @@ -278,6 +302,8 @@ def list_api_keys( All API keys matching the filter and associated with the supplied service account. """ + service_account = zen_store().get_service_account(service_account_id) + verify_permission_for_model(service_account, action=Action.READ) return zen_store().list_api_keys( service_account_id=service_account_id, filter_model=filter_model, @@ -308,6 +334,8 @@ def update_api_key( Returns: The updated API key. """ + service_account = zen_store().get_service_account(service_account_id) + verify_permission_for_model(service_account, action=Action.UPDATE) return zen_store().update_api_key( service_account_id=service_account_id, api_key_name_or_id=api_key_name_or_id, @@ -341,6 +369,8 @@ def rotate_api_key( Returns: The updated API key. """ + service_account = zen_store().get_service_account(service_account_id) + verify_permission_for_model(service_account, action=Action.UPDATE) return zen_store().rotate_api_key( service_account_id=service_account_id, api_key_name_or_id=api_key_name_or_id, @@ -365,6 +395,8 @@ def delete_api_key( belongs. api_key_name_or_id: Name or ID of the API key to delete. """ + service_account = zen_store().get_service_account(service_account_id) + verify_permission_for_model(service_account, action=Action.UPDATE) zen_store().delete_api_key( service_account_id=service_account_id, api_key_name_or_id=api_key_name_or_id, diff --git a/src/zenml/zen_server/routers/tags_endpoints.py b/src/zenml/zen_server/routers/tags_endpoints.py index 56814b3d9c1..d20111146d9 100644 --- a/src/zenml/zen_server/routers/tags_endpoints.py +++ b/src/zenml/zen_server/routers/tags_endpoints.py @@ -32,6 +32,14 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_create_entity, + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import ResourceType from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -45,7 +53,7 @@ router = APIRouter( prefix=API + VERSION_1 + TAGS, tags=["tags"], - responses={401: error_response}, + responses={401: error_response, 403: error_response}, ) @@ -67,7 +75,11 @@ def create_tag( Returns: The created tag. """ - return zen_store().create_tag(tag) + return verify_permissions_and_create_entity( + request_model=tag, + resource_type=ResourceType.TAG, + create_method=zen_store().create_tag, + ) @router.get( @@ -92,8 +104,10 @@ def list_tags( Returns: The tags according to query filters. """ - return zen_store().list_tags( - tag_filter_model=tag_filter_model, + return verify_permissions_and_list_entities( + filter_model=tag_filter_model, + resource_type=ResourceType.TAG, + list_method=zen_store().list_tags, ) @@ -115,7 +129,9 @@ def get_tag( Returns: The tag with the given name or ID. """ - return zen_store().get_tag(tag_name_or_id) + return verify_permissions_and_get_entity( + id=tag_name_or_id, get_method=zen_store().get_tag + ) @router.put( @@ -138,9 +154,11 @@ def update_tag( Returns: The updated tag. """ - return zen_store().update_tag( - tag_name_or_id=tag_id, - tag_update_model=tag_update_model, + return verify_permissions_and_update_entity( + id=tag_id, + update_model=tag_update_model, + get_method=zen_store().get_tag, + update_method=zen_store().update_tag, ) @@ -158,4 +176,8 @@ def delete_tag( Args: tag_name_or_id: The name or ID of the tag to delete. """ - zen_store().delete_tag(tag_name_or_id) + return verify_permissions_and_delete_entity( + id=tag_name_or_id, + get_method=zen_store().get_tag, + delete_method=zen_store().delete_tag, + ) diff --git a/src/zenml/zen_server/routers/users_endpoints.py b/src/zenml/zen_server/routers/users_endpoints.py index 5e43065a366..97233db3a2d 100644 --- a/src/zenml/zen_server/routers/users_endpoints.py +++ b/src/zenml/zen_server/routers/users_endpoints.py @@ -332,7 +332,7 @@ def email_opt_in_response( ) else: raise AuthorizationException( - "Users can not opt in on behalf of another " "user." + "Users can not opt in on behalf of another user." ) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index c904c99f6c4..42dfb9d6c5d 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -85,13 +85,18 @@ StackRequest, StackResponse, WorkspaceFilter, + WorkspaceRequest, WorkspaceResponse, + WorkspaceUpdate, ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response from zenml.zen_server.rbac.endpoint_utils import ( verify_permissions_and_create_entity, + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, verify_permissions_and_list_entities, + verify_permissions_and_update_entity, ) from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( @@ -135,32 +140,38 @@ def list_workspaces( Returns: A list of workspaces. """ - return zen_store().list_workspaces( - workspace_filter_model=workspace_filter_model, hydrate=hydrate + return verify_permissions_and_list_entities( + filter_model=workspace_filter_model, + resource_type=ResourceType.WORKSPACE, + list_method=zen_store().list_workspaces, + hydrate=hydrate, ) -# @router.post( -# WORKSPACES, -# response_model=WorkspaceResponseModel, -# responses={401: error_response, 409: error_response, 422: error_response}, -# ) -# @handle_exceptions -# def create_workspace( -# workspace: WorkspaceRequest, -# _: AuthContext = Security(authorize), -# ) -> WorkspaceResponse: -# """Creates a workspace based on the requestBody. +@router.post( + WORKSPACES, + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def create_workspace( + workspace: WorkspaceRequest, + _: AuthContext = Security(authorize), +) -> WorkspaceResponse: + """Creates a workspace based on the requestBody. -# # noqa: DAR401 + # noqa: DAR401 -# Args: -# workspace: Workspace to create. + Args: + workspace: Workspace to create. -# Returns: -# The created workspace. -# """ -# return zen_store().create_workspace(workspace=workspace) + Returns: + The created workspace. + """ + return verify_permissions_and_create_entity( + request_model=workspace, + resource_type=ResourceType.WORKSPACE, + create_method=zen_store().create_workspace, + ) @router.get( @@ -186,53 +197,61 @@ def get_workspace( Returns: The requested workspace. """ - return zen_store().get_workspace( - workspace_name_or_id=workspace_name_or_id, hydrate=hydrate + return verify_permissions_and_get_entity( + id=workspace_name_or_id, + get_method=zen_store().get_workspace, + hydrate=hydrate, ) -# @router.put( -# WORKSPACES + "/{workspace_name_or_id}", -# responses={401: error_response, 404: error_response, 422: error_response}, -# ) -# @handle_exceptions -# def update_workspace( -# workspace_name_or_id: UUID, -# workspace_update: WorkspaceUpdate, -# _: AuthContext = Security(authorize), -# ) -> WorkspaceResponse: -# """Get a workspace for given name. - -# # noqa: DAR401 - -# Args: -# workspace_name_or_id: Name or ID of the workspace to update. -# workspace_update: the workspace to use to update - -# Returns: -# The updated workspace. -# """ -# return zen_store().update_workspace( -# workspace_id=workspace_name_or_id, -# workspace_update=workspace_update, -# ) - - -# @router.delete( -# WORKSPACES + "/{workspace_name_or_id}", -# responses={401: error_response, 404: error_response, 422: error_response}, -# ) -# @handle_exceptions -# def delete_workspace( -# workspace_name_or_id: Union[str, UUID], -# _: AuthContext = Security(authorize), -# ) -> None: -# """Deletes a workspace. - -# Args: -# workspace_name_or_id: Name or ID of the workspace. -# """ -# zen_store().delete_workspace(workspace_name_or_id=workspace_name_or_id) +@router.put( + WORKSPACES + "/{workspace_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def update_workspace( + workspace_name_or_id: UUID, + workspace_update: WorkspaceUpdate, + _: AuthContext = Security(authorize), +) -> WorkspaceResponse: + """Get a workspace for given name. + + # noqa: DAR401 + + Args: + workspace_name_or_id: Name or ID of the workspace to update. + workspace_update: the workspace to use to update + + Returns: + The updated workspace. + """ + return verify_permissions_and_update_entity( + id=workspace_name_or_id, + update_model=workspace_update, + get_method=zen_store().get_workspace, + update_method=zen_store().update_model, + ) + + +@router.delete( + WORKSPACES + "/{workspace_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_workspace( + workspace_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize), +) -> None: + """Deletes a workspace. + + Args: + workspace_name_or_id: Name or ID of the workspace. + """ + return verify_permissions_and_delete_entity( + id=workspace_name_or_id, + get_method=zen_store().get_workspace, + delete_method=zen_store().delete_workspace, + ) @router.get( @@ -333,7 +352,7 @@ def list_workspace_stack_components( make_dependable(ComponentFilter) ), hydrate: bool = False, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> Page[ComponentResponse]: """List stack components that are part of a specific workspace. @@ -345,7 +364,6 @@ def list_workspace_stack_components( filtering. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. - auth_context: Authentication Context. Returns: All stack components part of the specified workspace. @@ -868,21 +886,20 @@ def create_run_metadata( def create_secret( workspace_name_or_id: Union[str, UUID], secret: SecretRequestModel, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> SecretResponseModel: """Creates a secret. Args: workspace_name_or_id: Name or ID of the workspace. secret: Secret to create. - auth_context: Authentication context. Returns: The created secret. Raises: - IllegalOperationError: If the workspace or user specified in the - secret does not match the current workspace or authenticated user. + IllegalOperationError: If the workspace specified in the + secret does not match the current workspace. """ workspace = zen_store().get_workspace(workspace_name_or_id) @@ -892,12 +909,12 @@ def create_secret( f"of this endpoint `{workspace_name_or_id}` is " f"not supported." ) - if secret.user != auth_context.user.id: - raise IllegalOperationError( - "Creating secrets for a user other than yourself " - "is not supported." - ) - return zen_store().create_secret(secret=secret) + + return verify_permissions_and_create_entity( + request_model=secret, + resource_type=ResourceType.SECRET, + create_method=zen_store().create_secret, + ) @router.get( diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 5a3a93ffd7b..571b63bbbf2 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -185,9 +185,11 @@ UserResponse, UserUpdate, WorkspaceFilter, + WorkspaceRequest, WorkspaceResponse, WorkspaceScopedRequest, WorkspaceScopedRequestModel, + WorkspaceUpdate, ) from zenml.service_connectors.service_connector_registry import ( service_connector_registry, @@ -2396,22 +2398,22 @@ def delete_user(self, user_name_or_id: Union[str, UUID]) -> None: # ----------------------------- Workspaces ----------------------------- - # def create_workspace( - # self, workspace: WorkspaceRequest - # ) -> WorkspaceResponse: - # """Creates a new workspace. + def create_workspace( + self, workspace: WorkspaceRequest + ) -> WorkspaceResponse: + """Creates a new workspace. - # Args: - # workspace: The workspace to create. + Args: + workspace: The workspace to create. - # Returns: - # The newly created workspace. - # """ - # return self._create_resource( - # resource=workspace, - # route=WORKSPACES, - # response_model=WorkspaceResponse, - # ) + Returns: + The newly created workspace. + """ + return self._create_resource( + resource=workspace, + route=WORKSPACES, + response_model=WorkspaceResponse, + ) def get_workspace( self, workspace_name_or_id: Union[UUID, str], hydrate: bool = True @@ -2456,35 +2458,35 @@ def list_workspaces( params={"hydrate": hydrate}, ) - # def update_workspace( - # self, workspace_id: UUID, workspace_update: WorkspaceUpdate - # ) -> WorkspaceResponse: - # """Update an existing workspace. - - # Args: - # workspace_id: The ID of the workspace to be updated. - # workspace_update: The update to be applied to the workspace. - - # Returns: - # The updated workspace. - # """ - # return self._update_resource( - # resource_id=workspace_id, - # resource_update=workspace_update, - # route=WORKSPACES, - # response_model=WorkspaceResponse, - # ) - - # def delete_workspace(self, workspace_name_or_id: Union[str, UUID]) -> None: - # """Deletes a workspace. - - # Args: - # workspace_name_or_id: Name or ID of the workspace to delete. - # """ - # self._delete_resource( - # resource_id=workspace_name_or_id, - # route=WORKSPACES, - # ) + def update_workspace( + self, workspace_id: UUID, workspace_update: WorkspaceUpdate + ) -> WorkspaceResponse: + """Update an existing workspace. + + Args: + workspace_id: The ID of the workspace to be updated. + workspace_update: The update to be applied to the workspace. + + Returns: + The updated workspace. + """ + return self._update_resource( + resource_id=workspace_id, + resource_update=workspace_update, + route=WORKSPACES, + response_model=WorkspaceResponse, + ) + + def delete_workspace(self, workspace_name_or_id: Union[str, UUID]) -> None: + """Deletes a workspace. + + Args: + workspace_name_or_id: Name or ID of the workspace to delete. + """ + self._delete_resource( + resource_id=workspace_name_or_id, + route=WORKSPACES, + ) ######### # Model diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 8daa32b77db..b5f74f12647 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -107,7 +107,9 @@ UserResponse, UserUpdate, WorkspaceFilter, + WorkspaceRequest, WorkspaceResponse, + WorkspaceUpdate, ) @@ -1768,21 +1770,21 @@ def delete_user(self, user_name_or_id: Union[str, UUID]) -> None: # -------------------- Workspaces -------------------- - # @abstractmethod - # def create_workspace( - # self, workspace: WorkspaceRequest - # ) -> WorkspaceResponse: - # """Creates a new workspace. + @abstractmethod + def create_workspace( + self, workspace: WorkspaceRequest + ) -> WorkspaceResponse: + """Creates a new workspace. - # Args: - # workspace: The workspace to create. + Args: + workspace: The workspace to create. - # Returns: - # The newly created workspace. + Returns: + The newly created workspace. - # Raises: - # EntityExistsError: If a workspace with the given name already exists. - # """ + Raises: + EntityExistsError: If a workspace with the given name already exists. + """ @abstractmethod def get_workspace( @@ -1820,33 +1822,33 @@ def list_workspaces( A list of all workspace matching the filter criteria. """ - # @abstractmethod - # def update_workspace( - # self, workspace_id: UUID, workspace_update: WorkspaceUpdate - # ) -> WorkspaceResponse: - # """Update an existing workspace. + @abstractmethod + def update_workspace( + self, workspace_id: UUID, workspace_update: WorkspaceUpdate + ) -> WorkspaceResponse: + """Update an existing workspace. - # Args: - # workspace_id: The ID of the workspace to be updated. - # workspace_update: The update to be applied to the workspace. + Args: + workspace_id: The ID of the workspace to be updated. + workspace_update: The update to be applied to the workspace. - # Returns: - # The updated workspace. + Returns: + The updated workspace. - # Raises: - # KeyError: if the workspace does not exist. - # """ + Raises: + KeyError: if the workspace does not exist. + """ - # @abstractmethod - # def delete_workspace(self, workspace_name_or_id: Union[str, UUID]) -> None: - # """Deletes a workspace. + @abstractmethod + def delete_workspace(self, workspace_name_or_id: Union[str, UUID]) -> None: + """Deletes a workspace. - # Args: - # workspace_name_or_id: Name or ID of the workspace to delete. + Args: + workspace_name_or_id: Name or ID of the workspace to delete. - # Raises: - # KeyError: If no workspace with the given name exists. - # """ + Raises: + KeyError: If no workspace with the given name exists. + """ # -------------------- Model -------------------- diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index c2a58eec8af..b4ea16ec7ad 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -187,10 +187,6 @@ def __enter__(self): name=self.user_name, password=self.password, active=True ) self.created_user = self.store.create_user(new_user) - self.client.create_user_role_assignment( - role_name_or_id="admin", - user_name_or_id=self.created_user.id, - ) else: self.created_user = self.store.get_user(self.user_name) @@ -253,10 +249,6 @@ def __enter__(self): self.created_service_account = self.store.create_service_account( new_account ) - self.client.create_user_role_assignment( - role_name_or_id="admin", - user_name_or_id=self.created_service_account.id, - ) else: self.created_service_account = self.store.get_service_account( self.name From 1e5f0807cdeb9300bf756999951f58ac8fd2f2da Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 20 Nov 2023 14:53:13 +0100 Subject: [PATCH 052/103] Early exit if rbac disabled --- src/zenml/zen_server/rbac/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 00b94de519e..ad43dff910a 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -98,6 +98,9 @@ def dehydrate_response_model( Returns: The (potentially) dehydrated model. """ + if not server_config().rbac_enabled: + return model + dehydrated_fields = {} for field_name in model.__fields__.keys(): From c0412a41ae3000f23cd04b0ba6f672a810c328e0 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 20 Nov 2023 14:53:44 +0100 Subject: [PATCH 053/103] Run test only with rest --- tests/integration/functional/zen_stores/test_zen_store.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 2a0cfc27947..8d80903fb09 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2750,6 +2750,8 @@ def test_connector_name_reuse_for_same_user_fails(): def test_connector_name_reuse_for_different_user_fails(): """Tests that a connector's name cannot be re-used by another user.""" + if Client().zen_store.type == StoreType.SQL: + pytest.skip("SQL Zen Stores do not support user switching.") with ServiceConnectorContext( connector_type="cat'o'matic", From 2b9bebcbbfdbafdaa0f45af2be26003217c76bbb Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 20 Nov 2023 15:25:26 +0100 Subject: [PATCH 054/103] Fix sql pagination --- src/zenml/zen_server/routers/workspaces_endpoints.py | 2 +- src/zenml/zen_stores/sql_zen_store.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 42dfb9d6c5d..be95df0114d 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -229,7 +229,7 @@ def update_workspace( id=workspace_name_or_id, update_model=workspace_update, get_method=zen_store().get_workspace, - update_method=zen_store().update_model, + update_method=zen_store().update_workspace, ) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index f28548454ac..92e83c57ca0 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -840,7 +840,7 @@ def filter_and_paginate( "since it does not have a `to_model` method." ) - return Page[B]( + return Page[Any]( total=total, total_pages=total_pages, items=items, From 438e7182a212474f1796e353875c84d68b0d2742 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 20 Nov 2023 15:44:34 +0100 Subject: [PATCH 055/103] Small adjustments --- src/zenml/models/v2/base/base.py | 1 + src/zenml/zen_server/rbac/utils.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/zenml/models/v2/base/base.py b/src/zenml/models/v2/base/base.py index 6204f5349d8..33ec7cc2691 100644 --- a/src/zenml/models/v2/base/base.py +++ b/src/zenml/models/v2/base/base.py @@ -102,6 +102,7 @@ class BaseResponse(GenericModel, Generic[AnyBody, AnyMetadata], BaseZenModel): """Base domain model.""" id: UUID = Field(title="The unique resource id.") + permission_denied: bool = False # Body and metadata pair body: Optional["AnyBody"] = Field(title="The body of the resource.") diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index ad43dff910a..7c39002dbb3 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -101,6 +101,15 @@ def dehydrate_response_model( if not server_config().rbac_enabled: return model + if not permissions: + auth_context = get_auth_context() + assert auth_context + + resources = get_subresources_for_model(model) + permissions = rbac().check_permissions( + user=auth_context.user, resources=resources, action=Action.READ + ) + dehydrated_fields = {} for field_name in model.__fields__.keys(): @@ -208,7 +217,9 @@ def get_permission_denied_model_v2( Returns: The model with body and metadata removed. """ - return model.copy(exclude={"body", "metadata"}) + return model.copy( + exclude={"body", "metadata"}, update={"permission_denied": True} + ) def get_permission_denied_model_v1( From ebda193c4685ded774aca649f3683334034de264 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 20 Nov 2023 17:12:10 +0100 Subject: [PATCH 056/103] Fix service accounts --- src/zenml/zen_server/rbac/zenml_cloud_rbac.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py index 606d6ff2b03..f03584456c2 100644 --- a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -150,6 +150,10 @@ def check_permissions( # No need to send a request if there are no resources return {} + if user.is_service_account: + # Service accounts have full permissions for now + return {resource: True for resource in resources} + params = { "user_id": str(user.external_user_id), "resources": [ @@ -183,6 +187,11 @@ def list_allowed_resource_ids( """ assert not resource.id assert user.external_user_id + + if user.is_service_account: + # Service accounts have full permissions for now + return True, [] + params = { "user_id": str(user.external_user_id), "resource": _convert_to_cloud_resource(resource), From 5047b91262e18c4bd16b7062110853c82d0b88f9 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 20 Nov 2023 17:17:11 +0100 Subject: [PATCH 057/103] Reorder migration table drops --- .../86fa52918b54_remove_teams_and_roles.py | 70 +++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/src/zenml/zen_stores/migrations/versions/86fa52918b54_remove_teams_and_roles.py b/src/zenml/zen_stores/migrations/versions/86fa52918b54_remove_teams_and_roles.py index 34f4373c93e..be285c6af52 100644 --- a/src/zenml/zen_stores/migrations/versions/86fa52918b54_remove_teams_and_roles.py +++ b/src/zenml/zen_stores/migrations/versions/86fa52918b54_remove_teams_and_roles.py @@ -19,11 +19,11 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" # ### commands auto generated by Alembic - please adjust! ### op.drop_table("role_permission") - op.drop_table("role") op.drop_table("team_role_assignment") - op.drop_table("team") - op.drop_table("team_assignment") op.drop_table("user_role_assignment") + op.drop_table("team_assignment") + op.drop_table("team") + op.drop_table("role") # ### end Alembic commands ### @@ -31,31 +31,19 @@ def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" # ### commands auto generated by Alembic - please adjust! ### op.create_table( - "user_role_assignment", + "role", sa.Column("id", sa.CHAR(length=32), nullable=False), - sa.Column("role_id", sa.CHAR(length=32), nullable=False), - sa.Column("user_id", sa.CHAR(length=32), nullable=False), - sa.Column("workspace_id", sa.CHAR(length=32), nullable=True), + sa.Column("name", sa.VARCHAR(), nullable=False), + sa.Column("created", sa.DATETIME(), nullable=False), + sa.Column("updated", sa.DATETIME(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "team", + sa.Column("id", sa.CHAR(length=32), nullable=False), + sa.Column("name", sa.VARCHAR(), nullable=False), sa.Column("created", sa.DATETIME(), nullable=False), sa.Column("updated", sa.DATETIME(), nullable=False), - sa.ForeignKeyConstraint( - ["role_id"], - ["role.id"], - name="fk_user_role_assignment_role_id_role", - ondelete="CASCADE", - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["user.id"], - name="fk_user_role_assignment_user_id_user", - ondelete="CASCADE", - ), - sa.ForeignKeyConstraint( - ["workspace_id"], - ["workspace.id"], - name="fk_user_role_assignment_workspace_id_workspace", - ondelete="CASCADE", - ), sa.PrimaryKeyConstraint("id"), ) op.create_table( @@ -77,11 +65,31 @@ def downgrade() -> None: sa.PrimaryKeyConstraint("user_id", "team_id"), ) op.create_table( - "team", + "user_role_assignment", sa.Column("id", sa.CHAR(length=32), nullable=False), - sa.Column("name", sa.VARCHAR(), nullable=False), + sa.Column("role_id", sa.CHAR(length=32), nullable=False), + sa.Column("user_id", sa.CHAR(length=32), nullable=False), + sa.Column("workspace_id", sa.CHAR(length=32), nullable=True), sa.Column("created", sa.DATETIME(), nullable=False), sa.Column("updated", sa.DATETIME(), nullable=False), + sa.ForeignKeyConstraint( + ["role_id"], + ["role.id"], + name="fk_user_role_assignment_role_id_role", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_user_role_assignment_user_id_user", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name="fk_user_role_assignment_workspace_id_workspace", + ondelete="CASCADE", + ), sa.PrimaryKeyConstraint("id"), ) op.create_table( @@ -112,14 +120,6 @@ def downgrade() -> None: ), sa.PrimaryKeyConstraint("id"), ) - op.create_table( - "role", - sa.Column("id", sa.CHAR(length=32), nullable=False), - sa.Column("name", sa.VARCHAR(), nullable=False), - sa.Column("created", sa.DATETIME(), nullable=False), - sa.Column("updated", sa.DATETIME(), nullable=False), - sa.PrimaryKeyConstraint("id"), - ) op.create_table( "role_permission", sa.Column("name", sa.VARCHAR(), nullable=False), From ad4f2bfc825672cc6a8553fe8a42614c84fd10e2 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 20 Nov 2023 17:40:11 +0100 Subject: [PATCH 058/103] Some more endpoints fixed --- src/zenml/zen_server/rbac/models.py | 3 + .../routers/pipeline_builds_endpoints.py | 23 ++++++-- .../routers/pipeline_deployments_endpoints.py | 25 +++++++-- .../zen_server/routers/runs_endpoints.py | 51 ++++++++++++++--- .../zen_server/routers/steps_endpoints.py | 39 +++++++++++-- .../routers/workspaces_endpoints.py | 55 +++++++++++-------- 6 files changed, 151 insertions(+), 45 deletions(-) diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index fd983545832..6c0d20f05de 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -53,6 +53,9 @@ class ResourceType(StrEnum): TAG = "tag" SERVICE_ACCOUNT = "service_account" WORKSPACE = "workspace" + PIPELINE_RUN = "pipeline_run" + PIPELINE_DEPLOYMENT = "pipeline_deployment" + PIPELINE_BUILD = "pipeline_build" class Resource(BaseModel): diff --git a/src/zenml/zen_server/routers/pipeline_builds_endpoints.py b/src/zenml/zen_server/routers/pipeline_builds_endpoints.py index 7e70de2405f..36f014ca832 100644 --- a/src/zenml/zen_server/routers/pipeline_builds_endpoints.py +++ b/src/zenml/zen_server/routers/pipeline_builds_endpoints.py @@ -20,6 +20,12 @@ from zenml.models import Page, PipelineBuildFilter, PipelineBuildResponse from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, +) +from zenml.zen_server.rbac.models import ResourceType from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -57,8 +63,11 @@ def list_builds( Returns: List of build objects. """ - return zen_store().list_builds( - build_filter_model=build_filter_model, hydrate=hydrate + return verify_permissions_and_list_entities( + filter_model=build_filter_model, + resource_type=ResourceType.PIPELINE_BUILD, + list_method=zen_store().list_builds, + hydrate=hydrate, ) @@ -83,7 +92,9 @@ def get_build( Returns: A specific build object. """ - return zen_store().get_build(build_id=build_id, hydrate=hydrate) + return verify_permissions_and_get_entity( + id=build_id, get_method=zen_store().get_build, hydrate=hydrate + ) @router.delete( @@ -100,4 +111,8 @@ def delete_build( Args: build_id: ID of the build to delete. """ - zen_store().delete_build(build_id=build_id) + verify_permissions_and_delete_entity( + id=build_id, + get_method=zen_store().get_build, + delete_method=zen_store().delete_build, + ) diff --git a/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py b/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py index 9f882dbc52b..66caed6c994 100644 --- a/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py +++ b/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py @@ -24,6 +24,12 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, +) +from zenml.zen_server.rbac.models import ResourceType from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -61,8 +67,11 @@ def list_deployments( Returns: List of deployment objects. """ - return zen_store().list_deployments( - deployment_filter_model=deployment_filter_model, hydrate=hydrate + return verify_permissions_and_list_entities( + filter_model=deployment_filter_model, + resource_type=ResourceType.PIPELINE_DEPLOYMENT, + list_method=zen_store().list_deployments, + hydrate=hydrate, ) @@ -87,8 +96,10 @@ def get_deployment( Returns: A specific deployment object. """ - return zen_store().get_deployment( - deployment_id=deployment_id, hydrate=hydrate + return verify_permissions_and_get_entity( + id=deployment_id, + get_method=zen_store().get_deployment, + hydrate=hydrate, ) @@ -106,4 +117,8 @@ def delete_deployment( Args: deployment_id: ID of the deployment to delete. """ - zen_store().delete_deployment(deployment_id=deployment_id) + verify_permissions_and_delete_entity( + id=deployment_id, + get_method=zen_store().get_deployment, + delete_method=zen_store().delete_deployment, + ) diff --git a/src/zenml/zen_server/routers/runs_endpoints.py b/src/zenml/zen_server/routers/runs_endpoints.py index 7692cb81f44..d2493286eac 100644 --- a/src/zenml/zen_server/routers/runs_endpoints.py +++ b/src/zenml/zen_server/routers/runs_endpoints.py @@ -38,6 +38,13 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import ResourceType from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -74,8 +81,11 @@ def list_runs( Returns: The pipeline runs according to query filters. """ - return zen_store().list_runs( - runs_filter_model=runs_filter_model, hydrate=hydrate + return verify_permissions_and_list_entities( + filter_model=runs_filter_model, + resource_type=ResourceType.PIPELINE_RUN, + list_method=zen_store().list_runs, + hydrate=hydrate, ) @@ -100,7 +110,9 @@ def get_run( Returns: The pipeline run. """ - return zen_store().get_run(run_name_or_id=run_id, hydrate=hydrate) + return verify_permissions_and_get_entity( + id=run_id, get_method=zen_store().get_run, hydrate=hydrate + ) @router.put( @@ -123,7 +135,12 @@ def update_run( Returns: The updated run model. """ - return zen_store().update_run(run_id=run_id, run_update=run_model) + return verify_permissions_and_update_entity( + id=run_id, + update_model=run_model, + get_method=zen_store().get_run, + update_method=zen_store().update_run, + ) @router.delete( @@ -140,7 +157,11 @@ def delete_run( Args: run_id: ID of the run. """ - zen_store().delete_run(run_id=run_id) + verify_permissions_and_delete_entity( + id=run_id, + get_method=zen_store().get_run, + delete_method=zen_store().delete_run, + ) @router.get( @@ -161,7 +182,9 @@ def get_run_dag( Returns: The DAG for a given pipeline run. """ - run = zen_store().get_run(run_name_or_id=run_id) + run = verify_permissions_and_get_entity( + id=run_id, get_method=zen_store().get_run, hydrate=True + ) graph = LineageGraph() graph.generate_run_nodes_and_edges(run) return graph @@ -174,6 +197,7 @@ def get_run_dag( ) @handle_exceptions def get_run_steps( + run_id: UUID, step_run_filter_model: StepRunFilter = Depends( make_dependable(StepRunFilter) ), @@ -182,12 +206,17 @@ def get_run_steps( """Get all steps for a given pipeline run. Args: + run_id: ID of the pipeline run. step_run_filter_model: Filter model used for pagination, sorting, filtering Returns: The steps for a given pipeline run. """ + verify_permissions_and_get_entity( + id=run_id, get_method=zen_store().get_run, hydrate=False + ) + step_run_filter_model.pipeline_run_id = run_id return zen_store().list_run_steps(step_run_filter_model) @@ -209,7 +238,10 @@ def get_pipeline_configuration( Returns: The pipeline configuration of the pipeline run. """ - return zen_store().get_run(run_name_or_id=run_id).config.dict() + run = verify_permissions_and_get_entity( + id=run_id, get_method=zen_store().get_run, hydrate=True + ) + return run.config.dict() @router.get( @@ -230,4 +262,7 @@ def get_run_status( Returns: The status of the pipeline run. """ - return zen_store().get_run(run_id).status + run = verify_permissions_and_get_entity( + id=run_id, get_method=zen_store().get_run, hydrate=False + ) + return run.status diff --git a/src/zenml/zen_server/routers/steps_endpoints.py b/src/zenml/zen_server/routers/steps_endpoints.py index 59ab254c614..3105a5875af 100644 --- a/src/zenml/zen_server/routers/steps_endpoints.py +++ b/src/zenml/zen_server/routers/steps_endpoints.py @@ -40,6 +40,11 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.models import Action +from zenml.zen_server.rbac.utils import ( + dehydrate_response_model, + verify_permission_for_model, +) from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -100,6 +105,9 @@ def create_run_step( Returns: The created run step. """ + pipeline_run = zen_store().get_run(step.pipeline_run_id) + verify_permission_for_model(pipeline_run, action=Action.UPDATE) + return zen_store().create_run_step(step_run=step) @@ -124,7 +132,11 @@ def get_step( Returns: The step. """ - return zen_store().get_run_step(step_id, hydrate=hydrate) + step = zen_store().get_run_step(step_id, hydrate=hydrate) + pipeline_run = zen_store().get_run(step.pipeline_run_id) + verify_permission_for_model(pipeline_run, action=Action.READ) + + return dehydrate_response_model(step) @router.put( @@ -147,9 +159,14 @@ def update_step( Returns: The updated step model. """ - return zen_store().update_run_step( + step = zen_store().get_run_step(step_id, hydrate=True) + pipeline_run = zen_store().get_run(step.pipeline_run_id) + verify_permission_for_model(pipeline_run, action=Action.UPDATE) + + updated_step = zen_store().update_run_step( step_run_id=step_id, step_run_update=step_model ) + return dehydrate_response_model(updated_step) @router.get( @@ -170,7 +187,11 @@ def get_step_configuration( Returns: The step configuration. """ - return zen_store().get_run_step(step_id).config.dict() + step = zen_store().get_run_step(step_id, hydrate=True) + pipeline_run = zen_store().get_run(step.pipeline_run_id) + verify_permission_for_model(pipeline_run, action=Action.READ) + + return step.config.dict() @router.get( @@ -191,7 +212,11 @@ def get_step_status( Returns: The status of the step. """ - return zen_store().get_run_step(step_id).status + step = zen_store().get_run_step(step_id, hydrate=True) + pipeline_run = zen_store().get_run(step.pipeline_run_id) + verify_permission_for_model(pipeline_run, action=Action.READ) + + return step.status @router.get( @@ -215,8 +240,12 @@ def get_step_logs( Raises: HTTPException: If no logs are available for this step. """ + step = zen_store().get_run_step(step_id, hydrate=True) + pipeline_run = zen_store().get_run(step.pipeline_run_id) + verify_permission_for_model(pipeline_run, action=Action.READ) + store = zen_store() - logs = store.get_run_step(step_id).logs + logs = step.logs if logs is None: raise HTTPException( status_code=404, detail="No logs available for this step" diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index be95df0114d..dd75187bba8 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -545,8 +545,12 @@ def list_workspace_builds( """ workspace = zen_store().get_workspace(workspace_name_or_id) build_filter_model.set_scope_workspace(workspace.id) - return zen_store().list_builds( - build_filter_model=build_filter_model, hydrate=hydrate + + return verify_permissions_and_list_entities( + filter_model=build_filter_model, + resource_type=ResourceType.PIPELINE_BUILD, + list_method=zen_store().list_builds, + hydrate=hydrate, ) @@ -572,8 +576,8 @@ def create_build( The created build. Raises: - IllegalOperationError: If the workspace or user specified in the build - does not match the current workspace or authenticated user. + IllegalOperationError: If the workspace specified in the build + does not match the current workspace. """ workspace = zen_store().get_workspace(workspace_name_or_id) @@ -583,13 +587,12 @@ def create_build( f"of this endpoint `{workspace_name_or_id}` is " f"not supported." ) - if build.user != auth_context.user.id: - raise IllegalOperationError( - "Creating builds for a user other than yourself " - "is not supported." - ) - return zen_store().create_build(build=build) + return verify_permissions_and_create_entity( + request_model=build, + resource_type=ResourceType.PIPELINE_BUILD, + create_method=zen_store().create_build, + ) @router.get( @@ -622,8 +625,12 @@ def list_workspace_deployments( """ workspace = zen_store().get_workspace(workspace_name_or_id) deployment_filter_model.set_scope_workspace(workspace.id) - return zen_store().list_deployments( - deployment_filter_model=deployment_filter_model, hydrate=hydrate + + return verify_permissions_and_list_entities( + filter_model=deployment_filter_model, + resource_type=ResourceType.PIPELINE_DEPLOYMENT, + list_method=zen_store().list_deployments, + hydrate=hydrate, ) @@ -649,9 +656,8 @@ def create_deployment( The created deployment. Raises: - IllegalOperationError: If the workspace or user specified in the - deployment does not match the current workspace or authenticated - user. + IllegalOperationError: If the workspace specified in the + deployment does not match the current workspace. """ workspace = zen_store().get_workspace(workspace_name_or_id) @@ -661,13 +667,12 @@ def create_deployment( f"of this endpoint `{workspace_name_or_id}` is " f"not supported." ) - if deployment.user != auth_context.user.id: - raise IllegalOperationError( - "Creating deployments for a user other than yourself " - "is not supported." - ) - return zen_store().create_deployment(deployment=deployment) + return verify_permissions_and_create_entity( + request_model=deployment, + resource_type=ResourceType.PIPELINE_DEPLOYMENT, + create_method=zen_store().create_deployment, + ) @router.get( @@ -698,8 +703,12 @@ def list_runs( """ workspace = zen_store().get_workspace(workspace_name_or_id) runs_filter_model.set_scope_workspace(workspace.id) - return zen_store().list_runs( - runs_filter_model=runs_filter_model, hydrate=hydrate + + return verify_permissions_and_list_entities( + filter_model=runs_filter_model, + resource_type=ResourceType.PIPELINE_RUN, + list_method=zen_store().list_runs, + hydrate=hydrate, ) From 1c19a6efdd2b9ae9c7934afa5cdb11df57f4c92e Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 21 Nov 2023 09:18:01 +0100 Subject: [PATCH 059/103] Generalize rbac on filter models --- src/zenml/models/model_models.py | 38 ------------ src/zenml/models/v2/base/filter.py | 60 ++++++++++++------- src/zenml/zen_server/rbac/endpoint_utils.py | 4 +- .../routers/model_versions_endpoints.py | 5 +- .../zen_server/routers/models_endpoints.py | 4 +- 5 files changed, 44 insertions(+), 67 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 956e03f9fdf..3e76a963ac0 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -21,7 +21,6 @@ Dict, List, Optional, - Set, Type, TypeVar, Union, @@ -78,23 +77,6 @@ class ModelScopedFilterModel(WorkspaceScopedFilter): """Base filter model inside Model Scope.""" _model_id: UUID = PrivateAttr(None) - _rbac_allowed_model_ids: Optional[Set[UUID]] = None - - def set_rbac_allowed_model_ids_and_user( - self, allowed_model_ids: Optional[Set[UUID]], user_id: Optional[UUID] - ) -> None: - """Set allowed model IDs and user ID for the query. - - Args: - allowed_model_ids: Set of IDs to limit the query to. If given, the - remaining filters will be applied to entities within this set - only. If `None`, the remaining filters will applied to all - entries in the table. - user_id: ID of the authenticated user. If given, all entities owned - by this user will be included in addition to the `allowed_ids`. - """ - self._rbac_allowed_model_ids = allowed_model_ids - self._rbac_user_id = user_id def set_scope_model(self, model_name_or_id: Union[str, UUID]) -> None: """Set the model to scope this response. @@ -125,28 +107,8 @@ def apply_filter( Returns: The query with filter applied. """ - from sqlmodel import or_ - query = super().apply_filter(query=query, table=table) - if self._rbac_allowed_model_ids is not None: - expressions = [ - getattr(table, "model_id").in_(self._rbac_allowed_model_ids) - ] - - if hasattr(table, "user_id"): - # Unowned entities are considered server-owned and can be seen - # by anyone - expressions.append(getattr(table, "user_id").is_(None)) - - if self._rbac_user_id: - # The authenticated user owns this entity - expressions.append( - getattr(table, "user_id") == self._rbac_user_id - ) - - query = query.where(or_(*expressions)) - if self._model_id: query = query.where(getattr(table, "model_id") == self._model_id) diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 3f04e43bd90..e5e61450ee3 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -290,8 +290,9 @@ class BaseFilter(BaseModel): default=None, description="Updated" ) - _rbac_allowed_ids: Optional[Set[UUID]] = None - _rbac_user_id: Optional[UUID] = None + _rbac_configuration: Optional[ + Tuple[UUID, Dict[str, Optional[Set[UUID]]]] + ] = None @validator("sort_by", pre=True) def validate_sort_by(cls, v: str) -> str: @@ -384,8 +385,10 @@ def sorting_params(self) -> Tuple[str, SorterOps]: return column, operator - def set_rbac_allowed_ids_and_user( - self, allowed_ids: Optional[Set[UUID]], user_id: Optional[UUID] + def configure_rbac( + self, + authenticated_user_id: UUID, + **column_allowed_ids: Optional[Set[UUID]], ) -> None: """Set allowed IDs and user ID for the query. @@ -397,8 +400,35 @@ def set_rbac_allowed_ids_and_user( user_id: ID of the authenticated user. If given, all entities owned by this user will be included in addition to the `allowed_ids`. """ - self._rbac_allowed_ids = allowed_ids - self._rbac_user_id = user_id + self._rbac_configuration = (authenticated_user_id, column_allowed_ids) + + def apply_rbac_filter( + self, + query: Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"], + table: Type["AnySchema"], + ) -> Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"]: + from sqlmodel import or_ + + if not self._rbac_configuration: + return query + + expressions = [] + + for column_name, allowed_ids in self._rbac_configuration[1].items(): + if allowed_ids is not None: + expression = getattr(table, column_name).in_(allowed_ids) + expressions.append(expression) + + if hasattr(table, "user_id"): + # Unowned entities are considered server-owned and can be seen + # by anyone + expressions.append(getattr(table, "user_id").is_(None)) + # The authenticated user owns this entity + expressions.append( + getattr(table, "user_id") == self._rbac_configuration[0] + ) + + return query.where(or_(False, *expressions)) @classmethod def _generate_filter_list(cls, values: Dict[str, Any]) -> List[Filter]: @@ -764,23 +794,7 @@ def apply_filter( Returns: The query with filter applied. """ - from sqlmodel import or_ - - if self._rbac_allowed_ids is not None: - expressions = [table.id.in_(self._rbac_allowed_ids)] # type: ignore[attr-defined] - - if hasattr(table, "user_id"): - # Unowned entities are considered server-owned and can be seen - # by anyone - expressions.append(getattr(table, "user_id").is_(None)) - - if self._rbac_user_id: - # The authenticated user owns this entity - expressions.append( - getattr(table, "user_id") == self._rbac_user_id - ) - - query = query.where(or_(*expressions)) + query = self.apply_rbac_filter(query) filters = self.generate_filter(table=table) diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py index 3a0958952bb..912ba55f24b 100644 --- a/src/zenml/zen_server/rbac/endpoint_utils.py +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -112,8 +112,8 @@ def verify_permissions_and_list_entities( assert auth_context allowed_ids = get_allowed_resource_ids(resource_type=resource_type) - filter_model.set_rbac_allowed_ids_and_user( - allowed_ids=allowed_ids, user_id=auth_context.user.id + filter_model.configure_rbac( + authenticated_user_id=auth_context.user.id, id=allowed_ids ) page = list_method(filter_model, **list_method_kwargs) return dehydrate_page(page) diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py index bf6cbec355b..c7c9885769e 100644 --- a/src/zenml/zen_server/routers/model_versions_endpoints.py +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -87,8 +87,9 @@ def list_model_versions( allowed_model_ids = get_allowed_resource_ids( resource_type=ResourceType.MODEL ) - model_version_filter_model.set_rbac_allowed_model_ids_and_user( - allowed_model_ids=allowed_model_ids, user_id=auth_context.user.id + model_version_filter_model.configure_rbac( + authenticated_user_id=auth_context.user.id, + model_id=allowed_model_ids, ) model_versions = zen_store().list_model_versions( diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 248e04ad3f3..a1a15c81851 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -197,8 +197,8 @@ def list_model_versions( allowed_model_ids = get_allowed_resource_ids( resource_type=ResourceType.MODEL ) - model_version_filter_model.set_rbac_allowed_model_ids_and_user( - allowed_model_ids=allowed_model_ids, user_id=auth_context.user.id + model_version_filter_model.configure_rbac( + authenticated_user_id=auth_context.user.id, model_id=allowed_model_ids ) model_versions = zen_store().list_model_versions( From 00bd06f3cdefec9e49399e2e41496ad93efc7fe0 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 21 Nov 2023 10:07:42 +0100 Subject: [PATCH 060/103] Add rbac check to service connector resources endpoint --- src/zenml/client.py | 1 - src/zenml/models/v2/core/service_connector.py | 4 --- .../zen_server/routers/secrets_endpoints.py | 9 +++-- .../routers/workspaces_endpoints.py | 18 +++++++++- src/zenml/zen_stores/rest_zen_store.py | 2 -- src/zenml/zen_stores/sql_zen_store.py | 35 +++++++------------ src/zenml/zen_stores/zen_store_interface.py | 2 -- 7 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index 66e22e996cd..52c6f4eaf6e 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -4248,7 +4248,6 @@ def list_service_connector_resources( connectors have access to. """ return self.zen_store.list_service_connector_resources( - user_name_or_id=self.active_user.id, workspace_name_or_id=self.active_workspace.id, connector_type=connector_type, resource_type=resource_type, diff --git a/src/zenml/models/v2/core/service_connector.py b/src/zenml/models/v2/core/service_connector.py index bbf42aadc22..f2b297c06ad 100644 --- a/src/zenml/models/v2/core/service_connector.py +++ b/src/zenml/models/v2/core/service_connector.py @@ -622,10 +622,6 @@ class ServiceConnectorFilter(WorkspaceScopedFilter): description="The type to scope this query to.", ) - is_shared: Optional[Union[bool, str]] = Field( - default=None, - description="If the service connector is shared or private", - ) name: Optional[str] = Field( default=None, description="The name to filter by", diff --git a/src/zenml/zen_server/routers/secrets_endpoints.py b/src/zenml/zen_server/routers/secrets_endpoints.py index fad4f1073f8..afba2b99539 100644 --- a/src/zenml/zen_server/routers/secrets_endpoints.py +++ b/src/zenml/zen_server/routers/secrets_endpoints.py @@ -36,6 +36,7 @@ from zenml.zen_server.rbac.utils import ( get_allowed_resource_ids, has_permissions_for_model, + is_owned_by_authenticated_user, ) from zenml.zen_server.utils import ( handle_exceptions, @@ -86,8 +87,12 @@ def list_secrets( if allowed_ids is not None: for secret in secrets.items: - if secret.id not in allowed_ids: - secret.remove_secrets() + if secret.id in allowed_ids or is_owned_by_authenticated_user( + secret + ): + continue + + secret.remove_secrets() return secrets diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index dd75187bba8..7b7d210670f 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -101,6 +101,7 @@ from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( batch_verify_permissions_for_models, + get_allowed_resource_ids, verify_permission_for_model, ) from zenml.zen_server.utils import ( @@ -1147,13 +1148,28 @@ def list_service_connector_resources( The matching list of resources that available service connectors have access to. """ - # TODO: missing permissions + workspace = zen_store().get_workspace(workspace_name_or_id) + + filter_model = ServiceConnectorFilter( + connector_type=connector_type, + resource_type=resource_type, + ) + filter_model.set_scope_workspace(workspace.id) + + allowed_ids = get_allowed_resource_ids( + resource_type=ResourceType.SERVICE_CONNECTOR + ) + filter_model.configure_rbac( + authenticated_user_id=auth_context.user.id, id=allowed_ids + ) + return zen_store().list_service_connector_resources( user_name_or_id=auth_context.user.id, workspace_name_or_id=workspace_name_or_id, connector_type=connector_type, resource_type=resource_type, resource_id=resource_id, + filter_model=filter_model, ) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 571b63bbbf2..4a9a9b60ada 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -1943,7 +1943,6 @@ def get_service_connector_client( def list_service_connector_resources( self, - user_name_or_id: Union[str, UUID], workspace_name_or_id: Union[str, UUID], connector_type: Optional[str] = None, resource_type: Optional[str] = None, @@ -1952,7 +1951,6 @@ def list_service_connector_resources( """List resources that can be accessed by service connectors. Args: - user_name_or_id: The name or ID of the user to scope to. workspace_name_or_id: The name or ID of the workspace to scope to. connector_type: The type of service connector to scope to. resource_type: The type of resource to scope to. diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 92e83c57ca0..38682f744ac 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -4149,53 +4149,42 @@ def get_service_connector_client( def list_service_connector_resources( self, - user_name_or_id: Union[str, UUID], workspace_name_or_id: Union[str, UUID], connector_type: Optional[str] = None, resource_type: Optional[str] = None, resource_id: Optional[str] = None, + filter_model: Optional[ServiceConnectorFilter] = None, ) -> List[ServiceConnectorResourcesModel]: """List resources that can be accessed by service connectors. Args: - user_name_or_id: The name or ID of the user to scope to. workspace_name_or_id: The name or ID of the workspace to scope to. connector_type: The type of service connector to scope to. resource_type: The type of resource to scope to. resource_id: The ID of the resource to scope to. + filter_model: Optional filter model to use when fetching service + connectors. Returns: The matching list of resources that available service connectors have access to. """ - user = self.get_user(user_name_or_id) workspace = self.get_workspace(workspace_name_or_id) - connector_filter_model = ServiceConnectorFilter( - connector_type=connector_type, - resource_type=resource_type, - is_shared=True, - workspace_id=workspace.id, - ) - shared_connectors = self.list_service_connectors( - filter_model=connector_filter_model - ).items - - connector_filter_model = ServiceConnectorFilter( - connector_type=connector_type, - resource_type=resource_type, - is_shared=False, - user_id=user.id, - workspace_id=workspace.id, - ) + if not filter_model: + filter_model = ServiceConnectorFilter( + connector_type=connector_type, + resource_type=resource_type, + workspace_id=workspace.id, + ) - private_connectors = self.list_service_connectors( - filter_model=connector_filter_model + service_connectors = self.list_service_connectors( + filter_model=filter_model ).items resource_list: List[ServiceConnectorResourcesModel] = [] - for connector in list(shared_connectors) + list(private_connectors): + for connector in service_connectors: if not service_connector_registry.is_registered(connector.type): # For connectors that we can instantiate, i.e. those that have a # connector type available locally, we return complete diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index b5f74f12647..7ea335ae27e 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1478,7 +1478,6 @@ def get_service_connector_client( @abstractmethod def list_service_connector_resources( self, - user_name_or_id: Union[str, UUID], workspace_name_or_id: Union[str, UUID], connector_type: Optional[str] = None, resource_type: Optional[str] = None, @@ -1487,7 +1486,6 @@ def list_service_connector_resources( """List resources that can be accessed by service connectors. Args: - user_name_or_id: The name or ID of the user to scope to. workspace_name_or_id: The name or ID of the workspace to scope to. connector_type: The type of service connector to scope to. resource_type: The type of resource to scope to. From 5a3b63f322d691b79b58c8de53c1417f14373b3a Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 21 Nov 2023 10:57:28 +0100 Subject: [PATCH 061/103] run metadata --- src/zenml/models/v2/base/filter.py | 2 +- src/zenml/zen_server/rbac/models.py | 1 + .../routers/run_metadata_endpoints.py | 11 ++++- .../zen_server/routers/steps_endpoints.py | 18 +++++-- .../routers/workspaces_endpoints.py | 48 ++++++++++++------- 5 files changed, 56 insertions(+), 24 deletions(-) diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index e5e61450ee3..2fce04e6008 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -794,7 +794,7 @@ def apply_filter( Returns: The query with filter applied. """ - query = self.apply_rbac_filter(query) + query = self.apply_rbac_filter(query, table=table) filters = self.generate_filter(table=table) diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index 6c0d20f05de..43a20b155c4 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -56,6 +56,7 @@ class ResourceType(StrEnum): PIPELINE_RUN = "pipeline_run" PIPELINE_DEPLOYMENT = "pipeline_deployment" PIPELINE_BUILD = "pipeline_build" + RUN_METADATA = "run_metadata" class Resource(BaseModel): diff --git a/src/zenml/zen_server/routers/run_metadata_endpoints.py b/src/zenml/zen_server/routers/run_metadata_endpoints.py index d7bf64be22a..326c45f5d47 100644 --- a/src/zenml/zen_server/routers/run_metadata_endpoints.py +++ b/src/zenml/zen_server/routers/run_metadata_endpoints.py @@ -20,6 +20,10 @@ from zenml.models import Page, RunMetadataFilter, RunMetadataResponse from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_list_entities, +) +from zenml.zen_server.rbac.models import ResourceType from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -57,6 +61,9 @@ def list_run_metadata( Returns: The pipeline runs according to query filters. """ - return zen_store().list_run_metadata( - run_metadata_filter_model, hydrate=hydrate + return verify_permissions_and_list_entities( + filter_model=run_metadata_filter_model, + resource_type=ResourceType.RUN_METADATA, + list_method=zen_store().list_run_metadata, + hydrate=hydrate, ) diff --git a/src/zenml/zen_server/routers/steps_endpoints.py b/src/zenml/zen_server/routers/steps_endpoints.py index 3105a5875af..21f484ef7dc 100644 --- a/src/zenml/zen_server/routers/steps_endpoints.py +++ b/src/zenml/zen_server/routers/steps_endpoints.py @@ -40,9 +40,11 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response -from zenml.zen_server.rbac.models import Action +from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( + dehydrate_page, dehydrate_response_model, + get_allowed_resource_ids, verify_permission_for_model, ) from zenml.zen_server.utils import ( @@ -69,7 +71,7 @@ def list_run_steps( make_dependable(StepRunFilter) ), hydrate: bool = False, - _: AuthContext = Security(authorize), + auth_context: AuthContext = Security(authorize), ) -> Page[StepRunResponse]: """Get run steps according to query filters. @@ -78,13 +80,23 @@ def list_run_steps( filtering. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. + auth_context: Authentication context. Returns: The run steps according to query filters. """ - return zen_store().list_run_steps( + allowed_pipeline_run_ids = get_allowed_resource_ids( + resource_type=ResourceType.PIPELINE_RUN + ) + step_run_filter_model.configure_rbac( + authenticated_user_id=auth_context.user.id, + pipeline_run_id=allowed_pipeline_run_ids, + ) + + page = zen_store().list_run_steps( step_run_filter_model=step_run_filter_model, hydrate=hydrate ) + return dehydrate_page(page) @router.post( diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 7b7d210670f..feebac3b74b 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -102,6 +102,7 @@ from zenml.zen_server.rbac.utils import ( batch_verify_permissions_for_models, get_allowed_resource_ids, + verify_permission, verify_permission_for_model, ) from zenml.zen_server.utils import ( @@ -763,25 +764,20 @@ def create_schedule( def create_pipeline_run( workspace_name_or_id: Union[str, UUID], pipeline_run: PipelineRunRequest, - auth_context: AuthContext = Security(authorize), - get_if_exists: bool = False, + _: AuthContext = Security(authorize), ) -> PipelineRunResponse: """Creates a pipeline run. Args: workspace_name_or_id: Name or ID of the workspace. pipeline_run: Pipeline run to create. - auth_context: Authentication context. - get_if_exists: If a similar pipeline run already exists, return it - instead of raising an error. Returns: The created pipeline run. Raises: - IllegalOperationError: If the workspace or user specified in the - pipeline run does not match the current workspace or authenticated - user. + IllegalOperationError: If the workspace specified in the + pipeline run does not match the current workspace. """ workspace = zen_store().get_workspace(workspace_name_or_id) @@ -791,15 +787,12 @@ def create_pipeline_run( f"of this endpoint `{workspace_name_or_id}` is " f"not supported." ) - if pipeline_run.user != auth_context.user.id: - raise IllegalOperationError( - "Creating pipeline runs for a user other than yourself " - "is not supported." - ) - if get_if_exists: - return zen_store().get_or_create_run(pipeline_run=pipeline_run)[0] - return zen_store().create_run(pipeline_run=pipeline_run) + return verify_permissions_and_create_entity( + request_model=pipeline_run, + resource_type=ResourceType.PIPELINE_RUN, + create_method=zen_store().create_run, + ) @router.post( @@ -841,6 +834,10 @@ def get_or_create_pipeline_run( "Creating pipeline runs for a user other than yourself " "is not supported." ) + + verify_permission( + resource_type=ResourceType.PIPELINE_RUN, action=Action.CREATE + ) return zen_store().get_or_create_run(pipeline_run=pipeline_run) @@ -884,7 +881,23 @@ def create_run_metadata( "is not supported." ) - return zen_store().create_run_metadata(run_metadata=run_metadata) + if run_metadata.pipeline_run_id: + run = zen_store().get_run(run_metadata.pipeline_run_id) + verify_permission_for_model(run, action=Action.UPDATE) + + if run_metadata.step_run_id: + step_run = zen_store().get_run_step(run_metadata.step_run_id) + verify_permission_for_model(step_run, action=Action.UPDATE) + + if run_metadata.artifact_id: + artifact = zen_store().get_artifact(run_metadata.artifact_id) + verify_permission_for_model(artifact, action=Action.UPDATE) + + verify_permission( + resource_type=ResourceType.RUN_METADATA, action=Action.CREATE + ) + + return zen_store().create_run_metadata(run_metadata) @router.post( @@ -1164,7 +1177,6 @@ def list_service_connector_resources( ) return zen_store().list_service_connector_resources( - user_name_or_id=auth_context.user.id, workspace_name_or_id=workspace_name_or_id, connector_type=connector_type, resource_type=resource_type, From 96ee0dd041b1a8e1ba788703a62597ce89c9bccb Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 21 Nov 2023 11:21:42 +0100 Subject: [PATCH 062/103] docstrings --- src/zenml/models/v2/base/filter.py | 23 +++++++++++++------ .../routers/service_accounts_endpoints.py | 2 +- .../zen_server/routers/tags_endpoints.py | 2 +- .../routers/workspaces_endpoints.py | 2 +- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 2fce04e6008..80f2d1ec652 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -390,15 +390,15 @@ def configure_rbac( authenticated_user_id: UUID, **column_allowed_ids: Optional[Set[UUID]], ) -> None: - """Set allowed IDs and user ID for the query. + """Configure RBAC allowed column values. Args: - allowed_ids: Set of IDs to limit the query to. If given, the - remaining filters will be applied to entities within this set - only. If `None`, the remaining filters will applied to all - entries in the table. - user_id: ID of the authenticated user. If given, all entities owned - by this user will be included in addition to the `allowed_ids`. + authenticated_user_id: ID of the authenticated user. All entities + owned by this user will be included. + column_allowed_ids: Set of IDs per column to limit the query to. + If given, the remaining filters will be applied to entities + within this set only. If `None`, the remaining filters will + applied to all entries in the table. """ self._rbac_configuration = (authenticated_user_id, column_allowed_ids) @@ -407,6 +407,15 @@ def apply_rbac_filter( query: Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"], table: Type["AnySchema"], ) -> Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"]: + """Applies the RBAC filter to a query. + + Args: + query: The query to which to apply the filter. + table: The query table. + + Returns: + The query with RBAC filter applied. + """ from sqlmodel import or_ if not self._rbac_configuration: diff --git a/src/zenml/zen_server/routers/service_accounts_endpoints.py b/src/zenml/zen_server/routers/service_accounts_endpoints.py index 5fabea257c4..8dedb498cf5 100644 --- a/src/zenml/zen_server/routers/service_accounts_endpoints.py +++ b/src/zenml/zen_server/routers/service_accounts_endpoints.py @@ -200,7 +200,7 @@ def delete_service_account( Args: service_account_name_or_id: Name or ID of the service account. """ - return verify_permissions_and_delete_entity( + verify_permissions_and_delete_entity( id=service_account_name_or_id, get_method=zen_store().get_service_account, delete_method=zen_store().delete_service_account, diff --git a/src/zenml/zen_server/routers/tags_endpoints.py b/src/zenml/zen_server/routers/tags_endpoints.py index d20111146d9..865d67b0237 100644 --- a/src/zenml/zen_server/routers/tags_endpoints.py +++ b/src/zenml/zen_server/routers/tags_endpoints.py @@ -176,7 +176,7 @@ def delete_tag( Args: tag_name_or_id: The name or ID of the tag to delete. """ - return verify_permissions_and_delete_entity( + verify_permissions_and_delete_entity( id=tag_name_or_id, get_method=zen_store().get_tag, delete_method=zen_store().delete_tag, diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index feebac3b74b..850977091fd 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -249,7 +249,7 @@ def delete_workspace( Args: workspace_name_or_id: Name or ID of the workspace. """ - return verify_permissions_and_delete_entity( + verify_permissions_and_delete_entity( id=workspace_name_or_id, get_method=zen_store().get_workspace, delete_method=zen_store().delete_workspace, From 43552b906e09e8650d472a106dc21451d3f40ac1 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 21 Nov 2023 11:59:11 +0100 Subject: [PATCH 063/103] Fix filter model edge case --- src/zenml/models/v2/base/filter.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 80f2d1ec652..b580801e991 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -428,7 +428,11 @@ def apply_rbac_filter( expression = getattr(table, column_name).in_(allowed_ids) expressions.append(expression) - if hasattr(table, "user_id"): + if expressions and hasattr(table, "user_id"): + # If `expressions` is not empty, we do not have full access to all + # rows of the table. In this case, we also include rows which the + # user owns. + # Unowned entities are considered server-owned and can be seen # by anyone expressions.append(getattr(table, "user_id").is_(None)) @@ -437,7 +441,10 @@ def apply_rbac_filter( getattr(table, "user_id") == self._rbac_configuration[0] ) - return query.where(or_(False, *expressions)) + if expressions: + return query.where(or_(*expressions)) + else: + return query @classmethod def _generate_filter_list(cls, values: Dict[str, Any]) -> List[Filter]: From 5514a1cdf413b040cdb4d69fa37a35179d5a11f1 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 21 Nov 2023 12:16:31 +0100 Subject: [PATCH 064/103] Linting --- src/zenml/models/v2/core/component.py | 2 ++ src/zenml/models/v2/core/stack.py | 2 ++ src/zenml/zen_server/rbac/utils.py | 1 - 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/zenml/models/v2/core/component.py b/src/zenml/models/v2/core/component.py index 4686e3de5e2..62e04aabfcc 100644 --- a/src/zenml/models/v2/core/component.py +++ b/src/zenml/models/v2/core/component.py @@ -360,4 +360,6 @@ def generate_filter( @server_owned_request_model class InternalComponentRequest(ComponentRequest): + """Internal component request model.""" + pass diff --git a/src/zenml/models/v2/core/stack.py b/src/zenml/models/v2/core/stack.py index b5d18da3d0d..221501135af 100644 --- a/src/zenml/models/v2/core/stack.py +++ b/src/zenml/models/v2/core/stack.py @@ -245,4 +245,6 @@ class StackFilter(WorkspaceScopedFilter): @server_owned_request_model class InternalStackRequest(StackRequest): + """Internal stack request model.""" + pass diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 7c39002dbb3..16ca216e3e9 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -194,7 +194,6 @@ def get_permission_denied_model(model: AnyResponseModel) -> AnyResponseModel: Returns: The permission denied model. """ - if isinstance(model, BaseResponse): return cast(AnyResponseModel, get_permission_denied_model_v2(model)) else: From a3416ea20107e9f5d5590853b77a4dce3acd29e9 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 21 Nov 2023 12:53:17 +0100 Subject: [PATCH 065/103] Request timeouts --- src/zenml/zen_server/rbac/zenml_cloud_rbac.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py index f03584456c2..0e70ae20670 100644 --- a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -223,11 +223,11 @@ def _get(self, endpoint: str, params: Dict[str, Any]) -> requests.Response: """ url = self._config.api_url + endpoint - response = self.session.get(url=url, params=params) + response = self.session.get(url=url, params=params, timeout=7) if response.status_code == 401: # Refresh the auth token and try again self._clear_session() - response = self.session.get(url=url, params=params) + response = self.session.get(url=url, params=params, timeout=7) try: response.raise_for_status() @@ -275,7 +275,9 @@ def _fetch_auth_token(self) -> str: "grant_type": "client_credentials", } try: - response = requests.post(auth0_url, headers=headers, data=payload) + response = requests.post( + auth0_url, headers=headers, data=payload, timeout=7 + ) response.raise_for_status() except Exception as e: raise RuntimeError(f"Error fetching auth token from auth0: {e}") From 1e0501db582daa2b2c463c5f66277ebb415a9bc0 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 21 Nov 2023 17:07:19 +0100 Subject: [PATCH 066/103] Add stack composition and fix CLI deleted columns --- src/zenml/cli/utils.py | 6 ++--- .../7500f434b71c_remove_shared_columns.py | 24 ++++++++++++++----- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index 8e47e00cba9..5d5f3226b25 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -1549,7 +1549,7 @@ def print_components_table( "NAME": component.name, "COMPONENT ID": component.id, "FLAVOR": component.flavor, - "OWNER": f"{component.user.name if component.user else 'DELETED!'}", + "OWNER": f"{component.user.name if component.user else '-'}", } configurations.append(component_config) print_table(configurations) @@ -1657,7 +1657,7 @@ def print_service_connectors_table( "TYPE": connector.emojified_connector_type, "RESOURCE TYPES": "\n".join(connector.emojified_resource_types), "RESOURCE NAME": resource_name, - "OWNER": f"{connector.user.name if connector.user else 'DELETED!'}", + "OWNER": f"{connector.user.name if connector.user else '-'}", "EXPIRES IN": expires_in( connector.expires_at, ":name_badge: Expired!" ) @@ -1761,7 +1761,7 @@ def print_service_connector_configuration( else: user_name = connector.user.name else: - user_name = "[DELETED]" + user_name = "-" if isinstance(connector, ServiceConnectorResponse): declare( diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py index dfdaadcbd8d..cf283285826 100644 --- a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -47,11 +47,18 @@ def resolve_duplicate_names() -> None: meta = sa.MetaData(bind=op.get_bind()) meta.reflect( - only=("stack", "stack_component", "service_connector", "workspace") + only=( + "stack", + "stack_component", + "stack_composition", + "service_connector", + "workspace", + ) ) stack_table = sa.Table("stack", meta) stack_component_table = sa.Table("stack_component", meta) + stack_composition_table = sa.Table("stack_composition", meta) workspace_table = sa.Table("workspace", meta) _rename_old_default_entities(stack_table) @@ -62,6 +69,7 @@ def resolve_duplicate_names() -> None: stack_components = [] stacks = [] + stack_compositions = [] for row in connection.execute(workspace_query).fetchall(): workspace_id = row[0] artifact_store_id = str(uuid4()).replace("-", "") @@ -89,24 +97,28 @@ def resolve_duplicate_names() -> None: "updated": utcnow, } + stack_id = str(uuid4()).replace("-", "") default_stack = { - "id": str(uuid4()).replace("-", ""), + "id": stack_id, "workspace_id": workspace_id, "name": "default", - "components": { - "artifact_store": [artifact_store_id], - "orchestrator": [orchestrator_id], - }, "is_shared": True, "created": utcnow, "updated": utcnow, } + stack_compositions.append( + {"stack_id": stack_id, "component_id": artifact_store_id} + ) + stack_compositions.append( + {"stack_id": stack_id, "component_id": orchestrator_id} + ) stack_components.append(default_artifact_store) stack_components.append(default_orchestrator) stacks.append(default_stack) op.bulk_insert(stack_component_table, rows=stack_components) op.bulk_insert(stack_table, rows=stacks) + op.bulk_insert(stack_composition_table, rows=stack_compositions) service_connector_table = sa.Table("service_connector", meta) query = sa.select( From 045f6c63ecf5e5cc9d0203d33d629ca7e869dbca Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 21 Nov 2023 17:51:37 +0100 Subject: [PATCH 067/103] Use shorter suffix in migration --- .../7500f434b71c_remove_shared_columns.py | 94 ++++++++++++------- 1 file changed, 60 insertions(+), 34 deletions(-) diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py index cf283285826..9465a59be88 100644 --- a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -6,7 +6,9 @@ """ import base64 +from collections import defaultdict from datetime import datetime +from typing import Optional, Set from uuid import uuid4 import sqlalchemy as sa @@ -19,26 +21,70 @@ depends_on = None -def _rename_old_default_entities(table: sa.Table) -> None: - """Include owner id in the name of default entities. +def _rename_duplicate_entities( + table: sa.Table, reserved_names: Optional[Set[str]] = None +) -> None: + """Include owner id in the name of duplicate entities. Args: - table: The table in which to rename the default entities. + table: The table in which to rename the duplicate entities. + reserved_names: Optional reserved names not to use. """ connection = op.get_bind() query = sa.select( table.c.id, + table.c.name, table.c.user_id, - ).where(table.c.name == "default") + ) - res = connection.execute(query).fetchall() - for id, owner_id in res: - name = f"default-{owner_id}" + names = reserved_names or set() + for id, name, user_id in connection.execute(query).fetchall(): + if name in names: + for suffix_length in range(4, len(user_id)): + new_name = f"{name}-{user_id[:suffix_length]}" + if new_name not in names: + name = new_name + break - connection.execute( - sa.update(table).where(table.c.id == id).values(name=name) - ) + connection.execute( + sa.update(table).where(table.c.id == id).values(name=name) + ) + + names.add(name) + + +def _rename_duplicate_components(table: sa.Table) -> None: + """Include owner id in the name of duplicate entities. + + Args: + table: The table in which to rename the duplicate entities. + """ + connection = op.get_bind() + + query = sa.select( + table.c.id, + table.c.type, + table.c.name, + table.c.user_id, + ) + + names_per_type = defaultdict(lambda: {"default"}) + + for id, type_, name, user_id in connection.execute(query).fetchall(): + names = names_per_type[type_] + if name in names: + for suffix_length in range(4, len(user_id)): + new_name = f"{name}-{user_id[:suffix_length]}" + if new_name not in names: + name = new_name + break + + connection.execute( + sa.update(table).where(table.c.id == id).values(name=name) + ) + + names.add(name) def resolve_duplicate_names() -> None: @@ -60,9 +106,11 @@ def resolve_duplicate_names() -> None: stack_component_table = sa.Table("stack_component", meta) stack_composition_table = sa.Table("stack_composition", meta) workspace_table = sa.Table("workspace", meta) + service_connector_table = sa.Table("service_connector", meta) - _rename_old_default_entities(stack_table) - _rename_old_default_entities(stack_component_table) + _rename_duplicate_entities(stack_table, reserved_names={"default"}) + _rename_duplicate_components(stack_component_table) + _rename_duplicate_entities(service_connector_table) workspace_query = sa.select(workspace_table.c.id) utcnow = datetime.utcnow() @@ -120,28 +168,6 @@ def resolve_duplicate_names() -> None: op.bulk_insert(stack_table, rows=stacks) op.bulk_insert(stack_composition_table, rows=stack_compositions) - service_connector_table = sa.Table("service_connector", meta) - query = sa.select( - service_connector_table.c.id, - service_connector_table.c.name, - service_connector_table.c.user_id, - ) - - names = set() - for id, name, user_id in connection.execute(query).fetchall(): - if name in names: - name = f"{name}-{user_id}" - # This will never happen, as we had a constraint on unique names - # per user - assert name not in names - connection.execute( - sa.update(service_connector_table) - .where(service_connector_table.c.id == id) - .values(name=name) - ) - - names.add(name) - def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" From b2055cf3712e6f3e507341b8ab101a7ae23de611 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 09:03:45 +0100 Subject: [PATCH 068/103] Remove jwt token permissions --- src/zenml/zen_server/jwt.py | 7 ------- src/zenml/zen_server/routers/auth_endpoints.py | 1 - tests/unit/zen_server/test_jwt.py | 6 ------ 3 files changed, 14 deletions(-) diff --git a/src/zenml/zen_server/jwt.py b/src/zenml/zen_server/jwt.py index 408891ba168..dbe5f81a6cb 100644 --- a/src/zenml/zen_server/jwt.py +++ b/src/zenml/zen_server/jwt.py @@ -17,7 +17,6 @@ from typing import ( Any, Dict, - List, Optional, cast, ) @@ -43,12 +42,10 @@ class JWTToken(BaseModel): was issued. pipeline_id: The id of the pipeline for which the token was issued. schedule_id: The id of the schedule for which the token was issued. - permissions: The permissions scope of the authenticated user. claims: The original token claims. """ user_id: UUID - permissions: List[str] device_id: Optional[UUID] = None api_key_id: Optional[UUID] = None pipeline_id: Optional[UUID] = None @@ -145,15 +142,12 @@ def decode_token( "UUID" ) - permissions: List[str] = claims.get("permissions", []) - return JWTToken( user_id=user_id, device_id=device_id, api_key_id=api_key_id, pipeline_id=pipeline_id, schedule_id=schedule_id, - permissions=list(set(permissions)), claims=claims, ) @@ -173,7 +167,6 @@ def encode(self, expires: Optional[datetime] = None) -> str: claims: Dict[str, Any] = dict( sub=str(self.user_id), - permissions=list(self.permissions), ) claims["iss"] = config.get_jwt_token_issuer() claims["aud"] = config.get_jwt_token_audience() diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index dcb5fe9b711..f343a1b3ddd 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -228,7 +228,6 @@ def generate_access_token( access_token = JWTToken( user_id=user_id, device_id=device.id if device else None, - permissions=[], api_key_id=api_key.id if api_key else None, ).encode(expires=expires) diff --git a/tests/unit/zen_server/test_jwt.py b/tests/unit/zen_server/test_jwt.py index 8ac9915d102..2a67bba85de 100644 --- a/tests/unit/zen_server/test_jwt.py +++ b/tests/unit/zen_server/test_jwt.py @@ -33,7 +33,6 @@ def test_encode_decode_works(): """Test that encoding and decoding JWT tokens generally works.""" user_id = uuid.uuid4() - permissions = ["read", "write"] device_id = uuid.uuid4() api_key_id = uuid.uuid4() pipeline_id = uuid.uuid4() @@ -45,7 +44,6 @@ def test_encode_decode_works(): token = JWTToken( user_id=user_id, - permissions=permissions, device_id=device_id, api_key_id=api_key_id, pipeline_id=pipeline_id, @@ -57,7 +55,6 @@ def test_encode_decode_works(): decoded_token = JWTToken.decode_token(encoded_token) assert decoded_token.user_id == user_id - assert set(decoded_token.permissions) == set(permissions) assert decoded_token.device_id == device_id assert decoded_token.api_key_id == api_key_id assert decoded_token.pipeline_id == pipeline_id @@ -71,7 +68,6 @@ def test_token_expiration(): """Test that tokens expire after the specified time.""" token = JWTToken( user_id=uuid.uuid4(), - permissions=[], ) expires = datetime.utcnow() @@ -92,7 +88,6 @@ def test_token_wrong_signature(): token = JWTToken( user_id=uuid.uuid4(), - permissions=[], ) encoded_token = token.encode() @@ -130,7 +125,6 @@ def _hack_token() -> Generator[Dict[str, Any], None, None]: token = JWTToken( user_id=uuid.uuid4(), - permissions=[], ) encoded_token = token.encode() From 7e1009924606769e6151be2731271047e4364d89 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 09:31:28 +0100 Subject: [PATCH 069/103] Move rbacImplementationSource to helm chart auth section --- .../deploy/helm/templates/server-deployment.yaml | 4 ++-- src/zenml/zen_server/deploy/helm/values.yaml | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/zenml/zen_server/deploy/helm/templates/server-deployment.yaml b/src/zenml/zen_server/deploy/helm/templates/server-deployment.yaml index 1c937289ced..380ba137277 100644 --- a/src/zenml/zen_server/deploy/helm/templates/server-deployment.yaml +++ b/src/zenml/zen_server/deploy/helm/templates/server-deployment.yaml @@ -137,9 +137,9 @@ spec: - name: ZENML_SERVER_ROOT_URL_PATH value: {{ .Values.zenml.rootUrlPath | quote }} {{- end }} - {{- if .Values.zenml.rbacImplementationSource }} + {{- if .Values.zenml.auth.rbacImplementationSource }} - name: ZENML_SERVER_RBAC_IMPLEMENTATION_SOURCE - value: {{ .Values.zenml.rbacImplementationSource | quote }} + value: {{ .Values.zenml.auth.rbacImplementationSource | quote }} {{- end }} - name: ZENML_DEFAULT_PROJECT_NAME value: {{ .Values.zenml.defaultProject | quote }} diff --git a/src/zenml/zen_server/deploy/helm/values.yaml b/src/zenml/zen_server/deploy/helm/values.yaml index 5c65416bed5..41a1bcec522 100644 --- a/src/zenml/zen_server/deploy/helm/values.yaml +++ b/src/zenml/zen_server/deploy/helm/values.yaml @@ -150,6 +150,11 @@ zenml: # used. externalServerID: + # Source pointing to a class implementing the RBAC interface defined by + # `zenml.zen_server.rbac_interface.RBACInterface`. If not specified, + # RBAC will not be enabled for this server. + rbacImplementationSource: + # The root URL path to use when behind a proxy. This is useful when the # `rewrite-target` annotation is used in the ingress controller, e.g.: # @@ -347,11 +352,6 @@ zenml: # mounted as environment variables in the ZenML server container. secretEnvironment: {} - # Source pointing to a class implementing the RBAC interface defined by - # `zenml.zen_server.rbac_interface.RBACInterface`. If not specified, - # RBAC will not be enabled for this server. - rbacImplementationSource: - service: type: ClusterIP port: 80 From a38ebea7d3585ddd55ca12fc3ef7e20d90ebe4f6 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 09:33:18 +0100 Subject: [PATCH 070/103] Update src/zenml/cli/server.py Co-authored-by: Stefan Nica --- src/zenml/cli/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/cli/server.py b/src/zenml/cli/server.py index 571f43b8fc9..6205108f4b2 100644 --- a/src/zenml/cli/server.py +++ b/src/zenml/cli/server.py @@ -784,7 +784,7 @@ def connect( except IllegalOperationError: cli_utils.warning( f"User '{username}' does not have sufficient permissions to " - f"to access the server at '{url}'." + f"access the server at '{url}'." ) if workspace: From a7fc9d82081a681852306d9cf8c155c7a1ddcbd1 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 12:34:41 +0100 Subject: [PATCH 071/103] Move database initialization into SQLZenStore --- src/zenml/cli/server.py | 2 +- src/zenml/client.py | 2 +- src/zenml/constants.py | 5 + src/zenml/zen_server/auth.py | 2 +- src/zenml/zen_stores/base_zen_store.py | 236 +----------------- src/zenml/zen_stores/rest_zen_store.py | 21 -- src/zenml/zen_stores/sql_zen_store.py | 206 ++++++++++++++- tests/harness/deployment/server_docker.py | 2 +- .../deployment/server_docker_compose.py | 2 +- tests/harness/deployment/server_local.py | 2 +- .../functional/cli/test_user_management.py | 2 +- .../functional/zen_stores/test_zen_store.py | 14 +- 12 files changed, 220 insertions(+), 276 deletions(-) diff --git a/src/zenml/cli/server.py b/src/zenml/cli/server.py index 6205108f4b2..c8de42ac0c7 100644 --- a/src/zenml/cli/server.py +++ b/src/zenml/cli/server.py @@ -191,7 +191,7 @@ def up( assert gc.store is not None if not blocking: - from zenml.zen_stores.base_zen_store import ( + from zenml.constants import ( DEFAULT_PASSWORD, DEFAULT_USERNAME, ) diff --git a/src/zenml/client.py b/src/zenml/client.py index 52c6f4eaf6e..c1ec1eb017e 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -948,7 +948,7 @@ def active_workspace(self) -> WorkspaceResponse: workspace_id = os.environ[ENV_ZENML_ACTIVE_WORKSPACE_ID] return self.get_workspace(workspace_id) - from zenml.zen_stores.base_zen_store import DEFAULT_WORKSPACE_NAME + from zenml.constants import DEFAULT_WORKSPACE_NAME # If running in a ZenML server environment, the active workspace is # not relevant diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 3684c87fa0b..7ab3673282a 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -138,6 +138,11 @@ def handle_int_env_var(var: str, default: int = 0) -> int: # Default store directory subpath: DEFAULT_STORE_DIRECTORY_NAME = "default_zen_store" +DEFAULT_USERNAME = "default" +DEFAULT_PASSWORD = "" +DEFAULT_WORKSPACE_NAME = "default" +DEFAULT_STACK_AND_COMPONENT_NAME = "default" + # Secrets Manager ZENML_SCHEMA_NAME = "zenml_schema_name" LOCAL_SECRETS_FILENAME = "secrets.yaml" diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 67c1e904242..f7eecd7f8e5 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -33,6 +33,7 @@ from zenml.analytics.context import AnalyticsContext from zenml.constants import ( API, + DEFAULT_USERNAME, EXTERNAL_AUTHENTICATOR_TIMEOUT, LOGIN, VERSION_1, @@ -54,7 +55,6 @@ ) from zenml.zen_server.jwt import JWTToken from zenml.zen_server.utils import server_config, zen_store -from zenml.zen_stores.base_zen_store import DEFAULT_USERNAME logger = get_logger(__name__) diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index a23aa246f14..3ea24fd20ce 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -12,8 +12,7 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Base Zen Store implementation.""" -import os -from abc import ABC, abstractmethod +from abc import ABC from typing import ( Any, Callable, @@ -32,20 +31,14 @@ from requests import ConnectionError import zenml -from zenml.analytics.utils import analytics_disabler from zenml.config.global_config import GlobalConfiguration from zenml.config.server_config import ServerConfiguration from zenml.config.store_config import StoreConfiguration from zenml.constants import ( - ENV_ZENML_DEFAULT_USER_NAME, - ENV_ZENML_DEFAULT_USER_PASSWORD, - ENV_ZENML_DEFAULT_WORKSPACE_NAME, IS_DEBUG_ENV, ) from zenml.enums import ( - AuthScheme, SecretsStoreType, - StackComponentType, StoreType, ) from zenml.exceptions import AuthorizationException @@ -53,15 +46,11 @@ from zenml.models import ( ServerDatabaseType, ServerModel, - StackFilter, StackResponse, UserFilter, - UserRequest, UserResponse, WorkspaceResponse, ) -from zenml.models.v2.core.component import InternalComponentRequest -from zenml.models.v2.core.stack import InternalStackRequest from zenml.utils.proxy_utils import make_proxy_class from zenml.zen_stores.enums import StoreEvent from zenml.zen_stores.secrets_stores.base_secrets_store import BaseSecretsStore @@ -75,11 +64,6 @@ logger = get_logger(__name__) -DEFAULT_USERNAME = "default" -DEFAULT_PASSWORD = "" -DEFAULT_WORKSPACE_NAME = "default" -DEFAULT_STACK_AND_COMPONENT_NAME = "default" - @make_proxy_class(SecretsStoreInterface, "_secrets_store") class BaseZenStore( @@ -299,19 +283,6 @@ def get_default_store_config(path: str) -> StoreConfiguration: def _initialize_database(self) -> None: """Initialize the database on first use.""" - default_workspace = self._get_or_create_default_workspace() - - config = ServerConfiguration.get_server_config() - # If the auth scheme is external, don't create the default user - if config.auth_scheme != AuthScheme.EXTERNAL: - try: - _ = self._default_user - except KeyError: - self._create_default_user() - - self._get_or_create_default_stack( - workspace=default_workspace, - ) @property def url(self) -> str: @@ -452,24 +423,6 @@ def is_local_store(self) -> bool: """ return self.get_store_info().is_local() - def _get_or_create_default_stack( - self, workspace: WorkspaceResponse - ) -> StackResponse: - try: - return self._get_default_stack( - workspace_id=workspace.id, - ) - except KeyError: - return self._create_default_stack( - workspace_id=workspace.id, - ) - - def _get_or_create_default_workspace(self) -> WorkspaceResponse: - try: - return self._default_workspace - except KeyError: - return self._create_default_workspace() - # -------------- # Event Handlers # -------------- @@ -506,152 +459,6 @@ def _trigger_event(self, event: StoreEvent, **kwargs: Any) -> None: exc_info=True, ) - # ------ - # Stacks - # ------ - - def _create_default_stack( - self, - workspace_id: UUID, - ) -> StackResponse: - """Create the default stack components and stack. - - The default stack contains a local orchestrator and a local artifact - store. - - Args: - workspace_id: ID of the workspace to which the stack - belongs. - - Returns: - The model of the created default stack. - """ - with analytics_disabler(): - workspace = self.get_workspace(workspace_name_or_id=workspace_id) - - logger.info( - f"Creating default stack in workspace {workspace.name}..." - ) - - orchestrator = self.create_stack_component( - component=InternalComponentRequest( - # Passing `None` for the user here means the orchestrator - # is owned by the server, which for RBAC indicates that - # everyone can read it - user=None, - workspace=workspace.id, - name=DEFAULT_STACK_AND_COMPONENT_NAME, - type=StackComponentType.ORCHESTRATOR, - flavor="local", - configuration={}, - ), - ) - - artifact_store = self.create_stack_component( - component=InternalComponentRequest( - # Passing `None` for the user here means the stack is owned - # by the server, which for RBAC indicates that everyone can - # read it - user=None, - workspace=workspace.id, - name=DEFAULT_STACK_AND_COMPONENT_NAME, - type=StackComponentType.ARTIFACT_STORE, - flavor="local", - configuration={}, - ), - ) - - components = { - c.type: [c.id] for c in [orchestrator, artifact_store] - } - - stack = InternalStackRequest( - # Passing `None` for the user here means the stack is owned by - # the server, which for RBAC indicates that everyone can read it - user=None, - name=DEFAULT_STACK_AND_COMPONENT_NAME, - components=components, - workspace=workspace.id, - ) - return self.create_stack(stack=stack) - - def _get_default_stack( - self, - workspace_id: UUID, - ) -> StackResponse: - """Get the default stack for a user in a workspace. - - Args: - workspace_id: ID of the workspace. - - Returns: - The default stack in the workspace. - - Raises: - KeyError: if the workspace or default stack doesn't exist. - """ - default_stacks = self.list_stacks( - StackFilter( - workspace_id=workspace_id, - name=DEFAULT_STACK_AND_COMPONENT_NAME, - ) - ) - if default_stacks.total == 0: - raise KeyError( - f"No default stack found in workspace {workspace_id}." - ) - return default_stacks.items[0] - - # ----- - # Users - # ----- - - @property - def _default_user_name(self) -> str: - """Get the default user name. - - Returns: - The default user name. - """ - return os.getenv(ENV_ZENML_DEFAULT_USER_NAME, DEFAULT_USERNAME) - - @property - def _default_user(self) -> UserResponse: - """Get the default user. - - Returns: - The default user. - - Raises: - KeyError: If the default user doesn't exist. - """ - user_name = self._default_user_name - try: - return self.get_user(user_name) - except KeyError: - raise KeyError(f"The default user '{user_name}' is not configured") - - def _create_default_user(self) -> UserResponse: - """Creates a default user. - - Returns: - The default user. - """ - user_name = os.getenv(ENV_ZENML_DEFAULT_USER_NAME, DEFAULT_USERNAME) - user_password = os.getenv( - ENV_ZENML_DEFAULT_USER_PASSWORD, DEFAULT_PASSWORD - ) - - logger.info(f"Creating default user '{user_name}' ...") - new_user = self.create_user( - UserRequest( - name=user_name, - active=True, - password=user_password, - ) - ) - return new_user - def get_external_user(self, user_id: UUID) -> UserResponse: """Get a user by external ID. @@ -669,47 +476,6 @@ def get_external_user(self, user_id: UUID) -> UserResponse: raise KeyError(f"User with external ID '{user_id}' not found.") return users.items[0] - # -------- - # Workspaces - # -------- - - @property - def _default_workspace_name(self) -> str: - """Get the default workspace name. - - Returns: - The default workspace name. - """ - return os.getenv( - ENV_ZENML_DEFAULT_WORKSPACE_NAME, DEFAULT_WORKSPACE_NAME - ) - - @property - def _default_workspace(self) -> WorkspaceResponse: - """Get the default workspace. - - Returns: - The default workspace. - - Raises: - KeyError: if the default workspace doesn't exist. - """ - workspace_name = self._default_workspace_name - try: - return self.get_workspace(workspace_name) - except KeyError: - raise KeyError( - f"The default workspace '{workspace_name}' is not configured" - ) - - @abstractmethod - def _create_default_workspace(self) -> WorkspaceResponse: - """Creates a default workspace. - - Returns: - The default workspace. - """ - class Config: """Pydantic configuration class.""" diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 4a9a9b60ada..8d0701dd242 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -448,27 +448,6 @@ class RestZenStore(BaseZenStore): _api_token: Optional[str] = None _session: Optional[requests.Session] = None - def _initialize_database(self) -> None: - """Initialize the database.""" - # don't do anything for a REST store - - def _create_default_stack( - self, - workspace_id: UUID, - ) -> StackResponse: - workspace = self.get_workspace(workspace_id) - - raise RuntimeError( - f"Unable to create default stack in workspace " - f"{workspace.name}." - ) - - def _create_default_workspace(self) -> WorkspaceResponse: - raise RuntimeError( - f"Unable to create default workspace " - f"{self._default_workspace_name}." - ) - # ==================================== # ZenML Store interface implementation # ==================================== diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 38682f744ac..c9cb0afb150 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -54,15 +54,24 @@ from sqlmodel.sql.expression import Select, SelectOfScalar from zenml.analytics.enums import AnalyticsEvent -from zenml.analytics.utils import track_decorator +from zenml.analytics.utils import analytics_disabler, track_decorator from zenml.config.global_config import GlobalConfiguration from zenml.config.secrets_store_config import SecretsStoreConfiguration +from zenml.config.server_config import ServerConfiguration from zenml.config.store_config import StoreConfiguration from zenml.constants import ( + DEFAULT_PASSWORD, + DEFAULT_STACK_AND_COMPONENT_NAME, + DEFAULT_USERNAME, + DEFAULT_WORKSPACE_NAME, + ENV_ZENML_DEFAULT_USER_NAME, + ENV_ZENML_DEFAULT_USER_PASSWORD, + ENV_ZENML_DEFAULT_WORKSPACE_NAME, ENV_ZENML_DISABLE_DATABASE_MIGRATION, TEXT_FIELD_MAX_LENGTH, ) from zenml.enums import ( + AuthScheme, LoggingLevels, ModelStages, SecretScope, @@ -190,6 +199,8 @@ WorkspaceResponse, WorkspaceUpdate, ) +from zenml.models.v2.core.component import InternalComponentRequest +from zenml.models.v2.core.stack import InternalStackRequest from zenml.service_connectors.service_connector_registry import ( service_connector_registry, ) @@ -201,7 +212,6 @@ ) from zenml.utils.string_utils import random_str from zenml.zen_stores.base_zen_store import ( - DEFAULT_STACK_AND_COMPONENT_NAME, BaseZenStore, ) from zenml.zen_stores.enums import StoreEvent @@ -262,6 +272,7 @@ SelectOfScalar.inherit_cache = True Select.inherit_cache = True + logger = get_logger(__name__) ZENML_SQLITE_DB_FILENAME = "zenml.db" @@ -690,11 +701,6 @@ class SqlZenStore(BaseZenStore): _engine: Optional[Engine] = None _alembic: Optional[Alembic] = None - def _create_default_workspace(self) -> WorkspaceResponse: - workspace_name = self._default_workspace_name - logger.info(f"Creating default workspace '{workspace_name}' ...") - return self.create_workspace(WorkspaceRequest(name=workspace_name)) - @property def engine(self) -> Engine: """The SQLAlchemy engine. @@ -910,6 +916,19 @@ def _initialize(self) -> None: ): self.migrate_database() + def _initialize_database(self) -> None: + """Initialize the database on first use.""" + default_workspace = self._get_or_create_default_workspace() + + config = ServerConfiguration.get_server_config() + # If the auth scheme is external, don't create the default user + if config.auth_scheme != AuthScheme.EXTERNAL: + self._get_or_create_default_user() + + self._get_or_create_default_stack( + workspace=default_workspace, + ) + def _create_mysql_database( self, url: URL, @@ -4522,6 +4541,118 @@ def _fail_if_stack_with_name_exists( ) return None + def _get_default_stack( + self, + workspace_id: UUID, + ) -> StackResponse: + """Get the default stack for a user in a workspace. + + Args: + workspace_id: ID of the workspace. + + Returns: + The default stack in the workspace. + + Raises: + KeyError: if the workspace or default stack doesn't exist. + """ + default_stacks = self.list_stacks( + StackFilter( + workspace_id=workspace_id, + name=DEFAULT_STACK_AND_COMPONENT_NAME, + ) + ) + if default_stacks.total == 0: + raise KeyError( + f"No default stack found in workspace {workspace_id}." + ) + return default_stacks.items[0] + + def _create_default_stack( + self, + workspace_id: UUID, + ) -> StackResponse: + """Create the default stack components and stack. + + The default stack contains a local orchestrator and a local artifact + store. + + Args: + workspace_id: ID of the workspace to which the stack + belongs. + + Returns: + The model of the created default stack. + """ + with analytics_disabler(): + workspace = self.get_workspace(workspace_name_or_id=workspace_id) + + logger.info( + f"Creating default stack in workspace {workspace.name}..." + ) + + orchestrator = self.create_stack_component( + component=InternalComponentRequest( + # Passing `None` for the user here means the orchestrator + # is owned by the server, which for RBAC indicates that + # everyone can read it + user=None, + workspace=workspace.id, + name=DEFAULT_STACK_AND_COMPONENT_NAME, + type=StackComponentType.ORCHESTRATOR, + flavor="local", + configuration={}, + ), + ) + + artifact_store = self.create_stack_component( + component=InternalComponentRequest( + # Passing `None` for the user here means the stack is owned + # by the server, which for RBAC indicates that everyone can + # read it + user=None, + workspace=workspace.id, + name=DEFAULT_STACK_AND_COMPONENT_NAME, + type=StackComponentType.ARTIFACT_STORE, + flavor="local", + configuration={}, + ), + ) + + components = { + c.type: [c.id] for c in [orchestrator, artifact_store] + } + + stack = InternalStackRequest( + # Passing `None` for the user here means the stack is owned by + # the server, which for RBAC indicates that everyone can read it + user=None, + name=DEFAULT_STACK_AND_COMPONENT_NAME, + components=components, + workspace=workspace.id, + ) + return self.create_stack(stack=stack) + + def _get_or_create_default_stack( + self, workspace: WorkspaceResponse + ) -> StackResponse: + """Get or create the default stack if it doesn't exist. + + Args: + workspace: The workspace for which to create the default stack. + + Returns: + The default stack. + """ + try: + return self._get_default_stack( + workspace_id=workspace.id, + ) + except KeyError: + return self._create_default_stack( + workspace_id=workspace.id, + ) + # ----------------------------- Step runs ----------------------------- def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse: @@ -5203,6 +5334,38 @@ def delete_user(self, user_name_or_id: Union[str, UUID]) -> None: session.delete(user) session.commit() + @property + def _default_user_name(self) -> str: + """Get the default user name. + + Returns: + The default user name. + """ + return os.getenv(ENV_ZENML_DEFAULT_USER_NAME, DEFAULT_USERNAME) + + def _get_or_create_default_user(self) -> UserResponse: + """Get or create the default user if it doesn't exist. + + Returns: + The default user. + """ + default_user_name = self._default_user_name + try: + return self.get_user(default_user_name) + except KeyError: + password = os.getenv( + ENV_ZENML_DEFAULT_USER_PASSWORD, DEFAULT_PASSWORD + ) + + logger.info(f"Creating default user '{default_user_name}' ...") + return self.create_user( + UserRequest( + name=default_user_name, + active=True, + password=password, + ) + ) + # ----------------------------- Workspaces ----------------------------- @track_decorator(AnalyticsEvent.CREATED_WORKSPACE) @@ -5360,6 +5523,35 @@ def delete_workspace(self, workspace_name_or_id: Union[str, UUID]) -> None: session.delete(workspace) session.commit() + @property + def _default_workspace_name(self) -> str: + """Get the default workspace name. + + Returns: + The default workspace name. + """ + return os.getenv( + ENV_ZENML_DEFAULT_WORKSPACE_NAME, DEFAULT_WORKSPACE_NAME + ) + + def _get_or_create_default_workspace(self) -> WorkspaceResponse: + """Get or create the default workspace if it doesn't exist. + + Returns: + The default workspace. + """ + default_workspace_name = self._default_workspace_name + + try: + return self.get_workspace(default_workspace_name) + except KeyError: + logger.info( + f"Creating default workspace '{default_workspace_name}' ..." + ) + return self.create_workspace( + WorkspaceRequest(name=default_workspace_name) + ) + # ======================= # Internal helper methods # ======================= diff --git a/tests/harness/deployment/server_docker.py b/tests/harness/deployment/server_docker.py index 4555e9f893e..d267826763a 100644 --- a/tests/harness/deployment/server_docker.py +++ b/tests/harness/deployment/server_docker.py @@ -152,7 +152,7 @@ def get_store_config(self) -> Optional[DeploymentStoreConfig]: Raises: RuntimeError: If the deployment is not running. """ - from zenml.zen_stores.base_zen_store import ( + from zenml.constants import ( DEFAULT_PASSWORD, DEFAULT_USERNAME, ) diff --git a/tests/harness/deployment/server_docker_compose.py b/tests/harness/deployment/server_docker_compose.py index e07011a0643..39928ba725f 100644 --- a/tests/harness/deployment/server_docker_compose.py +++ b/tests/harness/deployment/server_docker_compose.py @@ -264,7 +264,7 @@ def get_store_config(self) -> Optional[DeploymentStoreConfig]: Raises: RuntimeError: If the deployment is not running. """ - from zenml.zen_stores.base_zen_store import ( + from zenml.constants import ( DEFAULT_PASSWORD, DEFAULT_USERNAME, ) diff --git a/tests/harness/deployment/server_local.py b/tests/harness/deployment/server_local.py index 5b3ad3e983c..1bbdaab19b5 100644 --- a/tests/harness/deployment/server_local.py +++ b/tests/harness/deployment/server_local.py @@ -162,7 +162,7 @@ def get_store_config(self) -> Optional[DeploymentStoreConfig]: Raises: RuntimeError: If the deployment is not running. """ - from zenml.zen_stores.base_zen_store import ( + from zenml.constants import ( DEFAULT_PASSWORD, DEFAULT_USERNAME, ) diff --git a/tests/integration/functional/cli/test_user_management.py b/tests/integration/functional/cli/test_user_management.py index 930f8ce1fef..88dcf056332 100644 --- a/tests/integration/functional/cli/test_user_management.py +++ b/tests/integration/functional/cli/test_user_management.py @@ -20,7 +20,7 @@ user_delete_command, user_update_command, ) -from zenml.zen_stores.base_zen_store import ( +from zenml.constants import ( DEFAULT_USERNAME, ) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 8d80903fb09..d6ebfbb3598 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -42,7 +42,14 @@ StubLocalRepositoryContext, ) from zenml.client import Client -from zenml.constants import ACTIVATE, DEACTIVATE, USERS +from zenml.constants import ( + ACTIVATE, + DEACTIVATE, + DEFAULT_STACK_AND_COMPONENT_NAME, + DEFAULT_USERNAME, + DEFAULT_WORKSPACE_NAME, + USERS, +) from zenml.enums import ( ColorVariants, ModelStages, @@ -104,11 +111,6 @@ _load_file_from_artifact_store, ) from zenml.utils.enum_utils import StrEnum -from zenml.zen_stores.base_zen_store import ( - DEFAULT_STACK_AND_COMPONENT_NAME, - DEFAULT_USERNAME, - DEFAULT_WORKSPACE_NAME, -) from zenml.zen_stores.sql_zen_store import SqlZenStore DEFAULT_NAME = "default" From 4d64e5d0dec0f55e05d37aee671685151f078646 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 12:47:17 +0100 Subject: [PATCH 072/103] Remove unused security_scopes args --- src/zenml/zen_server/auth.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index f7eecd7f8e5..5edad921028 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -25,7 +25,6 @@ HTTPBasic, HTTPBasicCredentials, OAuth2PasswordBearer, - SecurityScopes, ) from pydantic import BaseModel from starlette.requests import Request @@ -652,13 +651,11 @@ def authenticate_api_key( def http_authentication( - security_scopes: SecurityScopes, credentials: HTTPBasicCredentials = Depends(HTTPBasic()), ) -> AuthContext: """Authenticates any request to the ZenML Server with basic HTTP authentication. Args: - security_scopes: Security scope will be ignored for http_auth credentials: HTTP basic auth credentials passed to the request. Returns: @@ -705,7 +702,6 @@ async def __call__(self, request: Request) -> Optional[str]: def oauth2_authentication( - security_scopes: SecurityScopes, token: str = Depends( CookieOAuth2TokenBearer( tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN, @@ -715,7 +711,6 @@ def oauth2_authentication( """Authenticates any request to the ZenML server with OAuth2 JWT tokens. Args: - security_scopes: Security scope for this token token: The JWT bearer token to be authenticated. Returns: @@ -736,12 +731,9 @@ def oauth2_authentication( return auth_context -def no_authentication(security_scopes: SecurityScopes) -> AuthContext: +def no_authentication() -> AuthContext: """Doesn't authenticate requests to the ZenML server. - Args: - security_scopes: Security scope will be ignored for http_auth - Returns: The authentication context reflecting the default user. """ From 701e8d57564c5bcc77d8774a1ddad1301cbce3c5 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 13:09:11 +0100 Subject: [PATCH 073/103] Make flavors user scoped --- src/zenml/models/v2/core/component.py | 14 +++--- src/zenml/models/v2/core/flavor.py | 46 ++++++++----------- src/zenml/models/v2/core/stack.py | 14 +++--- src/zenml/stack/flavor.py | 3 -- src/zenml/stack/flavor_registry.py | 2 - .../7500f434b71c_remove_shared_columns.py | 4 +- 6 files changed, 36 insertions(+), 47 deletions(-) diff --git a/src/zenml/models/v2/core/component.py b/src/zenml/models/v2/core/component.py index 62e04aabfcc..6b5e64b04b4 100644 --- a/src/zenml/models/v2/core/component.py +++ b/src/zenml/models/v2/core/component.py @@ -124,6 +124,13 @@ def name_cant_be_a_secret_reference(cls, name: str) -> str: return name +@server_owned_request_model +class InternalComponentRequest(ComponentRequest): + """Internal component request model.""" + + pass + + # ------------------ Update Model ------------------ @@ -356,10 +363,3 @@ def generate_filter( type_filter = getattr(table, "type") == self.scope_type return and_(base_filter, type_filter) return base_filter - - -@server_owned_request_model -class InternalComponentRequest(ComponentRequest): - """Internal component request model.""" - - pass diff --git a/src/zenml/models/v2/core/flavor.py b/src/zenml/models/v2/core/flavor.py index 9678285229b..6ed122e67d9 100644 --- a/src/zenml/models/v2/core/flavor.py +++ b/src/zenml/models/v2/core/flavor.py @@ -20,26 +20,26 @@ from zenml.constants import STR_FIELD_MAX_LENGTH from zenml.enums import StackComponentType -from zenml.models.v2.base.base import ( - BaseRequest, - BaseResponse, - BaseResponseBody, - BaseResponseMetadata, +from zenml.models.v2.base.internal import server_owned_request_model +from zenml.models.v2.base.scoped import ( + UserScopedRequest, + UserScopedResponse, + UserScopedResponseBody, + UserScopedResponseMetadata, + WorkspaceScopedFilter, ) -from zenml.models.v2.base.scoped import WorkspaceScopedFilter from zenml.models.v2.base.update import update_model if TYPE_CHECKING: from zenml.models import ( ServiceConnectorRequirements, ) - from zenml.models.v2.core.user import UserResponse from zenml.models.v2.core.workspace import WorkspaceResponse # ------------------ Request Model ------------------ -class FlavorRequest(BaseRequest): +class FlavorRequest(UserScopedRequest): """Request model for flavors.""" ANALYTICS_FIELDS: ClassVar[List[str]] = [ @@ -98,14 +98,18 @@ class FlavorRequest(BaseRequest): title="Whether or not this flavor is a custom, user created flavor.", default=True, ) - user: Optional[UUID] = Field( - default=None, title="The id of the user that created this resource." - ) workspace: Optional[UUID] = Field( default=None, title="The workspace to which this resource belongs." ) +@server_owned_request_model +class InternalFlavorRequest(FlavorRequest): + """Internal flavor request model.""" + + pass + + # ------------------ Update Model ------------------ @@ -117,12 +121,9 @@ class FlavorUpdate(FlavorRequest): # ------------------ Response Model ------------------ -class FlavorResponseBody(BaseResponseBody): +class FlavorResponseBody(UserScopedResponseBody): """Response body for flavor.""" - user: Union["UserResponse", None] = Field( - title="The user that created this resource.", nullable=True - ) type: StackComponentType = Field(title="The type of the Flavor.") integration: Optional[str] = Field( title="The name of the integration that the Flavor belongs to.", @@ -135,7 +136,7 @@ class FlavorResponseBody(BaseResponseBody): ) -class FlavorResponseMetadata(BaseResponseMetadata): +class FlavorResponseMetadata(UserScopedResponseMetadata): """Response metadata for flavors.""" workspace: Optional["WorkspaceResponse"] = Field( @@ -180,7 +181,9 @@ class FlavorResponseMetadata(BaseResponseMetadata): ) -class FlavorResponse(BaseResponse[FlavorResponseBody, FlavorResponseMetadata]): +class FlavorResponse( + UserScopedResponse[FlavorResponseBody, FlavorResponseMetadata] +): """Response model for flavors.""" # Analytics @@ -229,15 +232,6 @@ def connector_requirements( ) # Body and metadata properties - @property - def user(self) -> Union["UserResponse", None]: - """The `user` property. - - Returns: - the value of the property. - """ - return self.get_body().user - @property def type(self) -> StackComponentType: """The `type` property. diff --git a/src/zenml/models/v2/core/stack.py b/src/zenml/models/v2/core/stack.py index 221501135af..cc01165937a 100644 --- a/src/zenml/models/v2/core/stack.py +++ b/src/zenml/models/v2/core/stack.py @@ -71,6 +71,13 @@ def is_valid(self) -> bool: ) +@server_owned_request_model +class InternalStackRequest(StackRequest): + """Internal stack request model.""" + + pass + + # ------------------ Update Model ------------------ @@ -241,10 +248,3 @@ class StackFilter(WorkspaceScopedFilter): component_id: Optional[Union[UUID, str]] = Field( default=None, description="Component in the stack" ) - - -@server_owned_request_model -class InternalStackRequest(StackRequest): - """Internal stack request model.""" - - pass diff --git a/src/zenml/stack/flavor.py b/src/zenml/stack/flavor.py index ac9f8c44b72..99ba0dcaf82 100644 --- a/src/zenml/stack/flavor.py +++ b/src/zenml/stack/flavor.py @@ -135,15 +135,12 @@ def from_model(cls, flavor_model: FlavorResponse) -> "Flavor": def to_model( self, integration: Optional[str] = None, - scoped_by_workspace: bool = True, is_custom: bool = True, ) -> FlavorRequest: """Converts a flavor to a model. Args: integration: The integration to use for the model. - scoped_by_workspace: Whether this flavor should live in the scope - of the active workspace is_custom: Whether the flavor is a custom flavor. Custom flavors are then scoped by user and workspace diff --git a/src/zenml/stack/flavor_registry.py b/src/zenml/stack/flavor_registry.py index 13d73ab0328..bc9baedb536 100644 --- a/src/zenml/stack/flavor_registry.py +++ b/src/zenml/stack/flavor_registry.py @@ -112,7 +112,6 @@ def register_builtin_flavors(self, store: BaseZenStore) -> None: for flavor in self.builtin_flavors: flavor_request_model = flavor().to_model( integration="built-in", - scoped_by_workspace=False, is_custom=False, ) existing_flavor = store.list_flavors( @@ -146,7 +145,6 @@ def register_integration_flavors(store: BaseZenStore) -> None: for flavor in integrated_flavors: flavor_request_model = flavor().to_model( integration=name, - scoped_by_workspace=False, is_custom=False, ) existing_flavor = store.list_flavors( diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py index 9465a59be88..60ac05bbabc 100644 --- a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -8,7 +8,7 @@ import base64 from collections import defaultdict from datetime import datetime -from typing import Optional, Set +from typing import Dict, Optional, Set from uuid import uuid4 import sqlalchemy as sa @@ -69,7 +69,7 @@ def _rename_duplicate_components(table: sa.Table) -> None: table.c.user_id, ) - names_per_type = defaultdict(lambda: {"default"}) + names_per_type: Dict[str, Set[str]] = defaultdict(lambda: {"default"}) for id, type_, name, user_id in connection.execute(query).fetchall(): names = names_per_type[type_] From cdf7f05d77d5585ac6ea3758a3e67e94846e3a9e Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 13:15:07 +0100 Subject: [PATCH 074/103] Fix mypy issues --- src/zenml/zen_stores/base_zen_store.py | 78 +++++++++++++++++++++++--- src/zenml/zen_stores/sql_zen_store.py | 40 ------------- 2 files changed, 71 insertions(+), 47 deletions(-) diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index 3ea24fd20ce..710a6d46c43 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -12,6 +12,7 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Base Zen Store implementation.""" +import os from abc import ABC from typing import ( Any, @@ -35,6 +36,9 @@ from zenml.config.server_config import ServerConfiguration from zenml.config.store_config import StoreConfiguration from zenml.constants import ( + DEFAULT_STACK_AND_COMPONENT_NAME, + DEFAULT_WORKSPACE_NAME, + ENV_ZENML_DEFAULT_WORKSPACE_NAME, IS_DEBUG_ENV, ) from zenml.enums import ( @@ -46,6 +50,7 @@ from zenml.models import ( ServerDatabaseType, ServerModel, + StackFilter, StackResponse, UserFilter, UserResponse, @@ -344,7 +349,7 @@ def validate_active_config( active_workspace_name_or_id ) except KeyError: - active_workspace = self._get_or_create_default_workspace() + active_workspace = self._get_default_workspace() logger.warning( f"The current {config_name} active workspace is no longer " @@ -352,7 +357,7 @@ def validate_active_config( f"'{active_workspace.name}'." ) else: - active_workspace = self._get_or_create_default_workspace() + active_workspace = self._get_default_workspace() logger.info( f"Setting the {config_name} active workspace " @@ -372,8 +377,8 @@ def validate_active_config( "Resetting the active stack to default.", config_name, ) - active_stack = self._get_or_create_default_stack( - active_workspace + active_stack = self._get_default_stack( + workspace_id=active_workspace.id ) else: if active_stack.workspace.id != active_workspace.id: @@ -382,15 +387,18 @@ def validate_active_config( "workspace. Resetting the active stack to default.", config_name, ) - active_stack = self._get_or_create_default_stack( - active_workspace + active_stack = self._get_default_stack( + workspace_id=active_workspace.id ) + else: logger.warning( "Setting the %s active stack to default.", config_name, ) - active_stack = self._get_or_create_default_stack(active_workspace) + active_stack = self._get_default_stack( + workspace_id=active_workspace.id + ) return active_workspace, active_stack @@ -423,6 +431,62 @@ def is_local_store(self) -> bool: """ return self.get_store_info().is_local() + # ----------------------------- + # Default workspaces and stacks + # ----------------------------- + + @property + def _default_workspace_name(self) -> str: + """Get the default workspace name. + + Returns: + The default workspace name. + """ + return os.getenv( + ENV_ZENML_DEFAULT_WORKSPACE_NAME, DEFAULT_WORKSPACE_NAME + ) + + def _get_default_workspace(self) -> WorkspaceResponse: + """Get the default workspace. + + Raises: + KeyError: If the default workspace doesn't exist. + + Returns: + The default workspace. + """ + try: + return self.get_workspace(self._default_workspace_name) + except KeyError: + raise KeyError("Unable to find default workspace.") + + def _get_default_stack( + self, + workspace_id: UUID, + ) -> StackResponse: + """Get the default stack for a user in a workspace. + + Args: + workspace_id: ID of the workspace. + + Returns: + The default stack in the workspace. + + Raises: + KeyError: if the workspace or default stack doesn't exist. + """ + default_stacks = self.list_stacks( + StackFilter( + workspace_id=workspace_id, + name=DEFAULT_STACK_AND_COMPONENT_NAME, + ) + ) + if default_stacks.total == 0: + raise KeyError( + f"No default stack found in workspace {workspace_id}." + ) + return default_stacks.items[0] + # -------------- # Event Handlers # -------------- diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index c9cb0afb150..2467cdc285f 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -63,10 +63,8 @@ DEFAULT_PASSWORD, DEFAULT_STACK_AND_COMPONENT_NAME, DEFAULT_USERNAME, - DEFAULT_WORKSPACE_NAME, ENV_ZENML_DEFAULT_USER_NAME, ENV_ZENML_DEFAULT_USER_PASSWORD, - ENV_ZENML_DEFAULT_WORKSPACE_NAME, ENV_ZENML_DISABLE_DATABASE_MIGRATION, TEXT_FIELD_MAX_LENGTH, ) @@ -4541,33 +4539,6 @@ def _fail_if_stack_with_name_exists( ) return None - def _get_default_stack( - self, - workspace_id: UUID, - ) -> StackResponse: - """Get the default stack for a user in a workspace. - - Args: - workspace_id: ID of the workspace. - - Returns: - The default stack in the workspace. - - Raises: - KeyError: if the workspace or default stack doesn't exist. - """ - default_stacks = self.list_stacks( - StackFilter( - workspace_id=workspace_id, - name=DEFAULT_STACK_AND_COMPONENT_NAME, - ) - ) - if default_stacks.total == 0: - raise KeyError( - f"No default stack found in workspace {workspace_id}." - ) - return default_stacks.items[0] - def _create_default_stack( self, workspace_id: UUID, @@ -5523,17 +5494,6 @@ def delete_workspace(self, workspace_name_or_id: Union[str, UUID]) -> None: session.delete(workspace) session.commit() - @property - def _default_workspace_name(self) -> str: - """Get the default workspace name. - - Returns: - The default workspace name. - """ - return os.getenv( - ENV_ZENML_DEFAULT_WORKSPACE_NAME, DEFAULT_WORKSPACE_NAME - ) - def _get_or_create_default_workspace(self) -> WorkspaceResponse: """Get or create the default workspace if it doesn't exist. From d55cb379b16dc922b60ce469ea3be440f61c31ac Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 13:17:53 +0100 Subject: [PATCH 075/103] Use internal flavor request model --- src/zenml/stack/flavor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/zenml/stack/flavor.py b/src/zenml/stack/flavor.py index 99ba0dcaf82..6fba8ef9689 100644 --- a/src/zenml/stack/flavor.py +++ b/src/zenml/stack/flavor.py @@ -22,6 +22,7 @@ FlavorResponse, ServiceConnectorRequirements, ) +from zenml.models.v2.core.flavor import InternalFlavorRequest from zenml.stack.stack_component import StackComponent, StackComponentConfig from zenml.utils import source_utils @@ -166,7 +167,8 @@ def to_model( if connector_requirements else None ) - model = FlavorRequest( + model_class = FlavorRequest if is_custom else InternalFlavorRequest + model = model_class( user=client.active_user.id if is_custom else None, workspace=client.active_workspace.id if is_custom else None, name=self.name, From 139f8072a11f398aac2b8d6e9ebfd9006000d1d8 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 14:29:03 +0100 Subject: [PATCH 076/103] Re-add delete user endpoint --- .../zen_server/routers/users_endpoints.py | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/zenml/zen_server/routers/users_endpoints.py b/src/zenml/zen_server/routers/users_endpoints.py index 97233db3a2d..58ad5c26e83 100644 --- a/src/zenml/zen_server/routers/users_endpoints.py +++ b/src/zenml/zen_server/routers/users_endpoints.py @@ -28,7 +28,7 @@ VERSION_1, ) from zenml.enums import AuthScheme -from zenml.exceptions import AuthorizationException +from zenml.exceptions import AuthorizationException, IllegalOperationError from zenml.logger import get_logger from zenml.models import ( Page, @@ -282,6 +282,39 @@ def deactivate_user( user.get_body().activation_token = token return user + @router.delete( + "/{user_name_or_id}", + responses={ + 401: error_response, + 404: error_response, + 422: error_response, + }, + ) + @handle_exceptions + def delete_user( + user_name_or_id: Union[str, UUID], + auth_context: AuthContext = Security(authorize), + ) -> None: + """Deletes a specific user. + + Args: + user_name_or_id: Name or ID of the user. + auth_context: The authentication context. + + Raises: + IllegalOperationError: If the user is not authorized to delete the user. + """ + user = zen_store().get_user(user_name_or_id) + + if auth_context.user.name == user.name: + raise IllegalOperationError( + "You cannot delete the user account currently used to authenticate " + "to the ZenML server. If you wish to delete this account, " + "please authenticate with another account or contact your ZenML " + "administrator." + ) + zen_store().delete_user(user_name_or_id=user_name_or_id) + @router.put( "/{user_name_or_id}" + EMAIL_ANALYTICS, response_model=UserResponse, From 33a4105493cd97c3700431f0d068720db526844b Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 14:29:33 +0100 Subject: [PATCH 077/103] Enable user/service account deletion tests for rest --- .../functional/zen_stores/test_zen_store.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index d6ebfbb3598..38ca0e308b0 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -442,11 +442,6 @@ def test_delete_user_with_resources_fails(): """Tests deleting a user with resources fails.""" zen_store = Client().zen_store - if zen_store.type != StoreType.SQL: - pytest.skip( - "Only SQL Zen Stores allow creating resources for other accounts." - ) - with UserContext(delete=False) as user: with ComponentContext( c_type=StackComponentType.ORCHESTRATOR, @@ -699,11 +694,6 @@ def test_delete_service_account_with_resources_fails(): """Tests deleting a service account with resources fails.""" zen_store = Client().zen_store - if zen_store.type != StoreType.SQL: - pytest.skip( - "Only SQL Zen Stores allow creating resources for other accounts." - ) - with ServiceAccountContext(delete=False) as service_account: with ComponentContext( c_type=StackComponentType.ORCHESTRATOR, From 0e8d890682696ffff8e751035c247e4d0a4d310d Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 14:46:47 +0100 Subject: [PATCH 078/103] Use IllegalOperationError instead of HTTP exception --- src/zenml/zen_server/rbac/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 16ca216e3e9..80d760751a5 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -36,6 +36,7 @@ Page, UserScopedResponse, ) +from zenml.exceptions import IllegalOperationError from zenml.models.base_models import BaseResponseModel, UserScopedResponseModel from zenml.zen_server.auth import get_auth_context from zenml.zen_server.rbac.models import Action, Resource, ResourceType @@ -326,7 +327,7 @@ def batch_verify_permissions( action: The action the user wants to perform. Raises: - HTTPException: If the user is not allowed to perform the action. + IllegalOperationError: If the user is not allowed to perform the action. RuntimeError: If the permission verification failed unexpectedly. """ if not server_config().rbac_enabled: @@ -349,7 +350,7 @@ def batch_verify_permissions( ) if not permissions[resource]: - raise HTTPException( + raise IllegalOperationError( status_code=403, detail=f"Insufficient permissions to {action.upper()} resource " f"'{resource}'.", From 34105ed3c2531377fd1234234fc7dc5baf6cd76b Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 15:29:20 +0100 Subject: [PATCH 079/103] Remove unused property --- src/zenml/models/base_models.py | 32 ------------------------------ src/zenml/zen_server/rbac/utils.py | 2 +- 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/src/zenml/models/base_models.py b/src/zenml/models/base_models.py index 2a9f6143893..07d3aae0957 100644 --- a/src/zenml/models/base_models.py +++ b/src/zenml/models/base_models.py @@ -17,8 +17,6 @@ TYPE_CHECKING, Any, Dict, - List, - Set, Type, TypeVar, Union, @@ -118,36 +116,6 @@ def get_analytics_metadata(self) -> Dict[str, Any]: metadata["entity_id"] = self.id return metadata - @property - def partial(self) -> bool: - """Returns if this model is incomplete. - - A model is incomplete if the user has no permissions to read the - model itself or any submodel contained in this model. - - Returns: - True if the model is incomplete, False otherwise. - """ - if self.missing_permissions: - return True - - def _helper(value: Any) -> bool: - if isinstance(value, BaseResponseModel): - return value.partial - elif isinstance(value, Dict): - return any(_helper(v) for v in value.values()) - elif isinstance(value, (List, Set, tuple)): - return any(_helper(v) for v in value) - else: - return False - - for field_name in self.__fields__.keys(): - value = getattr(self, field_name) - if _helper(value): - return True - - return False - class UserScopedResponseModel(BaseResponseModel): """Base user-owned domain model. diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 80d760751a5..8aaf6ea9307 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -31,12 +31,12 @@ from fastapi import HTTPException from pydantic import BaseModel +from zenml.exceptions import IllegalOperationError from zenml.models import ( BaseResponse, Page, UserScopedResponse, ) -from zenml.exceptions import IllegalOperationError from zenml.models.base_models import BaseResponseModel, UserScopedResponseModel from zenml.zen_server.auth import get_auth_context from zenml.zen_server.rbac.models import Action, Resource, ResourceType From 27aaf84ab813406d34c97a3069d849c2d1baf166 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 15:34:44 +0100 Subject: [PATCH 080/103] Raise different errors depending on permission denied and missing body --- src/zenml/models/v2/base/base.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/zenml/models/v2/base/base.py b/src/zenml/models/v2/base/base.py index 33ec7cc2691..234aec740ce 100644 --- a/src/zenml/models/v2/base/base.py +++ b/src/zenml/models/v2/base/base.py @@ -260,13 +260,20 @@ def get_body(self) -> AnyBody: Raises: IllegalOperationError: If the user lacks permission to access the entity represented by this response. + RuntimeError: If the body was not included in the response. """ - if not self.body: + if self.permission_denied: raise IllegalOperationError( f"Missing permissions to access {type(self).__name__} with " f"ID {self.id}." ) + if not self.body: + raise RuntimeError( + f"Missing response body for {type(self).__name__} with ID " + f"{self.id}." + ) + return self.body def get_metadata(self) -> "AnyMetadata": @@ -279,7 +286,7 @@ def get_metadata(self) -> "AnyMetadata": IllegalOperationError: If the user lacks permission to access this entity represented by this response. """ - if not self.body: + if self.permission_denied: raise IllegalOperationError( f"Missing permissions to access {type(self).__name__} with " f"ID {self.id}." From 241651bba949e2ce38e1d0efe2ef870aea66dd58 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 16:16:49 +0100 Subject: [PATCH 081/103] Fix wrong error args --- src/zenml/zen_server/rbac/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 8aaf6ea9307..18c9885f7d0 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -351,9 +351,8 @@ def batch_verify_permissions( if not permissions[resource]: raise IllegalOperationError( - status_code=403, - detail=f"Insufficient permissions to {action.upper()} resource " - f"'{resource}'.", + message=f"Insufficient permissions to {action.upper()} " + f"resource '{resource}'.", ) From dca2a0b05ac88bc10ee303d5900faccc506f436f Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 17:30:16 +0100 Subject: [PATCH 082/103] Don't fail early for service accounts --- src/zenml/zen_server/rbac/zenml_cloud_rbac.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py index 0e70ae20670..6337ead45aa 100644 --- a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -144,8 +144,6 @@ def check_permissions( A dictionary mapping resources to a boolean which indicates whether the user has permissions to perform the action on that resource. """ - assert user.external_user_id - if not resources: # No need to send a request if there are no resources return {} @@ -154,6 +152,10 @@ def check_permissions( # Service accounts have full permissions for now return {resource: True for resource in resources} + # At this point its a regular user, which in the ZenML cloud with RBAC + # enabled is always authenticated using external authentication + assert user.external_user_id + params = { "user_id": str(user.external_user_id), "resources": [ @@ -186,12 +188,13 @@ def list_allowed_resource_ids( the action on. """ assert not resource.id - assert user.external_user_id - if user.is_service_account: # Service accounts have full permissions for now return True, [] + # At this point its a regular user, which in the ZenML cloud with RBAC + # enabled is always authenticated using external authentication + assert user.external_user_id params = { "user_id": str(user.external_user_id), "resource": _convert_to_cloud_resource(resource), From dafdce3f7094ecd21ef83988eaa0fc66f5e88a24 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Wed, 22 Nov 2023 20:38:22 +0100 Subject: [PATCH 083/103] Fix stack and flavor schemas to ignore user and workspace during update --- src/zenml/zen_stores/schemas/flavor_schemas.py | 4 +++- src/zenml/zen_stores/schemas/stack_schemas.py | 11 +++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/zenml/zen_stores/schemas/flavor_schemas.py b/src/zenml/zen_stores/schemas/flavor_schemas.py index 11c5da98c4e..0ada269101a 100644 --- a/src/zenml/zen_stores/schemas/flavor_schemas.py +++ b/src/zenml/zen_stores/schemas/flavor_schemas.py @@ -92,7 +92,9 @@ def update(self, flavor_update: "FlavorUpdate") -> "FlavorSchema": Returns: The updated `FlavorSchema`. """ - for field, value in flavor_update.dict(exclude_unset=True).items(): + for field, value in flavor_update.dict( + exclude_unset=True, exclude={"workspace", "user"} + ).items(): if field == "config_schema": setattr(self, field, json.dumps(value)) else: diff --git a/src/zenml/zen_stores/schemas/stack_schemas.py b/src/zenml/zen_stores/schemas/stack_schemas.py index a56ca365c2c..687bef7f9cd 100644 --- a/src/zenml/zen_stores/schemas/stack_schemas.py +++ b/src/zenml/zen_stores/schemas/stack_schemas.py @@ -115,16 +115,11 @@ def update( Returns: The updated StackSchema. """ - for field, value in stack_update.dict(exclude_unset=True).items(): + for field, value in stack_update.dict( + exclude_unset=True, exclude={"workspace", "user"} + ).items(): if field == "components": self.components = components - - elif field == "user": - assert self.user_id == value - - elif field == "workspace": - assert self.workspace_id == value - else: setattr(self, field, value) From ea0eb25cf264a8929cc8e2506d9d62aca5fe0268 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 22 Nov 2023 21:46:06 +0100 Subject: [PATCH 084/103] Catch correct error --- src/zenml/zen_server/rbac/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 18c9885f7d0..97e906da0f8 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -182,7 +182,7 @@ def has_permissions_for_model(model: AnyResponseModel, action: str) -> bool: try: verify_permission_for_model(model=model, action=action) return True - except HTTPException: + except IllegalOperationError: return False From 54f5137cf96c22ad34ceec14acb835d09b7dba01 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 23 Nov 2023 09:39:17 +0100 Subject: [PATCH 085/103] Apply rbac to new endpoint --- src/zenml/zen_server/rbac/utils.py | 1 - src/zenml/zen_server/routers/artifacts_endpoints.py | 10 ++++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 97e906da0f8..3c6852b8ffa 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -28,7 +28,6 @@ ) from uuid import UUID -from fastapi import HTTPException from pydantic import BaseModel from zenml.exceptions import IllegalOperationError diff --git a/src/zenml/zen_server/routers/artifacts_endpoints.py b/src/zenml/zen_server/routers/artifacts_endpoints.py index 79ba42cdef0..208c0cf83ae 100644 --- a/src/zenml/zen_server/routers/artifacts_endpoints.py +++ b/src/zenml/zen_server/routers/artifacts_endpoints.py @@ -34,6 +34,7 @@ verify_permissions_and_delete_entity, verify_permissions_and_get_entity, verify_permissions_and_list_entities, + verify_permissions_and_update_entity, ) from zenml.zen_server.rbac.models import ResourceType from zenml.zen_server.utils import ( @@ -141,7 +142,7 @@ def get_artifact( def update_artifact( artifact_id: UUID, artifact_update: ArtifactUpdate, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), + _: AuthContext = Security(authorize), ) -> ArtifactResponse: """Update an artifact by ID. @@ -152,7 +153,12 @@ def update_artifact( Returns: The updated artifact. """ - return zen_store().update_artifact(artifact_id, artifact_update) + return verify_permissions_and_update_entity( + id=artifact_id, + update_model=artifact_update, + get_method=zen_store().get_artifact, + update_method=zen_store().update_artifact, + ) @router.delete( From 2c5e457b02300d9c9f5f264d9c82023135859024 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 23 Nov 2023 10:04:19 +0100 Subject: [PATCH 086/103] Fix alembic order --- .../migrations/versions/389046140cad_data_versioning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_stores/migrations/versions/389046140cad_data_versioning.py b/src/zenml/zen_stores/migrations/versions/389046140cad_data_versioning.py index c4968c65f3d..0a30d726650 100644 --- a/src/zenml/zen_stores/migrations/versions/389046140cad_data_versioning.py +++ b/src/zenml/zen_stores/migrations/versions/389046140cad_data_versioning.py @@ -1,7 +1,7 @@ """Data Versioning [389046140cad]. Revision ID: 389046140cad -Revises: 14d687c8fa1c +Revises: 86fa52918b54 Create Date: 2023-10-09 14:12:01.280877 """ @@ -11,7 +11,7 @@ # revision identifiers, used by Alembic. revision = "389046140cad" -down_revision = "14d687c8fa1c" +down_revision = "86fa52918b54" branch_labels = None depends_on = None From 5c53679e078682900d642488d8527dfef2636467 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 23 Nov 2023 10:59:29 +0100 Subject: [PATCH 087/103] Create default stack when creating workspace --- src/zenml/zen_stores/sql_zen_store.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index ab71c2b9ec4..498ac4d95ec 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -919,17 +919,13 @@ def _initialize(self) -> None: def _initialize_database(self) -> None: """Initialize the database on first use.""" - default_workspace = self._get_or_create_default_workspace() + self._get_or_create_default_workspace() config = ServerConfiguration.get_server_config() # If the auth scheme is external, don't create the default user if config.auth_scheme != AuthScheme.EXTERNAL: self._get_or_create_default_user() - self._get_or_create_default_stack( - workspace=default_workspace, - ) - def _create_mysql_database( self, url: URL, @@ -5486,7 +5482,10 @@ def create_workspace( # Explicitly refresh the new_workspace schema session.refresh(new_workspace) - return new_workspace.to_model(hydrate=True) + workspace_model = new_workspace.to_model(hydrate=True) + + self._get_or_create_default_stack(workspace=workspace_model) + return workspace_model def get_workspace( self, workspace_name_or_id: Union[str, UUID], hydrate: bool = True From 198545c754ac70f00da2e2361c5a1b5230524fbe Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 23 Nov 2023 13:18:38 +0100 Subject: [PATCH 088/103] Fix workspace statistics inconsistency issue --- src/zenml/models/model_models.py | 7 +-- src/zenml/models/v2/base/filter.py | 29 +++++---- src/zenml/models/v2/base/scoped.py | 13 ++-- src/zenml/models/v2/core/api_key.py | 8 +-- src/zenml/models/v2/core/service_account.py | 8 +-- src/zenml/models/v2/core/user.py | 7 +-- .../routers/workspaces_endpoints.py | 37 +++++++++-- src/zenml/zen_stores/sql_zen_store.py | 63 ++++++++++--------- 8 files changed, 99 insertions(+), 73 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index b8ab0110735..e625fa8319a 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -40,13 +40,12 @@ ) from zenml.models.model_base_model import ModelBaseModel from zenml.models.tag_models import TagResponseModel +from zenml.models.v2.base.filter import AnyQuery from zenml.models.v2.base.scoped import WorkspaceScopedFilter from zenml.models.v2.core.artifact import ArtifactResponse from zenml.models.v2.core.pipeline_run import PipelineRunResponse if TYPE_CHECKING: - from sqlmodel.sql.expression import Select, SelectOfScalar - from zenml.model.model_version import ModelVersion from zenml.zen_stores.schemas import BaseSchema @@ -97,9 +96,9 @@ def set_scope_model(self, model_name_or_id: Union[str, UUID]) -> None: def apply_filter( self, - query: Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"], + query: AnyQuery, table: Type["AnySchema"], - ) -> Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"]: + ) -> AnyQuery: """Applies the filter to a query. Args: diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index b580801e991..e0384699e25 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -46,7 +46,6 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList - from sqlmodel.sql.expression import Select, SelectOfScalar from zenml.zen_stores.schemas import BaseSchema @@ -55,6 +54,9 @@ logger = get_logger(__name__) +AnyQuery = TypeVar("AnyQuery", bound=Any) + + class Filter(BaseModel, ABC): """Filter for all fields. @@ -402,24 +404,22 @@ def configure_rbac( """ self._rbac_configuration = (authenticated_user_id, column_allowed_ids) - def apply_rbac_filter( + def generate_rbac_filter( self, - query: Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"], table: Type["AnySchema"], - ) -> Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"]: - """Applies the RBAC filter to a query. + ) -> Optional["BooleanClauseList[Any]"]: + """Generates an optional RBAC filter. Args: - query: The query to which to apply the filter. table: The query table. Returns: - The query with RBAC filter applied. + The RBAC filter. """ from sqlmodel import or_ if not self._rbac_configuration: - return query + return None expressions = [] @@ -442,9 +442,9 @@ def apply_rbac_filter( ) if expressions: - return query.where(or_(*expressions)) + return or_(*expressions) else: - return query + return None @classmethod def _generate_filter_list(cls, values: Dict[str, Any]) -> List[Filter]: @@ -798,9 +798,9 @@ def generate_filter( def apply_filter( self, - query: Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"], + query: AnyQuery, table: Type["AnySchema"], - ) -> Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"]: + ) -> AnyQuery: """Applies the filter to a query. Args: @@ -810,7 +810,10 @@ def apply_filter( Returns: The query with filter applied. """ - query = self.apply_rbac_filter(query, table=table) + rbac_filter = self.generate_rbac_filter(table=table) + + if rbac_filter is not None: + query = query.where(rbac_filter) filters = self.generate_filter(table=table) diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index 805d20ee90d..f2166ee6fa2 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -23,12 +23,10 @@ Optional, Type, TypeVar, - Union, ) from uuid import UUID from pydantic import Field -from sqlmodel.sql.expression import Select, SelectOfScalar from zenml.models.v2.base.base import ( BaseRequest, @@ -36,7 +34,7 @@ BaseResponseBody, BaseResponseMetadata, ) -from zenml.models.v2.base.filter import BaseFilter +from zenml.models.v2.base.filter import AnyQuery, BaseFilter if TYPE_CHECKING: from zenml.models.v2.core.user import UserResponse @@ -45,7 +43,6 @@ AnySchema = TypeVar("AnySchema", bound=BaseSchema) - # ---------------------- Request Models ---------------------- @@ -166,9 +163,9 @@ def set_scope_user(self, user_id: UUID) -> None: def apply_filter( self, - query: Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"], + query: AnyQuery, table: Type["AnySchema"], - ) -> Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"]: + ) -> AnyQuery: """Applies the filter to a query. Args: @@ -258,9 +255,9 @@ def set_scope_workspace(self, workspace_id: UUID) -> None: def apply_filter( self, - query: Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"], + query: AnyQuery, table: Type["AnySchema"], - ) -> Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"]: + ) -> AnyQuery: """Applies the filter to a query. Args: diff --git a/src/zenml/models/v2/core/api_key.py b/src/zenml/models/v2/core/api_key.py index 92373e16ca9..c4cddd5a634 100644 --- a/src/zenml/models/v2/core/api_key.py +++ b/src/zenml/models/v2/core/api_key.py @@ -31,13 +31,11 @@ BaseResponseBody, BaseResponseMetadata, ) -from zenml.models.v2.base.filter import BaseFilter +from zenml.models.v2.base.filter import AnyQuery, BaseFilter from zenml.models.v2.base.update import update_model from zenml.utils.string_utils import b64_decode, b64_encode if TYPE_CHECKING: - from sqlmodel.sql.expression import Select, SelectOfScalar - from zenml.models.v2.base.filter import AnySchema from zenml.models.v2.core.service_account import ServiceAccountResponse @@ -360,9 +358,9 @@ def set_service_account(self, service_account_id: UUID) -> None: def apply_filter( self, - query: Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"], + query: AnyQuery, table: Type["AnySchema"], - ) -> Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"]: + ) -> AnyQuery: """Override to apply the service account scope as an additional filter. Args: diff --git a/src/zenml/models/v2/core/service_account.py b/src/zenml/models/v2/core/service_account.py index d78a748284d..83aad8d30a0 100644 --- a/src/zenml/models/v2/core/service_account.py +++ b/src/zenml/models/v2/core/service_account.py @@ -24,12 +24,10 @@ BaseResponseBody, BaseResponseMetadata, ) -from zenml.models.v2.base.filter import BaseFilter +from zenml.models.v2.base.filter import AnyQuery, BaseFilter from zenml.models.v2.base.update import update_model if TYPE_CHECKING: - from sqlmodel.sql.expression import Select, SelectOfScalar - from zenml.models.v2.base.filter import AnySchema from zenml.models.v2.core.user import UserResponse @@ -183,9 +181,9 @@ class ServiceAccountFilter(BaseFilter): def apply_filter( self, - query: Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"], + query: AnyQuery, table: Type["AnySchema"], - ) -> Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"]: + ) -> AnyQuery: """Override to filter out user accounts from the query. Args: diff --git a/src/zenml/models/v2/core/user.py b/src/zenml/models/v2/core/user.py index 80490e6fef0..e3b695eb6c7 100644 --- a/src/zenml/models/v2/core/user.py +++ b/src/zenml/models/v2/core/user.py @@ -35,12 +35,11 @@ BaseResponseBody, BaseResponseMetadata, ) -from zenml.models.v2.base.filter import BaseFilter +from zenml.models.v2.base.filter import AnyQuery, BaseFilter from zenml.models.v2.base.update import update_model if TYPE_CHECKING: from passlib.context import CryptContext - from sqlmodel.sql.expression import Select, SelectOfScalar from zenml.models.v2.base.filter import AnySchema @@ -405,9 +404,9 @@ class UserFilter(BaseFilter): def apply_filter( self, - query: Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"], + query: AnyQuery, table: Type["AnySchema"], - ) -> Union["Select[AnySchema]", "SelectOfScalar[AnySchema]"]: + ) -> AnyQuery: """Override to filter out service accounts from the query. Args: diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 4d254a07a59..6cfd032d67d 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1025,7 +1025,7 @@ def create_code_repository( @handle_exceptions def get_workspace_statistics( workspace_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize), + auth_context: AuthContext = Security(authorize), ) -> Dict[str, int]: """Gets statistics of a workspace. @@ -1039,13 +1039,40 @@ def get_workspace_statistics( """ workspace = zen_store().get_workspace(workspace_name_or_id) + user_id = auth_context.user.id + component_filter = ComponentFilter(workspace_id=workspace.id) + component_filter.configure_rbac( + authenticated_user_id=user_id, + id=get_allowed_resource_ids( + resource_type=ResourceType.STACK_COMPONENT + ), + ) + + stack_filter = StackFilter(workspace_id=workspace.id) + stack_filter.configure_rbac( + authenticated_user_id=user_id, + id=get_allowed_resource_ids(resource_type=ResourceType.STACK), + ) + + run_filter = PipelineRunFilter(workspace_id=workspace.id) + run_filter.configure_rbac( + authenticated_user_id=user_id, + id=get_allowed_resource_ids(resource_type=ResourceType.PIPELINE_RUN), + ) + + pipeline_filter = PipelineFilter(workspace_id=workspace.id) + pipeline_filter.configure_rbac( + authenticated_user_id=user_id, + id=get_allowed_resource_ids(resource_type=ResourceType.PIPELINE), + ) + return { - "stacks": zen_store().count_stacks(workspace_id=workspace.id), + "stacks": zen_store().count_stacks(filter_model=stack_filter), "components": zen_store().count_stack_components( - workspace_id=workspace.id + filter_model=component_filter ), - "pipelines": zen_store().count_pipelines(workspace_id=workspace.id), - "runs": zen_store().count_runs(workspace_id=workspace.id), + "pipelines": zen_store().count_pipelines(filter_model=pipeline_filter), + "runs": zen_store().count_runs(filter_model=run_filter), } diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 498ac4d95ec..3b6168ea7c1 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -2032,17 +2032,19 @@ def delete_stack_component(self, component_id: UUID) -> None: session.commit() - def count_stack_components(self, workspace_id: Optional[UUID]) -> int: - """Count all components, optionally within a workspace scope. + def count_stack_components( + self, filter_model: Optional[ComponentFilter] = None + ) -> int: + """Count all components. Args: - workspace_id: The workspace to use for counting components + filter_model: The filter model to use for counting components. Returns: - The number of components in the workspace. + The number of components. """ return self._count_entity( - schema=StackComponentSchema, workspace_id=workspace_id + schema=StackComponentSchema, filter_model=filter_model ) @staticmethod @@ -2074,7 +2076,7 @@ def _fail_if_component_with_name_type_exists( ).first() if existing_domain_component is not None: raise StackComponentExistsError( - f"Unable to register '{component_type.value}' component " + f"Unable to register '{component_type}' component " f"with name '{name}': Found an existing " f"component with the same name and type in the same " f" workspace '{existing_domain_component.workspace.name}'." @@ -2661,17 +2663,17 @@ def list_pipelines( hydrate=hydrate, ) - def count_pipelines(self, workspace_id: Optional[UUID]) -> int: - """Count all pipelines, optionally within a workspace scope. + def count_pipelines(self, filter_model: Optional[PipelineFilter]) -> int: + """Count all pipelines. Args: - workspace_id: The workspace to use for counting pipelines + filter_model: The filter model to use for counting pipelines. Returns: - The number of pipelines in the workspace. + The number of pipelines. """ return self._count_entity( - schema=PipelineSchema, workspace_id=workspace_id + schema=PipelineSchema, filter_model=filter_model ) def update_pipeline( @@ -3128,17 +3130,17 @@ def get_or_create_run( except KeyError: return self.get_run(pipeline_run.name), False - def count_runs(self, workspace_id: Optional[UUID]) -> int: - """Count all pipeline runs, optionally within a workspace scope. + def count_runs(self, filter_model: Optional[PipelineRunFilter]) -> int: + """Count all pipeline runs. Args: - workspace_id: The workspace to use for counting pipeline runs + filter model: The filter model to filter the runs. Returns: - The number of pipeline runs in the workspace. + The number of pipeline runs. """ return self._count_entity( - schema=PipelineRunSchema, workspace_id=workspace_id + schema=PipelineRunSchema, filter_model=filter_model ) # ----------------------------- Run Metadata ----------------------------- @@ -4560,17 +4562,17 @@ def delete_stack(self, stack_id: UUID) -> None: session.commit() - def count_stacks(self, workspace_id: Optional[UUID]) -> int: - """Count all stacks, optionally within a workspace scope. + def count_stacks(self, filter_model: Optional[StackFilter]) -> int: + """Count all stacks. Args: - workspace_id: The workspace to use for counting stacks + filter model: The filter model to filter the stacks. Returns: - The number of stacks in the workspace. + The number of stacks. """ return self._count_entity( - schema=StackSchema, workspace_id=workspace_id + schema=StackSchema, filter_model=filter_model ) def _fail_if_stack_with_name_exists( @@ -5627,23 +5629,26 @@ def _get_or_create_default_workspace(self) -> WorkspaceResponse: # ======================= def _count_entity( - self, schema: Type[BaseSchema], workspace_id: Optional[UUID] + self, + schema: Type[BaseSchema], + filter_model: Optional[BaseFilter] = None, ) -> int: - """Return count of a given entity, optionally scoped to workspace. + """Return count of a given entity. Args: schema: Schema of the Entity - workspace_id: (Optional) ID of the workspace scope - + filter model: The filter model to filter the entity table. Returns: Count of the entity as integer. """ with Session(self.engine) as session: - query = session.query(func.count(schema.id)) - if workspace_id and hasattr(schema, "workspace_id"): - query = query.filter(schema.workspace_id == workspace_id) + query = select([func.count(schema.id)]) + + if filter_model: + query = filter_model.apply_filter(query=query, table=schema) + + entity_count = session.scalar(query) - entity_count = query.scalar() return int(entity_count) @staticmethod From ded578dd140a3fb56c2e63baa414bf04f99238d4 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 23 Nov 2023 16:07:42 +0100 Subject: [PATCH 089/103] Fix some docstrings --- src/zenml/zen_server/routers/workspaces_endpoints.py | 1 + src/zenml/zen_stores/sql_zen_store.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 6cfd032d67d..90e4d0a0e5f 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1033,6 +1033,7 @@ def get_workspace_statistics( Args: workspace_name_or_id: Name or ID of the workspace to get statistics for. + auth_context: Authentication context. Returns: All pipelines within the workspace. diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 3b6168ea7c1..3e9f979731d 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -3134,7 +3134,7 @@ def count_runs(self, filter_model: Optional[PipelineRunFilter]) -> int: """Count all pipeline runs. Args: - filter model: The filter model to filter the runs. + filter_model: The filter model to filter the runs. Returns: The number of pipeline runs. @@ -4566,7 +4566,7 @@ def count_stacks(self, filter_model: Optional[StackFilter]) -> int: """Count all stacks. Args: - filter model: The filter model to filter the stacks. + filter_model: The filter model to filter the stacks. Returns: The number of stacks. @@ -5637,7 +5637,8 @@ def _count_entity( Args: schema: Schema of the Entity - filter model: The filter model to filter the entity table. + filter_model: The filter model to filter the entity table. + Returns: Count of the entity as integer. """ From b09b7b07266bdc0c207946e7009b6cc539b23ab2 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 23 Nov 2023 16:37:05 +0100 Subject: [PATCH 090/103] Fix some tests --- .../functional/zen_stores/test_zen_store.py | 30 ++++++------------- .../functional/zen_stores/utils.py | 1 + 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 8cef3278c6d..621a04aa958 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -1970,23 +1970,15 @@ def test_count_stack_components(): if not isinstance(store, SqlZenStore): pytest.skip("Test only applies to SQL store") active_workspace = client.active_workspace + filter_model = ComponentFilter(scope_workspace=active_workspace.id) + count_before = store.list_stack_components(filter_model).total - count_before = store.list_stack_components( - ComponentFilter(scope_workspace=active_workspace.id) - ).total - - assert ( - store.count_stack_components(workspace_id=active_workspace.id) - == count_before - ) + assert store.count_stack_components(filter_model) == count_before with ComponentContext( StackComponentType.ARTIFACT_STORE, config={}, flavor="s3" ): - assert ( - store.count_stack_components(workspace_id=active_workspace.id) - == count_before + 1 - ) + assert store.count_stack_components(filter_model) == count_before + 1 # .-------------------------. @@ -2300,24 +2292,20 @@ def test_count_runs(): if not isinstance(store, SqlZenStore): pytest.skip("Test only applies to SQL store") active_workspace = client.active_workspace - - num_runs = store.list_runs( - PipelineRunFilter(scope_workspace=active_workspace.id) - ).total + filter_model = PipelineRunFilter(scope_workspace=active_workspace.id) + num_runs = store.list_runs(filter_model).total # At baseline this should be the same - assert store.count_runs(workspace_id=active_workspace.id) == num_runs + assert store.count_runs(filter_model) == num_runs with PipelineRunContext(5): assert ( - store.count_runs(workspace_id=active_workspace.id) + store.count_runs(filter_model) == store.list_runs( PipelineRunFilter(scope_workspace=active_workspace.id) ).total ) - assert ( - store.count_runs(workspace_id=active_workspace.id) == num_runs + 5 - ) + assert store.count_runs(filter_model) == num_runs + 5 def test_filter_runs_by_code_repo(mocker): diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index b2b6ab0a39a..f45c3d8fd55 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -863,6 +863,7 @@ def update_method( integration="", source="", config_schema="", + user=uuid.uuid4(), workspace=uuid.uuid4(), ), filter_model=FlavorFilter, From c613c589fcc7969a00ff161e1b335f4b5aa52b6b Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 23 Nov 2023 16:57:17 +0100 Subject: [PATCH 091/103] Use action enum in more places --- src/zenml/zen_server/rbac/rbac_interface.py | 6 +++--- src/zenml/zen_server/rbac/utils.py | 12 ++++++------ src/zenml/zen_server/rbac/zenml_cloud_rbac.py | 14 +++++++------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/zenml/zen_server/rbac/rbac_interface.py b/src/zenml/zen_server/rbac/rbac_interface.py index d94b3b32b48..2b9dfcbd94a 100644 --- a/src/zenml/zen_server/rbac/rbac_interface.py +++ b/src/zenml/zen_server/rbac/rbac_interface.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Dict, List, Set, Tuple -from zenml.zen_server.rbac.models import Resource +from zenml.zen_server.rbac.models import Action, Resource if TYPE_CHECKING: from zenml.models import UserResponse @@ -27,7 +27,7 @@ class RBACInterface(ABC): @abstractmethod def check_permissions( - self, user: "UserResponse", resources: Set[Resource], action: str + self, user: "UserResponse", resources: Set[Resource], action: Action ) -> Dict[Resource, bool]: """Checks if a user has permissions to perform an action on resources. @@ -43,7 +43,7 @@ def check_permissions( @abstractmethod def list_allowed_resource_ids( - self, user: "UserResponse", resource: Resource, action: str + self, user: "UserResponse", resource: Resource, action: Action ) -> Tuple[bool, List[str]]: """Lists all resource IDs of a resource type that a user can access. diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 3c6852b8ffa..e6017afd2cb 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -165,7 +165,7 @@ def _dehydrate_value( return value -def has_permissions_for_model(model: AnyResponseModel, action: str) -> bool: +def has_permissions_for_model(model: AnyResponseModel, action: Action) -> bool: """If the active user has permissions to perform the action on the model. Args: @@ -275,7 +275,7 @@ def get_permission_denied_model_v1( def batch_verify_permissions_for_models( models: Sequence[AnyResponseModel], - action: str, + action: Action, ) -> None: """Batch permission verification for models. @@ -304,7 +304,7 @@ def batch_verify_permissions_for_models( def verify_permission_for_model( model: AnyResponseModel, - action: str, + action: Action, ) -> None: """Verifies if a user has permission to perform an action on a model. @@ -317,7 +317,7 @@ def verify_permission_for_model( def batch_verify_permissions( resources: Set[Resource], - action: str, + action: Action, ) -> None: """Batch permission verification. @@ -357,7 +357,7 @@ def batch_verify_permissions( def verify_permission( resource_type: str, - action: str, + action: Action, resource_id: Optional[UUID] = None, ) -> None: """Verifies if a user has permission to perform an action on a resource. @@ -374,7 +374,7 @@ def verify_permission( def get_allowed_resource_ids( resource_type: str, - action: str = Action.READ, + action: Action = Action.READ, ) -> Optional[Set[UUID]]: """Get all resource IDs of a resource type that a user can access. diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py index 6337ead45aa..ff12c17fd90 100644 --- a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -18,7 +18,7 @@ import requests from pydantic import BaseModel, validator -from zenml.zen_server.rbac.models import Resource +from zenml.zen_server.rbac.models import Action, Resource from zenml.zen_server.rbac.rbac_interface import RBACInterface from zenml.zen_server.utils import server_config @@ -131,7 +131,7 @@ def __init__(self) -> None: self._session: Optional[requests.Session] = None def check_permissions( - self, user: "UserResponse", resources: Set[Resource], action: str + self, user: "UserResponse", resources: Set[Resource], action: Action ) -> Dict[Resource, bool]: """Checks if a user has permissions to perform an action on resources. @@ -152,7 +152,7 @@ def check_permissions( # Service accounts have full permissions for now return {resource: True for resource in resources} - # At this point its a regular user, which in the ZenML cloud with RBAC + # At this point it's a regular user, which in the ZenML cloud with RBAC # enabled is always authenticated using external authentication assert user.external_user_id @@ -161,7 +161,7 @@ def check_permissions( "resources": [ _convert_to_cloud_resource(resource) for resource in resources ], - "action": action, + "action": str(action), } response = self._get(endpoint=PERMISSIONS_ENDPOINT, params=params) value = response.json() @@ -170,7 +170,7 @@ def check_permissions( return {_convert_from_cloud_resource(k): v for k, v in value.items()} def list_allowed_resource_ids( - self, user: "UserResponse", resource: Resource, action: str + self, user: "UserResponse", resource: Resource, action: Action ) -> Tuple[bool, List[str]]: """Lists all resource IDs of a resource type that a user can access. @@ -192,13 +192,13 @@ def list_allowed_resource_ids( # Service accounts have full permissions for now return True, [] - # At this point its a regular user, which in the ZenML cloud with RBAC + # At this point it's a regular user, which in the ZenML cloud with RBAC # enabled is always authenticated using external authentication assert user.external_user_id params = { "user_id": str(user.external_user_id), "resource": _convert_to_cloud_resource(resource), - "action": action, + "action": str(action), } response = self._get( endpoint=ALLOWED_RESOURCE_IDS_ENDPOINT, params=params From 6331448afd6a58ba823eb8c4596d1c6d3c49ad65 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Thu, 23 Nov 2023 16:50:18 +0000 Subject: [PATCH 092/103] Auto-update of E2E template --- examples/e2e/.copier-answers.yml | 2 +- examples/e2e/pipelines/batch_inference.py | 17 ++++----------- examples/e2e/run.py | 13 ------------ .../e2e/steps/deployment/deployment_deploy.py | 6 ++---- .../steps/etl/inference_data_preprocessor.py | 7 +------ .../hp_tuning/hp_tuning_select_best_model.py | 4 +--- .../hp_tuning/hp_tuning_single_search.py | 4 ++-- .../e2e/steps/inference/inference_predict.py | 7 +++---- ...ute_performance_metrics_on_current_data.py | 10 +++------ .../promotion/promote_with_metric_compare.py | 2 +- examples/e2e/steps/training/model_trainer.py | 21 +++++++++++-------- 11 files changed, 30 insertions(+), 63 deletions(-) diff --git a/examples/e2e/.copier-answers.yml b/examples/e2e/.copier-answers.yml index 23f4b486ff0..ce44ee01f20 100644 --- a/examples/e2e/.copier-answers.yml +++ b/examples/e2e/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2023.11.14 +_commit: 2023.11.14-2-ga100ab7 _src_path: gh:zenml-io/template-e2e-batch data_quality_checks: true email: '' diff --git a/examples/e2e/pipelines/batch_inference.py b/examples/e2e/pipelines/batch_inference.py index ac93ae9a182..13f54f8232a 100644 --- a/examples/e2e/pipelines/batch_inference.py +++ b/examples/e2e/pipelines/batch_inference.py @@ -24,8 +24,7 @@ notify_on_success, ) -from zenml import pipeline -from zenml.artifacts.external_artifact import ExternalArtifact +from zenml import ExternalArtifact, pipeline from zenml.integrations.evidently.metrics import EvidentlyMetricConfig from zenml.integrations.evidently.steps import evidently_report_step from zenml.logger import get_logger @@ -46,24 +45,16 @@ def e2e_use_case_batch_inference(): # of one step as the input of the next step. ########## ETL stage ########## df_inference, target, _ = data_loader( - random_state=ExternalArtifact( - model_artifact_pipeline_name="e2e_use_case_training", - model_artifact_name="random_state", - ), - is_inference=True, + random_state=ExternalArtifact(name="random_state"), is_inference=True ) df_inference = inference_data_preprocessor( dataset_inf=df_inference, - preprocess_pipeline=ExternalArtifact( - model_artifact_name="preprocess_pipeline", - ), + preprocess_pipeline=ExternalArtifact(name="preprocess_pipeline"), target=target, ) ########## DataQuality stage ########## report, _ = evidently_report_step( - reference_dataset=ExternalArtifact( - model_artifact_name="dataset_trn", - ), + reference_dataset=ExternalArtifact(name="dataset_trn"), comparison_dataset=df_inference, ignored_cols=["target"], metrics=[ diff --git a/examples/e2e/run.py b/examples/e2e/run.py index 02d22a3c4fa..b36a8bb8c3c 100644 --- a/examples/e2e/run.py +++ b/examples/e2e/run.py @@ -26,7 +26,6 @@ e2e_use_case_training, ) -from zenml.artifacts.external_artifact import ExternalArtifact from zenml.logger import get_logger logger = get_logger(__name__) @@ -213,18 +212,6 @@ def main( **run_args_inference ) - artifact = ExternalArtifact( - model_artifact_name="predictions", - model_name="e2e_use_case", - model_version="staging", - model_artifact_version=None, # can be skipped - using latest artifact link - ) - logger.info( - "Batch inference pipeline finished successfully! " - "You can find predictions in Artifact Store using ID: " - f"`{str(artifact.get_artifact_id())}`." - ) - if __name__ == "__main__": main() diff --git a/examples/e2e/steps/deployment/deployment_deploy.py b/examples/e2e/steps/deployment/deployment_deploy.py index c82135006be..7fa359acc10 100644 --- a/examples/e2e/steps/deployment/deployment_deploy.py +++ b/examples/e2e/steps/deployment/deployment_deploy.py @@ -20,7 +20,7 @@ from typing_extensions import Annotated -from zenml import get_step_context, step +from zenml import ArtifactConfig, get_step_context, step from zenml.client import Client from zenml.integrations.mlflow.services.mlflow_deployment import ( MLFlowDeploymentService, @@ -29,7 +29,6 @@ mlflow_model_registry_deployer_step, ) from zenml.logger import get_logger -from zenml.model import EndpointArtifactConfig logger = get_logger(__name__) @@ -38,8 +37,7 @@ def deployment_deploy() -> ( Annotated[ Optional[MLFlowDeploymentService], - "mlflow_deployment", - EndpointArtifactConfig(), + ArtifactConfig(name="mlflow_deployment", is_endpoint_artifact=True), ] ): """Predictions step. diff --git a/examples/e2e/steps/etl/inference_data_preprocessor.py b/examples/e2e/steps/etl/inference_data_preprocessor.py index 605b7d6e983..01a6ca767cb 100644 --- a/examples/e2e/steps/etl/inference_data_preprocessor.py +++ b/examples/e2e/steps/etl/inference_data_preprocessor.py @@ -21,7 +21,6 @@ from typing_extensions import Annotated from zenml import step -from zenml.model import DataArtifactConfig @step @@ -29,11 +28,7 @@ def inference_data_preprocessor( dataset_inf: pd.DataFrame, preprocess_pipeline: Pipeline, target: str, -) -> Annotated[ - pd.DataFrame, - "dataset_inf", - DataArtifactConfig(overwrite=False, artifact_name="inference_dataset"), -]: +) -> Annotated[pd.DataFrame, "inference_dataset"]: """Data preprocessor step. This is an example of a data processor step that prepares the data so that diff --git a/examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py b/examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py index 5f524675213..0d16ad4be64 100644 --- a/examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py +++ b/examples/e2e/steps/hp_tuning/hp_tuning_select_best_model.py @@ -47,9 +47,7 @@ def hp_tuning_select_best_model( best_metric = -1 # consume artifacts attached to current model version in Model Control Plane for step_name in step_names: - hp_output = model_version.get_data_artifact( - step_name=step_name, name="hp_result" - ) + hp_output = model_version.get_data_artifact("hp_result") model: ClassifierMixin = hp_output.load() # fetch metadata we attached earlier metric = float(hp_output.run_metadata["metric"].value) diff --git a/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py b/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py index 067ace93d03..d7539d0fdbf 100644 --- a/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py +++ b/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py @@ -96,8 +96,8 @@ def hp_tuning_single_search( score = accuracy_score(y_tst, y_pred) # log score along with output artifact as metadata log_artifact_metadata( - output_name="hp_result", - metric=float(score), + metadata={"metric": float(score)}, + artifact_name="hp_result", ) ### YOUR CODE ENDS HERE ### return cv.best_estimator_ diff --git a/examples/e2e/steps/inference/inference_predict.py b/examples/e2e/steps/inference/inference_predict.py index f4a219ebc16..26ce1231801 100644 --- a/examples/e2e/steps/inference/inference_predict.py +++ b/examples/e2e/steps/inference/inference_predict.py @@ -26,7 +26,6 @@ MLFlowDeploymentService, ) from zenml.logger import get_logger -from zenml.model import DataArtifactConfig logger = get_logger(__name__) @@ -34,7 +33,7 @@ @step def inference_predict( dataset_inf: pd.DataFrame, -) -> Annotated[pd.Series, "predictions", DataArtifactConfig(overwrite=False)]: +) -> Annotated[pd.Series, "predictions"]: """Predictions step. This is an example of a predictions step that takes the data in and returns @@ -59,7 +58,7 @@ def inference_predict( # get predictor predictor_service: Optional[ MLFlowDeploymentService - ] = model_version.get_endpoint_artifact("mlflow_deployment").load() + ] = model_version.load_artifact("mlflow_deployment") if predictor_service is not None: # run prediction from service predictions = predictor_service.predict(request=dataset_inf) @@ -69,7 +68,7 @@ def inference_predict( "as the orchestrator is not local." ) # run prediction from memory - predictor = model_version.get_model_artifact("model").load() + predictor = model_version.load_artifact("model") predictions = predictor.predict(dataset_inf) predictions = pd.Series(predictions, name="predicted") diff --git a/examples/e2e/steps/promotion/compute_performance_metrics_on_current_data.py b/examples/e2e/steps/promotion/compute_performance_metrics_on_current_data.py index 5df519172b2..18beb5a73c1 100644 --- a/examples/e2e/steps/promotion/compute_performance_metrics_on_current_data.py +++ b/examples/e2e/steps/promotion/compute_performance_metrics_on_current_data.py @@ -23,7 +23,7 @@ from zenml import get_step_context, step from zenml.logger import get_logger -from zenml.model import ModelVersion +from zenml.model.model_version import ModelVersion logger = get_logger(__name__) @@ -75,12 +75,8 @@ def compute_performance_metrics_on_current_data( else: # Get predictors predictors = { - latest_version_number: latest_version.get_model_artifact( - "model" - ).load(), - current_version_number: current_version.get_model_artifact( - "model" - ).load(), + latest_version_number: latest_version.load_artifact("model"), + current_version_number: current_version.load_artifact("model"), } metrics = {} diff --git a/examples/e2e/steps/promotion/promote_with_metric_compare.py b/examples/e2e/steps/promotion/promote_with_metric_compare.py index bf5c200c2fd..ad02eadad39 100644 --- a/examples/e2e/steps/promotion/promote_with_metric_compare.py +++ b/examples/e2e/steps/promotion/promote_with_metric_compare.py @@ -19,7 +19,7 @@ from zenml import get_step_context, step from zenml.logger import get_logger -from zenml.model import ModelVersion +from zenml.model.model_version import ModelVersion logger = get_logger(__name__) diff --git a/examples/e2e/steps/training/model_trainer.py b/examples/e2e/steps/training/model_trainer.py index 42e39eb5a5a..80824b05a94 100644 --- a/examples/e2e/steps/training/model_trainer.py +++ b/examples/e2e/steps/training/model_trainer.py @@ -20,7 +20,7 @@ from sklearn.base import ClassifierMixin from typing_extensions import Annotated -from zenml import log_artifact_metadata, step +from zenml import ArtifactConfig, log_artifact_metadata, step from zenml.client import Client from zenml.integrations.mlflow.experiment_trackers import ( MLFlowExperimentTracker, @@ -29,7 +29,6 @@ mlflow_register_model_step, ) from zenml.logger import get_logger -from zenml.model import ModelArtifactConfig logger = get_logger(__name__) @@ -50,7 +49,9 @@ def model_trainer( model: ClassifierMixin, target: str, name: str, -) -> Annotated[ClassifierMixin, "model", ModelArtifactConfig()]: +) -> Annotated[ + ClassifierMixin, ArtifactConfig(name="model", is_model_artifact=True) +]: """Configure and train a model on the training dataset. This is an example of a model training step that takes in a dataset artifact @@ -98,12 +99,14 @@ def model_trainer( name=name, ) # keep track of mlflow version for future use - log_artifact_metadata( - output_name="model", - model_registry_version=Client() - .active_stack.model_registry.list_model_versions(name=name)[-1] - .version, - ) + model_registry = Client().active_stack.model_registry + if model_registry: + versions = model_registry.list_model_versions(name=name) + if versions: + log_artifact_metadata( + metadata={"model_registry_version": versions[-1].version}, + artifact_name="model", + ) ### YOUR CODE ENDS HERE ### return model From 4889df729baf5f14cbe0baca821f641713866a13 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 23 Nov 2023 17:57:00 +0100 Subject: [PATCH 093/103] Add rbac on user endpoints --- src/zenml/zen_server/rbac/models.py | 3 +- .../zen_server/routers/users_endpoints.py | 69 +++++++++++++++---- 2 files changed, 58 insertions(+), 14 deletions(-) diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index 43a20b155c4..4ed73358d93 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -31,7 +31,7 @@ class Action(StrEnum): READ_SECRET_VALUE = "read_secret_value" # Service connectors - CLIENT = "client" # TODO: rename + CLIENT = "client" # Models PROMOTE = "promote" @@ -57,6 +57,7 @@ class ResourceType(StrEnum): PIPELINE_DEPLOYMENT = "pipeline_deployment" PIPELINE_BUILD = "pipeline_build" RUN_METADATA = "run_metadata" + USER = "user" class Resource(BaseModel): diff --git a/src/zenml/zen_server/routers/users_endpoints.py b/src/zenml/zen_server/routers/users_endpoints.py index 58ad5c26e83..3c0435d6421 100644 --- a/src/zenml/zen_server/routers/users_endpoints.py +++ b/src/zenml/zen_server/routers/users_endpoints.py @@ -43,6 +43,16 @@ authorize, ) from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_create_entity, +) +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import ( + dehydrate_page, + dehydrate_response_model, + get_allowed_resource_ids, + verify_permission_for_model, +) from zenml.zen_server.utils import ( handle_exceptions, make_dependable, @@ -82,7 +92,7 @@ def list_users( user_filter_model: UserFilter = Depends(make_dependable(UserFilter)), hydrate: bool = False, - _: AuthContext = Security(authorize), + auth_context: AuthContext = Security(authorize), ) -> Page[UserResponse]: """Returns a list of all users. @@ -91,13 +101,24 @@ def list_users( pagination. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. + auth_context: Authentication context. Returns: A list of all users. """ - return zen_store().list_users( + allowed_ids = get_allowed_resource_ids(resource_type=ResourceType.USER) + if allowed_ids is not None: + # Make sure users can see themselves + allowed_ids.add(auth_context.user.id) + + user_filter_model.configure_rbac( + authenticated_user_id=auth_context.user.id, id=allowed_ids + ) + + page = zen_store().list_users( user_filter_model=user_filter_model, hydrate=hydrate ) + return dehydrate_page(page) # When the auth scheme is set to EXTERNAL, users cannot be created via the @@ -139,7 +160,12 @@ def create_user( token = user.generate_activation_token() else: user.active = True - new_user = zen_store().create_user(user) + + new_user = verify_permissions_and_create_entity( + request_model=user, + resource_type=ResourceType.USER, + create_method=zen_store().create_user, + ) # add back the original unhashed activation token, if generated, to # send it back to the client @@ -157,7 +183,7 @@ def create_user( def get_user( user_name_or_id: Union[str, UUID], hydrate: bool = True, - _: AuthContext = Security(authorize), + auth_context: AuthContext = Security(authorize), ) -> UserResponse: """Returns a specific user. @@ -165,13 +191,18 @@ def get_user( user_name_or_id: Name or ID of the user. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. + auth_context: Authentication context. Returns: A specific user. """ - return zen_store().get_user( + user = zen_store().get_user( user_name_or_id=user_name_or_id, hydrate=hydrate ) + if user.id != auth_context.user.id: + verify_permission_for_model(user, action=Action.READ) + + return dehydrate_response_model(user) # When the auth scheme is set to EXTERNAL, users cannot be updated via the @@ -191,23 +222,27 @@ def get_user( def update_user( user_name_or_id: Union[str, UUID], user_update: UserUpdate, - _: AuthContext = Security(authorize), + auth_context: AuthContext = Security(authorize), ) -> UserResponse: """Updates a specific user. Args: user_name_or_id: Name or ID of the user. user_update: the user to use for the update. + auth_context: Authentication context. Returns: The updated user. """ user = zen_store().get_user(user_name_or_id) + if user.id != auth_context.user.id: + verify_permission_for_model(user, action=Action.UPDATE) - return zen_store().update_user( + updated_user = zen_store().update_user( user_id=user.id, user_update=user_update, ) + return dehydrate_response_model(updated_user) @activation_router.put( "/{user_name_or_id}" + ACTIVATE, @@ -258,17 +293,20 @@ def activate_user( @handle_exceptions def deactivate_user( user_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize), + auth_context: AuthContext = Security(authorize), ) -> UserResponse: """Deactivates a user and generates a new activation token for it. Args: user_name_or_id: Name or ID of the user. + auth_context: Authentication context. Returns: The generated activation token. """ user = zen_store().get_user(user_name_or_id) + if user.id != auth_context.user.id: + verify_permission_for_model(user, action=Action.UPDATE) user_update = UserUpdate( name=user.name, @@ -280,7 +318,7 @@ def deactivate_user( ) # add back the original unhashed activation token user.get_body().activation_token = token - return user + return dehydrate_response_model(user) @router.delete( "/{user_name_or_id}", @@ -306,13 +344,16 @@ def delete_user( """ user = zen_store().get_user(user_name_or_id) - if auth_context.user.name == user.name: + if auth_context.user.id == user.id: raise IllegalOperationError( "You cannot delete the user account currently used to authenticate " "to the ZenML server. If you wish to delete this account, " "please authenticate with another account or contact your ZenML " "administrator." ) + else: + verify_permission_for_model(user, action=Action.DELETE) + zen_store().delete_user(user_name_or_id=user_name_or_id) @router.put( @@ -360,9 +401,10 @@ def email_opt_in_response( source="zenml server", ) - return zen_store().update_user( + updated_user = zen_store().update_user( user_id=user.id, user_update=user_update ) + return dehydrate_response_model(updated_user) else: raise AuthorizationException( "Users can not opt in on behalf of another user." @@ -386,7 +428,7 @@ def get_current_user( Returns: The model of the authenticated user. """ - return auth_context.user + return dehydrate_response_model(auth_context.user) # When the auth scheme is set to EXTERNAL, users cannot be managed via the @@ -416,6 +458,7 @@ def update_myself( Returns: The updated user. """ - return zen_store().update_user( + updated_user = zen_store().update_user( user_id=auth_context.user.id, user_update=user ) + return dehydrate_response_model(updated_user) From 8a5cf5ab8d55c2601e0b29e0e105db4ea57988a4 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 23 Nov 2023 20:53:49 +0100 Subject: [PATCH 094/103] Add early return for dehydration and more response models to mapping --- src/zenml/zen_server/rbac/utils.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index e6017afd2cb..39bbfb86cc0 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -141,8 +141,10 @@ def _dehydrate_value( value, action=Action.READ ) resource = get_resource_for_model(value) - has_permissions = resource and (permissions or {}).get(resource, False) + if not resource: + return dehydrate_response_model(value, permissions=permissions) + has_permissions = (permissions or {}).get(resource, False) if has_permissions or has_permissions_for_model( model=value, action=Action.READ ): @@ -468,10 +470,18 @@ def get_resource_type_for_model( ComponentResponse, FlavorResponse, ModelResponseModel, + PipelineBuildResponse, + PipelineDeploymentResponse, PipelineResponse, + PipelineRunResponse, + RunMetadataResponse, SecretResponseModel, + ServiceAccountResponse, ServiceConnectorResponse, StackResponse, + TagResponseModel, + UserResponse, + WorkspaceResponse, ) mapping: Dict[ @@ -487,6 +497,14 @@ def get_resource_type_for_model( SecretResponseModel: ResourceType.SECRET, ModelResponseModel: ResourceType.MODEL, ArtifactResponse: ResourceType.ARTIFACT, + WorkspaceResponse: ResourceType.WORKSPACE, + UserResponse: ResourceType.USER, + RunMetadataResponse: ResourceType.RUN_METADATA, + PipelineDeploymentResponse: ResourceType.PIPELINE_DEPLOYMENT, + PipelineBuildResponse: ResourceType.PIPELINE_BUILD, + PipelineRunResponse: ResourceType.PIPELINE_RUN, + TagResponseModel: ResourceType.TAG, + ServiceAccountResponse: ResourceType.SERVICE_ACCOUNT, } return mapping.get(type(model)) From d1ebd449ec69545059e43eecbc7c7b042a3e0382 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 23 Nov 2023 21:32:09 +0100 Subject: [PATCH 095/103] Don't filter by user ID when finding stack components during tests --- tests/harness/model/requirements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/harness/model/requirements.py b/tests/harness/model/requirements.py index e5d621a4fd5..9522649e6b7 100644 --- a/tests/harness/model/requirements.py +++ b/tests/harness/model/requirements.py @@ -100,7 +100,6 @@ def find_stack_component( components = depaginate( partial( client.list_stack_components, - user_id=client.active_user.id, name=self.name or None, type=self.type, flavor=self.flavor, From 9c4df6137efce5a13f33bc30b45d1ec3686e9d8b Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 24 Nov 2023 09:15:19 +0100 Subject: [PATCH 096/103] Fix some tests --- src/zenml/zen_stores/rest_zen_store.py | 2 +- tests/integration/functional/zen_stores/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index dfb800814ec..ba56bb97b25 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -3474,7 +3474,7 @@ def _list_paginated_resources( # So these items will be parsed into their correct types like here page_of_items.items = [ response_model.parse_obj(generic_item) # type: ignore[misc] - for generic_item in page_of_items.items + for generic_item in body["items"] ] return page_of_items diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index f45c3d8fd55..abd54268578 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -56,7 +56,6 @@ PipelineRequest, PipelineRunFilter, PipelineRunRequest, - PipelineUpdate, ResourceTypeModel, SecretFilterModel, SecretRequestModel, @@ -891,7 +890,8 @@ def update_method( version="1", version_hash="abc123", ), - update_model=PipelineUpdate(name=sample_name("updated_sample_pipeline")), + # Updating pipelines is not doing anything at the moment + # update_model=PipelineUpdate(name=sample_name("updated_sample_pipeline")), filter_model=PipelineFilter, entity_name="pipeline", ) From 5329e5e6477c4c1e32e30531ddc3e1cbb9e9f6f1 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 24 Nov 2023 09:34:00 +0100 Subject: [PATCH 097/103] Make model version test robust to existing models --- .../functional/zen_stores/test_zen_store.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 621a04aa958..f25148c8bd8 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -102,7 +102,6 @@ WorkspaceFilter, WorkspaceUpdate, ) -from zenml.models.model_models import ModelFilterModel from zenml.models.tag_models import ( TagFilterModel, TagRequestModel, @@ -3440,21 +3439,19 @@ def test_connector_validation(): class TestModel: def test_latest_version_properly_fetched(self): """Test that latest version can be properly fetched.""" - with ModelVersionContext() as model: + with ModelVersionContext() as created_model: zs = Client().zen_store - models = zs.list_models(ModelFilterModel()) - assert models[0].latest_version is None + assert zs.get_model(created_model.id).latest_version is None for name in ["great one", "yet another one"]: mv = zs.create_model_version( ModelVersionRequestModel( - user=model.user.id, - workspace=model.workspace.id, - model=model.id, + user=created_model.user.id, + workspace=created_model.workspace.id, + model=created_model.id, name=name, ) ) - models = zs.list_models(ModelFilterModel()) - assert models[0].latest_version == mv.name + assert zs.get_model(created_model.id).latest_version == mv.name time.sleep(1) # thanks to MySQL again! From 5f5b40a163285cc4bf2828a61b1dc0d9a2f0befb Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Fri, 24 Nov 2023 10:06:18 +0100 Subject: [PATCH 098/103] Fix tests that verify account deletion with owned resources --- .../functional/zen_stores/test_zen_store.py | 302 ++++++++++++++---- .../functional/zen_stores/utils.py | 48 ++- 2 files changed, 270 insertions(+), 80 deletions(-) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 621a04aa958..f525de041c3 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -442,42 +442,104 @@ def test_delete_user_with_resources_fails(): """Tests deleting a user with resources fails.""" zen_store = Client().zen_store - with UserContext(delete=False) as user: - with ComponentContext( + login = zen_store.type == StoreType.REST + + with UserContext(delete=False, login=login) as user: + component_context = ComponentContext( c_type=StackComponentType.ORCHESTRATOR, flavor="local", config={}, user_id=user.id, - ) as orchestrator: - with ComponentContext( - c_type=StackComponentType.ARTIFACT_STORE, - flavor="local", - config={}, - user_id=user.id, - ) as artifact_store: - components = { - StackComponentType.ORCHESTRATOR: [orchestrator.id], - StackComponentType.ARTIFACT_STORE: [artifact_store.id], - } - with StackContext(components=components, user_id=user.id): - with pytest.raises(IllegalOperationError): - zen_store.delete_user(user.id) - - with pytest.raises(IllegalOperationError): - zen_store.delete_user(user.id) - - with pytest.raises(IllegalOperationError): - zen_store.delete_user(user.id) + delete=False, + ) + with component_context as orchestrator: + # We only use the context as a shortcut to create the resource + pass + + # Can't delete because owned resources exist + with pytest.raises(IllegalOperationError): + zen_store.delete_user(user.id) + + component_context.cleanup() + + # Can delete because owned resources have been removed + with does_not_raise(): + zen_store.delete_user(user.id) + + with UserContext(delete=False, login=login) as user: + orchestrator_context = ComponentContext( + c_type=StackComponentType.ORCHESTRATOR, + flavor="local", + config={}, + user_id=user.id, + delete=False, + ) + artifact_store_context = ComponentContext( + c_type=StackComponentType.ARTIFACT_STORE, + flavor="local", + config={}, + user_id=user.id, + delete=False, + ) + + with orchestrator_context as orchestrator: + # We only use the context as a shortcut to create the resource + pass + with artifact_store_context as artifact_store: + # We only use the context as a shortcut to create the resource + pass + + components = { + StackComponentType.ORCHESTRATOR: [orchestrator.id], + StackComponentType.ARTIFACT_STORE: [artifact_store.id], + } + stack_context = StackContext( + components=components, user_id=user.id, delete=False + ) + with stack_context: + # We only use the context as a shortcut to create the resource + pass + + # Can't delete because owned resources exist + with pytest.raises(IllegalOperationError): + zen_store.delete_user(user.id) + + stack_context.cleanup() + artifact_store_context.cleanup() + orchestrator_context.cleanup() + # Can delete because owned resources have been removed + with does_not_raise(): + zen_store.delete_user(user.id) + + with UserContext(delete=False, login=login) as user: with SecretContext(user_id=user.id, delete=False): - # Secrets are deleted when the user is deleted pass - with CodeRepositoryContext(user_id=user.id): - with pytest.raises(IllegalOperationError): - zen_store.delete_user(user.id) + # Secrets are deleted when the user is deleted + with does_not_raise(): + zen_store.delete_user(user.id) - with ServiceConnectorContext( + with UserContext(delete=False, login=login) as user: + code_repo_context = CodeRepositoryContext( + user_id=user.id, delete=False + ) + with code_repo_context: + # We only use the context as a shortcut to create the resource + pass + + # Can't delete because owned resources exist + with pytest.raises(IllegalOperationError): + zen_store.delete_user(user.id) + + code_repo_context.cleanup() + + # Can delete because owned resources have been removed + with does_not_raise(): + zen_store.delete_user(user.id) + + with UserContext(delete=False, login=login) as user: + service_connector_context = ServiceConnectorContext( connector_type="cat'o'matic", auth_method="paw-print", resource_types=["cat"], @@ -487,14 +549,37 @@ def test_delete_user_with_resources_fails(): "foods": "tuna", }, user_id=user.id, - ): - with pytest.raises(IllegalOperationError): - zen_store.delete_user(user.id) + delete=False, + ) + with service_connector_context: + # We only use the context as a shortcut to create the resource + pass + + # Can't delete because owned resources exist + with pytest.raises(IllegalOperationError): + zen_store.delete_user(user.id) + + service_connector_context.cleanup() + + # Can delete because owned resources have been removed + with does_not_raise(): + zen_store.delete_user(user.id) + + with UserContext(delete=False, login=login) as user: + model_version_context = ModelVersionContext( + create_version=True, user_id=user.id, delete=False + ) + with model_version_context: + # We only use the context as a shortcut to create the resource + pass + + # Can't delete because owned resources exist + with pytest.raises(IllegalOperationError): + zen_store.delete_user(user.id) - with ModelVersionContext(create_version=True, user_id=user.id): - with pytest.raises(IllegalOperationError): - zen_store.delete_user(user.id) + model_version_context.cleanup() + # Can delete because owned resources have been removed with does_not_raise(): zen_store.delete_user(user.id) @@ -694,44 +779,104 @@ def test_delete_service_account_with_resources_fails(): """Tests deleting a service account with resources fails.""" zen_store = Client().zen_store - with ServiceAccountContext(delete=False) as service_account: - with ComponentContext( + login = zen_store.type == StoreType.REST + + with ServiceAccountContext(delete=False, login=login) as service_account: + component_context = ComponentContext( c_type=StackComponentType.ORCHESTRATOR, flavor="local", config={}, user_id=service_account.id, - ) as orchestrator: - with ComponentContext( - c_type=StackComponentType.ARTIFACT_STORE, - flavor="local", - config={}, - user_id=service_account.id, - ) as artifact_store: - components = { - StackComponentType.ORCHESTRATOR: [orchestrator.id], - StackComponentType.ARTIFACT_STORE: [artifact_store.id], - } - with StackContext( - components=components, user_id=service_account.id - ): - with pytest.raises(IllegalOperationError): - zen_store.delete_service_account(service_account.id) + delete=False, + ) + with component_context as orchestrator: + # We only use the context as a shortcut to create the resource + pass - with pytest.raises(IllegalOperationError): - zen_store.delete_service_account(service_account.id) + # Can't delete because owned resources exist + with pytest.raises(IllegalOperationError): + zen_store.delete_service_account(service_account.id) - with pytest.raises(IllegalOperationError): - zen_store.delete_service_account(service_account.id) + component_context.cleanup() + # Can delete because owned resources have been removed + with does_not_raise(): + zen_store.delete_service_account(service_account.id) + + with ServiceAccountContext(delete=False, login=login) as service_account: + orchestrator_context = ComponentContext( + c_type=StackComponentType.ORCHESTRATOR, + flavor="local", + config={}, + user_id=service_account.id, + delete=False, + ) + artifact_store_context = ComponentContext( + c_type=StackComponentType.ARTIFACT_STORE, + flavor="local", + config={}, + user_id=service_account.id, + delete=False, + ) + + with orchestrator_context as orchestrator: + # We only use the context as a shortcut to create the resource + pass + with artifact_store_context as artifact_store: + # We only use the context as a shortcut to create the resource + pass + + components = { + StackComponentType.ORCHESTRATOR: [orchestrator.id], + StackComponentType.ARTIFACT_STORE: [artifact_store.id], + } + stack_context = StackContext( + components=components, user_id=service_account.id, delete=False + ) + with stack_context: + # We only use the context as a shortcut to create the resource + pass + + # Can't delete because owned resources exist + with pytest.raises(IllegalOperationError): + zen_store.delete_service_account(service_account.id) + + stack_context.cleanup() + artifact_store_context.cleanup() + orchestrator_context.cleanup() + + # Can delete because owned resources have been removed + with does_not_raise(): + zen_store.delete_service_account(service_account.id) + + with ServiceAccountContext(delete=False, login=login) as service_account: with SecretContext(user_id=service_account.id, delete=False): - # Secrets are deleted when the user is deleted pass - with CodeRepositoryContext(user_id=service_account.id): - with pytest.raises(IllegalOperationError): - zen_store.delete_service_account(service_account.id) + # Secrets are deleted when the service_account is deleted + with does_not_raise(): + zen_store.delete_service_account(service_account.id) + + with ServiceAccountContext(delete=False, login=login) as service_account: + code_repo_context = CodeRepositoryContext( + user_id=service_account.id, delete=False + ) + with code_repo_context: + # We only use the context as a shortcut to create the resource + pass - with ServiceConnectorContext( + # Can't delete because owned resources exist + with pytest.raises(IllegalOperationError): + zen_store.delete_service_account(service_account.id) + + code_repo_context.cleanup() + + # Can delete because owned resources have been removed + with does_not_raise(): + zen_store.delete_service_account(service_account.id) + + with ServiceAccountContext(delete=False, login=login) as service_account: + service_connector_context = ServiceConnectorContext( connector_type="cat'o'matic", auth_method="paw-print", resource_types=["cat"], @@ -741,16 +886,37 @@ def test_delete_service_account_with_resources_fails(): "foods": "tuna", }, user_id=service_account.id, - ): - with pytest.raises(IllegalOperationError): - zen_store.delete_service_account(service_account.id) + delete=False, + ) + with service_connector_context: + # We only use the context as a shortcut to create the resource + pass - with ModelVersionContext( - create_version=True, user_id=service_account.id - ): - with pytest.raises(IllegalOperationError): - zen_store.delete_service_account(service_account.id) + # Can't delete because owned resources exist + with pytest.raises(IllegalOperationError): + zen_store.delete_service_account(service_account.id) + + service_connector_context.cleanup() + + # Can delete because owned resources have been removed + with does_not_raise(): + zen_store.delete_service_account(service_account.id) + + with ServiceAccountContext(delete=False, login=login) as service_account: + model_version_context = ModelVersionContext( + create_version=True, user_id=service_account.id, delete=False + ) + with model_version_context: + # We only use the context as a shortcut to create the resource + pass + + # Can't delete because owned resources exist + with pytest.raises(IllegalOperationError): + zen_store.delete_service_account(service_account.id) + + model_version_context.cleanup() + # Can delete because owned resources have been removed with does_not_raise(): zen_store.delete_service_account(service_account.id) diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index f45c3d8fd55..c026ce56562 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -283,11 +283,11 @@ def __exit__(self, exc_type, exc_value, exc_traceback): GlobalConfiguration._reset_instance(self.original_config) Client._reset_instance(self.original_client) _ = Client().zen_store + if not self.existing_account and self.delete: self.store.delete_api_key( self.created_service_account.id, self.api_key.id, ) - if not self.existing_account and self.delete: try: self.store.delete_service_account( self.created_service_account.id @@ -335,12 +335,14 @@ def __init__( components: Dict[StackComponentType, List[uuid.UUID]], stack_name: str = "aria", user_id: Optional[uuid.UUID] = None, + delete: bool = True, ): self.stack_name = sample_name(stack_name) self.user_id = user_id self.components = components self.client = Client() self.store = self.client.zen_store + self.delete = delete def __enter__(self): new_stack = StackRequest( @@ -352,12 +354,16 @@ def __enter__(self): self.created_stack = self.store.create_stack(new_stack) return self.created_stack - def __exit__(self, exc_type, exc_value, exc_traceback): + def cleanup(self): try: self.store.delete_stack(self.created_stack.id) except KeyError: pass + def __exit__(self, exc_type, exc_value, exc_traceback): + if self.delete: + self.cleanup() + class ComponentContext: def __init__( @@ -367,6 +373,7 @@ def __init__( flavor: str, component_name: str = "aria", user_id: Optional[uuid.UUID] = None, + delete: bool = True, ): self.component_name = sample_name(component_name) self.flavor = flavor @@ -375,6 +382,7 @@ def __init__( self.user_id = user_id self.client = Client() self.store = self.client.zen_store + self.delete = delete def __enter__(self): new_component = ComponentRequest( @@ -390,12 +398,16 @@ def __enter__(self): ) return self.created_component - def __exit__(self, exc_type, exc_value, exc_traceback): + def cleanup(self): try: self.store.delete_stack_component(self.created_component.id) except KeyError: pass + def __exit__(self, exc_type, exc_value, exc_traceback): + if self.delete: + self.cleanup() + class WorkspaceContext: def __init__( @@ -508,12 +520,15 @@ def __enter__(self): self.repo = self.store.create_code_repository(request) return self.repo + def cleanup(self): + try: + self.store.delete_code_repository(self.repo.id) + except KeyError: + pass + def __exit__(self, exc_type, exc_value, exc_traceback): if self.delete: - try: - self.store.delete_code_repository(self.repo.id) - except KeyError: - pass + self.cleanup() class ServiceConnectorContext: @@ -569,12 +584,15 @@ def __enter__(self): self.connector = self.store.create_service_connector(request) return self.connector + def cleanup(self): + try: + self.store.delete_service_connector(self.connector.id) + except KeyError: + pass + def __exit__(self, exc_type, exc_value, exc_traceback): if self.delete: - try: - self.store.delete_service_connector(self.connector.id) - except KeyError: - pass + self.cleanup() class ModelVersionContext: @@ -584,6 +602,7 @@ def __init__( create_artifacts: int = 0, create_prs: int = 0, user_id: Optional[uuid.UUID] = None, + delete: bool = True, ): client = Client() self.workspace = client.active_workspace.id @@ -597,6 +616,7 @@ def __init__( self.create_prs = create_prs self.prs = [] self.deployments = [] + self.delete = delete def __enter__(self): client = Client() @@ -676,7 +696,7 @@ def __enter__(self): else: return model - def __exit__(self, exc_type, exc_value, exc_traceback): + def cleanup(self): client = Client() try: client.delete_model(self.model) @@ -689,6 +709,10 @@ def __exit__(self, exc_type, exc_value, exc_traceback): for deployment in self.deployments: client.delete_deployment(str(deployment.id)) + def __exit__(self, exc_type, exc_value, exc_traceback): + if self.delete: + self.cleanup() + class CatClawMarks(AuthenticationConfig): """Cat claw marks authentication credentials.""" From d73747027bde4de601ef9de5a7ce79f667c7b6ad Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 24 Nov 2023 10:14:21 +0100 Subject: [PATCH 099/103] Require pipeline run permissions in api token endpoint --- src/zenml/zen_server/routers/auth_endpoints.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index f343a1b3ddd..e784bfff091 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -65,6 +65,8 @@ ) from zenml.zen_server.exceptions import error_response from zenml.zen_server.jwt import JWTToken +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import verify_permission from zenml.zen_server.utils import ( get_ip_location, handle_exceptions, @@ -506,6 +508,10 @@ def api_token( detail="Not authenticated.", ) + verify_permission( + resource_type=ResourceType.PIPELINE_RUN, action=Action.CREATE + ) + if not token.device_id: # If not authenticated with a device, the current API token is returned # as is, without any modifications. Issuing workload tokens is only From 0d08530a19635f97987e96373f5855a7375c991c Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Fri, 24 Nov 2023 10:21:48 +0100 Subject: [PATCH 100/103] Fix zen store tests after breaking them --- .../functional/zen_stores/test_zen_store.py | 13 ++++++++++--- tests/integration/functional/zen_stores/utils.py | 3 ++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index a1cf11d29a7..41188e2dc83 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -994,7 +994,9 @@ def test_list_service_accounts(): with ServiceAccountContext() as service_account_one: accounts = zen_store.list_service_accounts( - ServiceAccountFilter() + # TODO: we use a large size to get all accounts in one page, but + # the correct way to do this is to fetch all pages + ServiceAccountFilter(size=1000) ).items assert service_account_one.id in [account.id for account in accounts] @@ -1014,7 +1016,9 @@ def test_list_service_accounts(): with ServiceAccountContext() as service_account_two: accounts = zen_store.list_service_accounts( - ServiceAccountFilter() + # TODO: we use a large size to get all accounts in one page, but + # the correct way to do this is to fetch all pages + ServiceAccountFilter(size=1000) ).items assert service_account_one.id in [ account.id for account in accounts @@ -1046,6 +1050,7 @@ def test_list_service_accounts(): accounts = zen_store.list_service_accounts( ServiceAccountFilter( active=True, + size=1000, ) ).items assert service_account_one.id in [ @@ -1057,7 +1062,9 @@ def test_list_service_accounts(): with UserContext() as user: accounts = zen_store.list_service_accounts( - ServiceAccountFilter() + # TODO: we use a large size to get all accounts in one page, + # but the correct way to do this is to fetch all pages + ServiceAccountFilter(size=1000) ).items assert user.id not in [account.id for account in accounts] diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index d6182740999..3f65fe2f9ff 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -282,11 +282,12 @@ def __exit__(self, exc_type, exc_value, exc_traceback): GlobalConfiguration._reset_instance(self.original_config) Client._reset_instance(self.original_client) _ = Client().zen_store - if not self.existing_account and self.delete: + if self.existing_account or self.login and self.delete: self.store.delete_api_key( self.created_service_account.id, self.api_key.id, ) + if not self.existing_account and self.delete: try: self.store.delete_service_account( self.created_service_account.id From ade951c309a785062135cdf8ea2dc98ca485a99c Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Fri, 24 Nov 2023 10:26:33 +0000 Subject: [PATCH 101/103] Auto-update of E2E template --- examples/e2e/.copier-answers.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/e2e/.copier-answers.yml b/examples/e2e/.copier-answers.yml index ce44ee01f20..5d2316a71eb 100644 --- a/examples/e2e/.copier-answers.yml +++ b/examples/e2e/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2023.11.14-2-ga100ab7 +_commit: 2023.11.23 _src_path: gh:zenml-io/template-e2e-batch data_quality_checks: true email: '' From 286b4e556d00e485a13e9aeb0f5088aa71e9a0d0 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 24 Nov 2023 11:43:11 +0100 Subject: [PATCH 102/103] Prevent random failure on secret reference test --- tests/unit/utils/test_secret_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/utils/test_secret_utils.py b/tests/unit/utils/test_secret_utils.py index 0864f7a310b..040c6e2cceb 100644 --- a/tests/unit/utils/test_secret_utils.py +++ b/tests/unit/utils/test_secret_utils.py @@ -18,7 +18,7 @@ from zenml.utils import secret_utils -strategy = from_regex(r"[^.\s]{1,20}", fullmatch=True) +strategy = from_regex(r"[^.{}\s]{1,20}", fullmatch=True) @given(name=strategy, key=strategy) From bc4186ff29d559eca5a3741b0a0ed24c46fd56ff Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 24 Nov 2023 11:43:36 +0100 Subject: [PATCH 103/103] Exponential backoff when sending requests to cloud api --- src/zenml/zen_server/rbac/zenml_cloud_rbac.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py index ff12c17fd90..0d72fdcbff5 100644 --- a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -17,6 +17,7 @@ import requests from pydantic import BaseModel, validator +from requests.adapters import HTTPAdapter, Retry from zenml.zen_server.rbac.models import Action, Resource from zenml.zen_server.rbac.rbac_interface import RBACInterface @@ -236,8 +237,8 @@ def _get(self, endpoint: str, params: Dict[str, Any]) -> requests.Response: response.raise_for_status() except requests.HTTPError as e: raise RuntimeError( - "Failed while trying to contact RBAC service." - ) from e + f"Failed while trying to contact RBAC service: {e}" + ) return response @@ -253,6 +254,9 @@ def session(self) -> requests.Session: token = self._fetch_auth_token() self._session.headers.update({"Authorization": "Bearer " + token}) + retries = Retry(total=5, backoff_factor=0.1) + self._session.mount("https://", HTTPAdapter(max_retries=retries)) + return self._session def _clear_session(self) -> None: