diff --git a/.github/workflows/dockerized-tests.yml b/.github/workflows/dockerized-tests.yml index 7c4561a0..9a2f6adf 100644 --- a/.github/workflows/dockerized-tests.yml +++ b/.github/workflows/dockerized-tests.yml @@ -1,6 +1,6 @@ -name: Built docker image and run tests +name: Docker - build, test, push -# This workflow will built docker image and run tests inside the container. +# This workflow will build docker image and run tests inside the container. # This workflow is only executed if there is pull request with change in pyproject.toml dependencies, # or in Dockerfile, or in docker workflow. @@ -10,22 +10,118 @@ on: - '.github/workflows/docker-image.yml' - 'pyproject.toml' - 'Dockerfile' + + push: + branches: + - 'develop' + + release: + types: [published] # allows to manually start a workflow run from the GitHub UI or using the GitHub API. workflow_dispatch: + inputs: + push-image: + description: "Push image to docker hub" + required: false + type: boolean + default: false + push-description: + description: "Update docker hub description" + required: false + type: boolean + default: false + tag: + description: "Tag for the docker image" + required: false + default: "workflow-dispatch" + jobs: - built: + build: runs-on: ubuntu-latest - permissions: - packages: write + steps: + - uses: actions/checkout@v4 + # We do not bother with setup-qemu-action since we don't care about emulation right now + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Build + uses: docker/build-push-action@v5 + with: + context: . + file: ./Dockerfile + tags: aiod/metadata_catalogue:ci + outputs: type=docker,dest=/tmp/aiod_mc_image.tar + cache-from: type=gha + cache-to: type=gha,mode=min + # We store the image as an artifact, so it can be used by the `test` step + # and inspected manually if needed (download it through Github Actions UI) + - name: Store Image + uses: actions/upload-artifact@v4 + with: + name: aiod_mc_image + path: /tmp/aiod_mc_image.tar + test: + runs-on: ubuntu-latest + needs: [build] steps: - - uses: actions/checkout@v2 - - name: Build the docker image - run: docker build --tag aiod_metadata_catalogue:latest -f Dockerfile . - - - name: Run docker container and pytest tests + # We need to check out the repository, so that we have the `scripts` directory to mount. + # This is required to run the backup script tests. + - uses: actions/checkout@v4 + - name: Retrieve Image + uses: actions/download-artifact@v4 + with: + name: aiod_mc_image + path: /tmp + - name: Load Image + run: | + docker load --input /tmp/aiod_mc_image.tar + docker image ls -a + - name: Run pytest from docker run: | - docker run -v ./scripts:/scripts -e KEYCLOAK_CLIENT_SECRET="mocked_secret" --entrypoint "" aiod_metadata_catalogue sh -c "pip install \".[dev]\" && pytest tests -s" + docker run -v ./scripts:/scripts -e KEYCLOAK_CLIENT_SECRET="mocked_secret" --entrypoint "" aiod/metadata_catalogue:ci sh -c "pip install \".[dev]\" && pytest tests -s" + publish: + needs: [test] + runs-on: ubuntu-latest + if: github.event_name != 'pull_request' + steps: + # The correct tag depends on how this workflow was invoked, see also docker-description.md + - name: Set Develop Tag + if: github.ref == 'refs/heads/develop' + run: echo "IMAGE_TAGS=aiod/metadata_catalogue:develop" >> "$GITHUB_ENV" + - name: Set Release Tag + if: github.event_name == 'release' + run: echo "IMAGE_TAGS=aiod/metadata_catalogue:latest,aiod/metadata_catalogue:${{ github.event.release.tag_name }}" >> "$GITHUB_ENV" + - name: Set Dispatch Tag + if: github.event_name == 'workflow_dispatch' + run: echo "IMAGE_TAGS=aiod/metadata_catalogue:${{ inputs.tag }}" >> "$GITHUB_ENV" + - uses: actions/checkout@v4 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.AIOD_DOCKER_PAT }} + - name: Echo tags + run: echo $IMAGE_TAGS + - name: Build + if: (github.event_name != 'workflow_dispatch') || inputs.push-image + uses: docker/build-push-action@v5 + with: + push: true + context: . + file: ./Dockerfile + tags: ${{ env.IMAGE_TAGS }} + cache-from: type=gha + cache-to: type=gha,mode=min + - name: Update repository description + if: (github.event_name == 'release') || inputs.push-description + uses: peter-evans/dockerhub-description@v4 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.AIOD_DOCKER_PAT }} + repository: aiod/metadata_catalogue + readme-filepath: ./docker-description.md + short-description: "Metadata catalogue REST API for AI on Demand." diff --git a/.github/workflows/pytest-tests.yml b/.github/workflows/pytest-tests.yml index eb2d5d59..868c2293 100644 --- a/.github/workflows/pytest-tests.yml +++ b/.github/workflows/pytest-tests.yml @@ -41,10 +41,7 @@ jobs: source venv/bin/activate pre-commit run --all - - name: Test with pytest - run: | - source venv/bin/activate - pytest ./src/tests/ + diff --git a/README.md b/README.md index a2bf73de..bf6e16a4 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ Information on how to install Docker is found in [their documentation](https://d docker compose --profile examples up -d ``` -starts the MYSQL Server, the REST API, Keycloak for Identy and access management and Nginx for reverse proxing. \ +starts the MYSQL Server, the REST API, Keycloak for Identity and access management and Nginx for reverse proxying. \ Once started, you should be able to visit the REST API server at: http://localhost and Keycloak at http://localhost/aiod-auth \ To authenticate to the REST API swagger interface the predefined user is: user, and password: password \ To authenticate as admin to Keycloak the predefined user is: admin and password: password \ diff --git a/authentication/Dockerfile b/authentication/Dockerfile index 471835b6..eb385036 100644 --- a/authentication/Dockerfile +++ b/authentication/Dockerfile @@ -1,4 +1,4 @@ -FROM quay.io/keycloak/keycloak:latest as builder +FROM quay.io/keycloak/keycloak:24.0.4 as builder # Enable health and metrics support ENV KC_HEALTH_ENABLED=true @@ -12,7 +12,7 @@ WORKDIR /opt/keycloak #RUN keytool -genkeypair -storepass password -storetype PKCS12 -keyalg RSA -keysize 2048 -dname "CN=server" -alias server -ext "SAN:c=DNS:localhost,IP:127.0.0.1" -keystore conf/server.keystore #RUN /opt/keycloak/bin/kc.sh build -FROM quay.io/keycloak/keycloak:latest +FROM quay.io/keycloak/keycloak:24.0.4 COPY --from=builder /opt/keycloak/ /opt/keycloak/ # change these values to point to a running postgres instance diff --git a/docker-compose.yaml b/docker-compose.yaml index de3ba673..3b6e9c43 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -121,7 +121,7 @@ services: condition: service_healthy sqlserver: - image: mysql + image: mysql:8.3.0 container_name: sqlserver env_file: .env environment: @@ -137,7 +137,7 @@ services: retries: 30 keycloak: - image: quay.io/keycloak/keycloak + image: quay.io/keycloak/keycloak:24.0.4 container_name: keycloak env_file: .env environment: @@ -157,7 +157,7 @@ services: --import-realm nginx: - image: nginx + image: nginx:1.25.5 container_name: nginx restart: unless-stopped volumes: diff --git a/docker-description.md b/docker-description.md new file mode 100644 index 00000000..5d5676fd --- /dev/null +++ b/docker-description.md @@ -0,0 +1,11 @@ +# AIOD Metadata Catalogue + +Image for AI on Demand's (AIOD) metadata catalogue REST API, developed on [Github](https://github.com/aiondemand/AIOD-rest-api/). +This image requires a properly configured database setup to function. Additionally, to have all features working, authentication (via Keycloak) and search (through Elasticsearch) must also be configured. Please refer to the documentation available in the README of the ["AIOD-rest-api"](https://github.com/aiondemand/AIOD-rest-api) repository. + +The following tags are available: + + - `latest`: the latest official release + - `develop`: the head of the development branch + - `v*`, e.g. `v1.3.20240308`: that specific release + - a number of custom tags may be introduced for testing, but these are not intended for general use. diff --git a/pyproject.toml b/pyproject.toml index 39339db2..d1ca2ab8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "aiod_metadata_catalogue" description = "A Metadata Catalogue for AI on Demand " -version = "1.3.20240308" +version = "1.3.20240619" requires-python = ">=3.11" authors = [ { name = "Adrián Alcolea" }, @@ -47,7 +47,7 @@ dev = [ "pytest-asyncio==0.23.2", "pytest-dotenv==0.5.2", "pytest-xdist==3.5.0", - "pre-commit==3.6.0", + "pre-commit==3.7.0", "responses==0.24.1", "freezegun==1.4.0", ] diff --git a/scripts/backup.sh b/scripts/backup.sh old mode 100644 new mode 100755 diff --git a/scripts/mysql_dump.sh b/scripts/mysql_dump.sh old mode 100644 new mode 100755 diff --git a/scripts/mysql_restore.sh b/scripts/mysql_restore.sh old mode 100644 new mode 100755 diff --git a/scripts/realm_export.sh b/scripts/realm_export.sh old mode 100644 new mode 100755 index 58d3ea9a..bdb39799 --- a/scripts/realm_export.sh +++ b/scripts/realm_export.sh @@ -9,7 +9,7 @@ source .env DATA_PATH=$(realpath "$DATA_PATH") LOCAL_BACKUP_PATH="$DATA_PATH"/keycloak_realm -docker exec -i keycloak /bin/bash -c "/opt/keycloak/bin/kc.sh export --file /tmp/aiod.json --realm aiod" +docker exec -i keycloak /bin/bash -c "/opt/keycloak/bin/kc.sh export --file /tmp/aiod.json --realm aiod --users realm_file" if [ ! -d "$LOCAL_BACKUP_PATH" ]; then mkdir "$LOCAL_BACKUP_PATH" diff --git a/src/authentication.py b/src/authentication.py index 0496dff4..105ec7f8 100644 --- a/src/authentication.py +++ b/src/authentication.py @@ -17,6 +17,7 @@ performs a separate authorization request. The only downside is the overhead of the additional keycloak requests - if that becomes prohibitive in the future, we should reevaluate this design. """ + import logging import os @@ -55,15 +56,15 @@ def has_any_role(self, *roles: str) -> bool: return bool(set(roles) & self.roles) -async def get_current_user(token=Security(oidc)) -> User: +async def _get_user(token) -> User: """ - Use this function in Depends() to force authentication. Check the roles of the user for - authorization. + Check the roles of the user for authorization. Raises: - HTTPException with status 401 on missing token (unauthorized message), also status 401 on - any problem with the token (we don't want to leak information), status 500 on any - request if Keycloak is configured incorrectly. + NoTokenError on missing token (unauthorized message) and InvalidUserError on inactive user. + Also HTTPException with status 401 on any problem with the token + (we don't want to leak information), and status 500 on any request + if Keycloak is configured incorrectly. """ if not client_secret: raise HTTPException( @@ -73,11 +74,7 @@ async def get_current_user(token=Security(oidc)) -> User: "from a Keycloak Administrator of AIoD.", ) if not token: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="This endpoint requires authorization. You need to be logged in.", - headers={"WWW-Authenticate": "Bearer"}, - ) + raise NoTokenError("No token found") try: token = token.replace("Bearer ", "") # query the authorization server to determine the active state of this token and to @@ -86,10 +83,12 @@ async def get_current_user(token=Security(oidc)) -> User: if not userinfo.get("active", False): logging.error("Invalid userinfo or inactive user.") - raise RuntimeError("Invalid userinfo or inactive user.") # caught below + raise InvalidUserError("Invalid userinfo or inactive user") # caught below return User( name=userinfo["username"], roles=set(userinfo.get("realm_access", {}).get("roles", [])) ) + except InvalidUserError: + raise except Exception as e: logging.error(f"Error while checking the access token: '{e}'") raise HTTPException( @@ -97,3 +96,45 @@ async def get_current_user(token=Security(oidc)) -> User: detail="Invalid authentication token", headers={"WWW-Authenticate": "Bearer"}, ) + + +async def get_user_or_none(token=Security(oidc)) -> User | None: + """ + Use this function in Depends() to ask for authentication. + This method should be only used to get the current user + without raising exception when the token is not found, + or the user is not active, or the userinfo is invalid. + """ + try: + return await _get_user(token) + except (NoTokenError, InvalidUserError): + return None + + +async def get_user_or_raise(token=Security(oidc)) -> User: + """ + Use this function in Depends() to force authentication. Check the roles of the user for + authorization. + + Raises: + HTTPException with status 401 on missing token (unauthorized message), or invalid user. + It also raises a HTTPException with status 401 on + any problem with the token (we don't want to leak information), + status 500 on any request if Keycloak is configured incorrectly. + """ + try: + return await _get_user(token) + except (InvalidUserError, NoTokenError) as err: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"{err} - This endpoint requires authorization. You need to be logged in.", + headers={"WWW-Authenticate": "Bearer"}, + ) from err + + +class InvalidUserError(Exception): + """Raise an error on invalid userinfo or inactive user.""" + + +class NoTokenError(Exception): + """Raise an error when no token is found.""" diff --git a/src/database/model/agent/organisation.py b/src/database/model/agent/organisation.py index 5cefea81..60505e5f 100644 --- a/src/database/model/agent/organisation.py +++ b/src/database/model/agent/organisation.py @@ -50,7 +50,8 @@ class Organisation(OrganisationBase, Agent, table=True): # type: ignore [call-a class RelationshipConfig(Agent.RelationshipConfig): contact_details: int | None = OneToOne( - description="The contact details by which this organisation can be reached", + description="The identifier of the contact details by which this organisation " + "can be reached.", deserializer=FindByIdentifierDeserializer(Contact), _serializer=AttributeSerializer("identifier"), ) diff --git a/src/database/model/agent/person.py b/src/database/model/agent/person.py index 4408b70b..a8f783cc 100644 --- a/src/database/model/agent/person.py +++ b/src/database/model/agent/person.py @@ -58,7 +58,7 @@ class Person(PersonBase, Agent, table=True): # type: ignore [call-arg] class RelationshipConfig(Agent.RelationshipConfig): contact_details: int | None = OneToOne( - description="The contact details by which this person can be reached", + description="The identifier of the contact details by which this person can be reached", deserializer=FindByIdentifierDeserializer(Contact), _serializer=AttributeSerializer("identifier"), ) diff --git a/src/database/model/ai_resource/resource.py b/src/database/model/ai_resource/resource.py index fbfaf61d..5c04c757 100644 --- a/src/database/model/ai_resource/resource.py +++ b/src/database/model/ai_resource/resource.py @@ -180,14 +180,15 @@ class RelationshipConfig(AIoDConcept.RelationshipConfig): ) # TODO(jos): documentedIn - KnowledgeAsset. This should probably be defined on ResourceTable contact: list[int] = ManyToMany( - description="Contact information of persons/organisations that can be contacted about " - "this resource.", + description="The identifiers of the contact information of the persons and/or " + "organisations that can be contacted about this resource.", _serializer=AttributeSerializer("identifier"), deserializer=FindByIdentifierDeserializerList(Contact), default_factory_pydantic=list, ) creator: list[int] = ManyToMany( - description="Contact information of persons/organisations that created this resource.", + description="The identifiers of the contact information of the persons and/or " + "organisations that created this resource.", _serializer=AttributeSerializer("identifier"), deserializer=FindByIdentifierDeserializerList(Contact), default_factory_pydantic=list, diff --git a/src/database/model/platform/platform_names.py b/src/database/model/platform/platform_names.py index dadb61d4..c43fbc6e 100644 --- a/src/database/model/platform/platform_names.py +++ b/src/database/model/platform/platform_names.py @@ -9,6 +9,7 @@ class PlatformName(str, enum.Enum): """ aiod = "aiod" + ai4europe_cms = "ai4europe_cms" example = "example" openml = "openml" huggingface = "huggingface" diff --git a/src/database/session.py b/src/database/session.py index 49ba51f0..5f1955c6 100644 --- a/src/database/session.py +++ b/src/database/session.py @@ -38,14 +38,14 @@ def db_url(including_db=True): @contextmanager -def DbSession() -> Session: +def DbSession(autoflush: bool = True) -> Session: """ Returning a SQLModel session bound to the (configured) database engine. Alternatively, we could have used FastAPI Depends, but that only works for FastAPI - while the synchronization, for instance, also needs a Session, but doesn't use FastAPI. """ - session = Session(EngineSingleton().engine) + session = Session(EngineSingleton().engine, autoflush=autoflush) try: yield session finally: diff --git a/src/database/validators/huggingface_validators.py b/src/database/validators/huggingface_validators.py index 33203cfa..556c92ee 100644 --- a/src/database/validators/huggingface_validators.py +++ b/src/database/validators/huggingface_validators.py @@ -1,7 +1,4 @@ -import re - -REPO_ID_ILLEGAL_CHARACTERS = re.compile(r"[^0-9a-zA-Z-_./]+") -MSG_PREFIX = "The platform_resource_identifier for HuggingFace should be a valid repo_id. " +from huggingface_hub.utils import validate_repo_id def throw_error_on_invalid_identifier(platform_resource_identifier: str): @@ -11,26 +8,11 @@ def throw_error_on_invalid_identifier(platform_resource_identifier: str): Valid repo_ids: Between 1 and 96 characters. Either “repo_name” or “namespace/repo_name” - [a-zA-Z0-9] or ”-”, ”_”, ”.” - ”—” and ”..” are forbidden + [a-zA-Z0-9] or ”-”, ”_”, ”.”. + The following sequences ”--” and ”..” are forbidden. Refer to: https://huggingface.co/docs/huggingface_hub/package_reference/utilities#huggingface_hub.utils.validate_repo_id """ repo_id = platform_resource_identifier - if REPO_ID_ILLEGAL_CHARACTERS.search(repo_id): - msg = "A repo_id should only contain [a-zA-Z0-9] or ”-”, ”_”, ”.”" - raise ValueError(MSG_PREFIX + msg) - if not (1 < len(repo_id) < 96): - msg = "A repo_id should be between 1 and 96 characters." - raise ValueError(MSG_PREFIX + msg) - if repo_id.count("/") > 1: - msg = ( - "For new repositories, there should be a single forward slash in the repo_id (" - "namespace/repo_name). Legacy repositories are without a namespace. This repo_id has " - "too many forward slashes." - ) - raise ValueError(MSG_PREFIX + msg) - if ".." in repo_id: - msg = "A repo_id may not contain multiple consecutive dots." - raise ValueError(MSG_PREFIX + msg) + validate_repo_id(repo_id=repo_id) diff --git a/src/main.py b/src/main.py index cd816744..44acb904 100644 --- a/src/main.py +++ b/src/main.py @@ -4,6 +4,7 @@ Note: order matters for overloaded paths (https://fastapi.tiangolo.com/tutorial/path-params/#order-matters). """ + import argparse import pkg_resources @@ -12,7 +13,7 @@ from fastapi.responses import HTMLResponse from sqlmodel import select -from authentication import get_current_user, User +from authentication import get_user_or_raise, User from config import KEYCLOAK_CONFIG from database.deletion.triggers import add_delete_triggers from database.model.concept.concept import AIoDConcept @@ -62,7 +63,7 @@ def home() -> str: """ @app.get(url_prefix + "/authorization_test") - def test_authorization(user: User = Depends(get_current_user)) -> User: + def test_authorization(user: User = Depends(get_user_or_raise)) -> User: """ Returns the user, if authenticated correctly. """ diff --git a/src/routers/resource_router.py b/src/routers/resource_router.py index 84b0b389..47da64eb 100644 --- a/src/routers/resource_router.py +++ b/src/routers/resource_router.py @@ -2,8 +2,7 @@ import datetime import traceback from functools import partial -from typing import Literal, Union, Any, Annotated -from typing import TypeVar, Type +from typing import Annotated, Any, Literal, Sequence, Type, TypeVar, Union from wsgiref.handlers import format_date_time from fastapi import APIRouter, Depends, HTTPException, status, Query, Path @@ -14,7 +13,7 @@ from sqlmodel import SQLModel, Session, select, Field from starlette.responses import JSONResponse -from authentication import get_current_user, User +from authentication import User, get_user_or_none, get_user_or_raise from config import KEYCLOAK_CONFIG from converters.schema_converters.schema_converter import SchemaConverter from database.model.ai_resource.resource import AbstractAIResource @@ -52,6 +51,7 @@ class Pagination(BaseModel): RESOURCE = TypeVar("RESOURCE", bound=AbstractAIResource) RESOURCE_CREATE = TypeVar("RESOURCE_CREATE", bound=SQLModel) RESOURCE_READ = TypeVar("RESOURCE_READ", bound=SQLModel) +RESOURCE_MODEL = TypeVar("RESOURCE_MODEL", bound=SQLModel) class ResourceRouter(abc.ABC): @@ -103,7 +103,7 @@ def resource_name_plural(self) -> str: @property @abc.abstractmethod - def resource_class(self): + def resource_class(self) -> type[RESOURCE_MODEL]: pass @property @@ -204,42 +204,42 @@ def create(self, url_prefix: str) -> APIRouter: ) return router - def get_resources(self, schema: str, pagination: Pagination, platform: str | None = None): + def get_resources( + self, + schema: str, + pagination: Pagination, + user: User | None = None, + platform: str | None = None, + ): """Fetch all resources of this platform in given schema, using pagination""" _raise_error_on_invalid_schema(self._possible_schemas, schema) - with DbSession() as session: + with DbSession(autoflush=False) as session: try: convert_schema = ( partial(self.schema_converters[schema].convert, session) if schema != "aiod" else self.resource_class_read.from_orm ) - where_clause = and_( - is_(self.resource_class.date_deleted, None), - (self.resource_class.platform == platform) if platform is not None else True, - ) - query = ( - select(self.resource_class) - .where(where_clause) - .offset(pagination.offset) - .limit(pagination.limit) - ) - - return self._wrap_with_headers( - [convert_schema(resource) for resource in session.scalars(query).all()] + resources: Any = self._retrieve_resources_and_post_process( + session, pagination, user, platform ) + return self._wrap_with_headers([convert_schema(resource) for resource in resources]) except Exception as e: raise as_http_exception(e) - def get_resource(self, identifier: str, schema: str, platform: str | None = None): + def get_resource( + self, identifier: str, schema: str, user: User | None = None, platform: str | None = None + ): """ Get the resource identified by AIoD identifier (if platform is None) or by platform AND platform-identifier (if platform is not None), return in given schema. """ _raise_error_on_invalid_schema(self._possible_schemas, schema) try: - with DbSession() as session: - resource = self._retrieve_resource(session, identifier, platform=platform) + with DbSession(autoflush=False) as session: + resource: Any = self._retrieve_resource_and_post_process( + session, identifier, user, platform=platform + ) if schema != "aiod": return self.schema_converters[schema].convert(session, resource) return self._wrap_with_headers(self.resource_class_read.from_orm(resource)) @@ -256,8 +256,11 @@ def get_resources_func(self): def get_resources( pagination: Pagination = Depends(), schema: self._possible_schemas_type = "aiod", # type:ignore + user: User | None = Depends(get_user_or_none), ): - resources = self.get_resources(pagination=pagination, schema=schema, platform=None) + resources = self.get_resources( + pagination=pagination, schema=schema, user=user, platform=None + ) return resources return get_resources @@ -318,8 +321,11 @@ def get_resources( ], pagination: Annotated[Pagination, Depends(Pagination)], schema: self._possible_schemas_type = "aiod", # type:ignore + user: User | None = Depends(get_user_or_none), ): - resources = self.get_resources(pagination=pagination, schema=schema, platform=platform) + resources = self.get_resources( + pagination=pagination, schema=schema, user=user, platform=platform + ) return resources return get_resources @@ -334,8 +340,11 @@ def get_resource_func(self): def get_resource( identifier: str, schema: self._possible_schemas_type = "aiod", # type: ignore + user: User | None = Depends(get_user_or_none), ): - resource = self.get_resource(identifier=identifier, schema=schema, platform=None) + resource = self.get_resource( + identifier=identifier, schema=schema, user=user, platform=None + ) return self._wrap_with_headers(resource) return get_resource @@ -362,8 +371,11 @@ def get_resource( ), ], schema: self._possible_schemas_type = "aiod", # type:ignore + user: User | None = Depends(get_user_or_none), ): - return self.get_resource(identifier=identifier, schema=schema, platform=platform) + return self.get_resource( + identifier=identifier, schema=schema, user=user, platform=platform + ) return get_resource @@ -377,7 +389,7 @@ def register_resource_func(self): def register_resource( resource_create: clz_create, # type: ignore - user: User = Depends(get_current_user), + user: User = Depends(get_user_or_raise), ): if not user.has_any_role( KEYCLOAK_CONFIG.get("role"), @@ -421,7 +433,7 @@ def put_resource_func(self): def put_resource( identifier: int, resource_create_instance: clz_create, # type: ignore - user: User = Depends(get_current_user), + user: User = Depends(get_user_or_raise), ): if not user.has_any_role( KEYCLOAK_CONFIG.get("role"), @@ -435,7 +447,7 @@ def put_resource( with DbSession() as session: try: - resource = self._retrieve_resource(session, identifier) + resource: Any = self._retrieve_resource(session, identifier) for attribute_name in resource.schema()["properties"]: if hasattr(resource_create_instance, attribute_name): new_value = getattr(resource_create_instance, attribute_name) @@ -465,7 +477,7 @@ def delete_resource_func(self): def delete_resource( identifier: str, - user: User = Depends(get_current_user), + user: User = Depends(get_user_or_raise), ): with DbSession() as session: if not user.has_any_role( @@ -479,7 +491,7 @@ def delete_resource( ) try: # Raise error if it does not exist - resource = self._retrieve_resource(session, identifier) + resource: Any = self._retrieve_resource(session, identifier) if ( hasattr(self.resource_class, "__deletion_config__") and not self.resource_class.__deletion_config__["soft_delete"] @@ -495,7 +507,16 @@ def delete_resource( return delete_resource - def _retrieve_resource(self, session, identifier, platform=None): + def _retrieve_resource( + self, + session: Session, + identifier: int | str, + platform: str | None = None, + ) -> type[RESOURCE_MODEL]: + """ + Retrieve a resource from the database based on the provided identifier + and platform (if applicable). + """ if platform is None: query = select(self.resource_class).where(self.resource_class.identifier == identifier) else: @@ -525,6 +546,72 @@ def _retrieve_resource(self, session, identifier, platform=None): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{name} {msg}") return resource + def _retrieve_resources( + self, + session: Session, + pagination: Pagination, + platform: str | None = None, + ) -> Sequence[type[RESOURCE_MODEL]]: + """ + Retrieve a sequence of resources from the database based on the provided identifier + and platform (if applicable). + """ + where_clause = and_( + is_(self.resource_class.date_deleted, None), + (self.resource_class.platform == platform) if platform is not None else True, + ) + query = ( + select(self.resource_class) + .where(where_clause) + .offset(pagination.offset) + .limit(pagination.limit) + ) + resources: Sequence = session.scalars(query).all() + return resources + + def _retrieve_resource_and_post_process( + self, + session: Session, + identifier: int | str, + user: User | None = None, + platform: str | None = None, + ) -> type[RESOURCE_MODEL]: + """ + Retrieve a resource from the database based on the provided identifier + and platform (if applicable). The user parameter can be used by subclasses to + implement further verification on user access to the resource. + """ + resource: type[RESOURCE_MODEL] = self._retrieve_resource(session, identifier, platform) + [processed_resource] = self._mask_or_filter([resource], session, user) + return processed_resource + + def _retrieve_resources_and_post_process( + self, + session: Session, + pagination: Pagination, + user: User | None = None, + platform: str | None = None, + ) -> Sequence[type[RESOURCE_MODEL]]: + """ + Retrieve a sequence of resources from the database based on the provided identifier + and platform (if applicable). The user parameter can be used by subclasses to + implement further verification on user access to the resource. + """ + resources: Sequence[type[RESOURCE_MODEL]] = self._retrieve_resources( + session, pagination, platform + ) + return self._mask_or_filter(resources, session, user) + + @staticmethod + def _mask_or_filter( + resources: Sequence[type[RESOURCE_MODEL]], session: Session, user: User | None + ) -> Sequence[type[RESOURCE_MODEL]]: + """ + Can be implemented in children to post process resources based on user roles + or something else. + """ + return resources + @property def _possible_schemas(self) -> list[str]: return ["aiod"] + list(self.schema_converters.keys()) diff --git a/src/routers/resource_routers/contact_router.py b/src/routers/resource_routers/contact_router.py index 6a4dadb1..e6a19b9c 100644 --- a/src/routers/resource_routers/contact_router.py +++ b/src/routers/resource_routers/contact_router.py @@ -1,8 +1,14 @@ +from typing import Sequence +from authentication import User from database.model.agent.contact import Contact +from database.model.agent.email import Email from database.model.agent.organisation import Organisation from database.model.agent.person import Person +from database.model.platform.platform_names import PlatformName from routers.resource_router import ResourceRouter +from sqlmodel import Session + class ContactRouter(ResourceRouter): def __init__(self): @@ -30,3 +36,20 @@ def resource_name_plural(self) -> str: @property def resource_class(self) -> type[Contact]: return Contact + + @staticmethod + def _mask_or_filter( + resources: Sequence[type[Contact]], session: Session, user: User | None + ) -> Sequence[type[Contact]]: + """ + Only authenticated users can see the contact email. + For the old ai4europe_cms platform, only users with "full_view_ai4europe_cms_resources" role + can view the contact emails. + """ + for contact in resources: + if not user or ( + (contact.platform == PlatformName.ai4europe_cms) + and not user.has_role("full_view_ai4europe_cms_resources") + ): + contact.email = [Email(name="******")] + return resources diff --git a/src/routers/resource_routers/person_router.py b/src/routers/resource_routers/person_router.py index f77768cf..e2d6b7a9 100644 --- a/src/routers/resource_routers/person_router.py +++ b/src/routers/resource_routers/person_router.py @@ -1,5 +1,9 @@ +from typing import Sequence +from sqlmodel import Session from database.model.agent.person import Person +from database.model.platform.platform_names import PlatformName from routers.resource_router import ResourceRouter +from authentication import User class PersonRouter(ResourceRouter): @@ -18,3 +22,20 @@ def resource_name_plural(self) -> str: @property def resource_class(self) -> type[Person]: return Person + + @staticmethod + def _mask_or_filter( + resources: Sequence[type[Person]], session: Session, user: User | None + ) -> Sequence[type[Person]]: + """ + For the old ai4europe_cms platform, only users with "full_view_ai4europe_cms_resources" + role can see the person's sensitive information. + """ + for person in resources: + if (person.platform == PlatformName.ai4europe_cms) and not ( + user and user.has_role("full_view_ai4europe_cms_resources") + ): + person.name = "******" + person.given_name = "******" + person.surname = "******" + return resources diff --git a/src/routers/search_routers/__init__.py b/src/routers/search_routers/__init__.py index 7dcceb07..6399ca85 100644 --- a/src/routers/search_routers/__init__.py +++ b/src/routers/search_routers/__init__.py @@ -1,4 +1,5 @@ from .search_router_datasets import SearchRouterDatasets +from .search_router_educational_resources import SearchRouterEducationalResources from .search_router_events import SearchRouterEvents from .search_router_experiments import SearchRouterExperiments from .search_router_ml_models import SearchRouterMLModels @@ -11,6 +12,7 @@ router_list: list[SearchRouter] = [ SearchRouterDatasets(), + SearchRouterEducationalResources(), SearchRouterEvents(), SearchRouterExperiments(), SearchRouterMLModels(), diff --git a/src/routers/search_routers/search_router_educational_resources.py b/src/routers/search_routers/search_router_educational_resources.py new file mode 100644 index 00000000..517c2503 --- /dev/null +++ b/src/routers/search_routers/search_router_educational_resources.py @@ -0,0 +1,20 @@ +from database.model.educational_resource.educational_resource import EducationalResource +from routers.search_router import SearchRouter + + +class SearchRouterEducationalResources(SearchRouter[EducationalResource]): + @property + def es_index(self) -> str: + return "educational_resource" + + @property + def resource_name_plural(self) -> str: + return "educational_resources" + + @property + def resource_class(self): + return EducationalResource + + @property + def indexed_fields(self): + return {"name", "description_plain", "description_html"} diff --git a/src/routers/uploader_routers/upload_router_huggingface.py b/src/routers/uploader_routers/upload_router_huggingface.py index 1a847e44..a31c95ad 100644 --- a/src/routers/uploader_routers/upload_router_huggingface.py +++ b/src/routers/uploader_routers/upload_router_huggingface.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends from fastapi import File, Query, UploadFile -from authentication import User, get_current_user +from authentication import User, get_user_or_raise from routers.uploader_router import UploaderRouter from uploaders.hugging_face_uploader import HuggingfaceUploader @@ -24,7 +24,7 @@ def huggingface_upload( username: str = Query( ..., title="Huggingface username", description="The username of HuggingFace" ), - user: User = Depends(get_current_user), + user: User = Depends(get_user_or_raise), ) -> int: """ Use this endpoint to upload a file (content) to Hugging Face using diff --git a/src/routers/uploader_routers/upload_router_zenodo.py b/src/routers/uploader_routers/upload_router_zenodo.py index 80278624..b96b5dc8 100644 --- a/src/routers/uploader_routers/upload_router_zenodo.py +++ b/src/routers/uploader_routers/upload_router_zenodo.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends from fastapi import File, Query, UploadFile, Path -from authentication import User, get_current_user +from authentication import User, get_user_or_raise from uploaders.zenodo_uploader import ZenodoUploader from routers.uploader_router import UploaderRouter @@ -32,7 +32,7 @@ def zenodo_upload( ), ] = False, token: str = Query(title="Zenodo Token", description="The access token of Zenodo"), - user: User = Depends(get_current_user), + user: User = Depends(get_user_or_raise), ) -> int: """ Use this endpoint to upload a file (content) to Zenodo using diff --git a/src/setup/es_setup/definitions.py b/src/setup/es_setup/definitions.py index d37a807c..7fa48416 100755 --- a/src/setup/es_setup/definitions.py +++ b/src/setup/es_setup/definitions.py @@ -4,6 +4,7 @@ "date_modified": {"type": "date"}, "identifier": {"type": "long"}, "name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, + "platform": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "description_plain": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, "description_html": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, } diff --git a/src/setup/logstash_setup/templates/sql_init.py b/src/setup/logstash_setup/templates/sql_init.py index 7c16dd72..0e594110 100755 --- a/src/setup/logstash_setup/templates/sql_init.py +++ b/src/setup/logstash_setup/templates/sql_init.py @@ -1,6 +1,7 @@ TEMPLATE_SQL_INIT = """SELECT {{entity_name}}.identifier, {{entity_name}}.name, + {{entity_name}}.platform, text.plain as 'description_plain', text.html as 'description_html', aiod_entry.date_modified{{extra_fields}} diff --git a/src/setup/logstash_setup/templates/sql_rm.py b/src/setup/logstash_setup/templates/sql_rm.py index 335f9da3..7cfce100 100755 --- a/src/setup/logstash_setup/templates/sql_rm.py +++ b/src/setup/logstash_setup/templates/sql_rm.py @@ -1,4 +1,6 @@ -TEMPLATE_SQL_RM = """SELECT {{entity_name}}.identifier +TEMPLATE_SQL_RM = """SELECT + {{entity_name}}.identifier, + {{entity_name}}.date_deleted FROM aiod.{{entity_name}} WHERE aiod.{{entity_name}}.date_deleted IS NOT NULL AND aiod.{{entity_name}}.date_deleted > :sql_last_value diff --git a/src/setup/logstash_setup/templates/sql_sync.py b/src/setup/logstash_setup/templates/sql_sync.py index d2e532e1..1475143c 100755 --- a/src/setup/logstash_setup/templates/sql_sync.py +++ b/src/setup/logstash_setup/templates/sql_sync.py @@ -1,9 +1,10 @@ TEMPLATE_SQL_SYNC = """SELECT {{entity_name}}.identifier, {{entity_name}}.name, + {{entity_name}}.platform, text.plain as 'description_plain', text.html as 'description_html', - aiod_entry.date_modified{{extra_fields}} + aiod_entry.date_modified as 'date_modified'{{extra_fields}} FROM aiod.{{entity_name}} INNER JOIN aiod.aiod_entry ON aiod.{{entity_name}}.aiod_entry_identifier=aiod.aiod_entry.identifier LEFT JOIN aiod.text ON aiod.{{entity_name}}.description_identifier=aiod.text.identifier diff --git a/src/tests/resources/elasticsearch/dataset_search.json b/src/tests/resources/elasticsearch/dataset_search.json index b797082d..e7c2012a 100644 --- a/src/tests/resources/elasticsearch/dataset_search.json +++ b/src/tests/resources/elasticsearch/dataset_search.json @@ -22,6 +22,7 @@ "identifier" : 1, "date_modified" : "2023-09-01T00:00:00.000Z", "name" : "A name.", + "platform" : "A platform.", "description_plain" : "A plain text description.", "issn" : "20493630", "type" : "dataset", diff --git a/src/tests/resources/elasticsearch/educational_resource_search.json b/src/tests/resources/elasticsearch/educational_resource_search.json new file mode 100644 index 00000000..8e6cbf8e --- /dev/null +++ b/src/tests/resources/elasticsearch/educational_resource_search.json @@ -0,0 +1,36 @@ +{ + "took" : 1, + "timed_out" : false, + "_shards" : { + "total" : 1, + "successful" : 1, + "skipped" : 0, + "failed" : 0 + }, + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : null, + "hits" : [ + { + "_index" : "educational_resource", + "_id" : "educational_resource_1", + "_score" : null, + "_source" : { + "identifier" : 1, + "date_modified" : "2023-09-01T00:00:00.000Z", + "name" : "A name.", + "platform" : "A platform.", + "description_plain" : "A plain text description.", + "description_html" : "An html description.", + "type" : "educational_resource" + }, + "sort" : [ + 1 + ] + } + ] + } +} diff --git a/src/tests/resources/elasticsearch/event_search.json b/src/tests/resources/elasticsearch/event_search.json index 13fd6135..add5a17f 100644 --- a/src/tests/resources/elasticsearch/event_search.json +++ b/src/tests/resources/elasticsearch/event_search.json @@ -22,6 +22,7 @@ "identifier" : 1, "date_modified" : "2023-09-01T00:00:00.000Z", "name" : "A name.", + "platform" : "A platform.", "description_plain" : "A plain text description.", "type" : "event", "description_html" : "An html description." diff --git a/src/tests/resources/elasticsearch/experiment_search.json b/src/tests/resources/elasticsearch/experiment_search.json index 5c4a815b..e2f4e657 100644 --- a/src/tests/resources/elasticsearch/experiment_search.json +++ b/src/tests/resources/elasticsearch/experiment_search.json @@ -22,6 +22,7 @@ "identifier" : 1, "date_modified" : "2023-09-01T00:00:00.000Z", "name" : "A name.", + "platform" : "A platform.", "description_plain" : "A plain text description.", "type" : "experiment", "description_html" : "An html description." diff --git a/src/tests/resources/elasticsearch/ml_model_search.json b/src/tests/resources/elasticsearch/ml_model_search.json index de60e1b0..9d28884d 100644 --- a/src/tests/resources/elasticsearch/ml_model_search.json +++ b/src/tests/resources/elasticsearch/ml_model_search.json @@ -22,6 +22,7 @@ "identifier" : 1, "date_modified" : "2023-09-01T00:00:00.000Z", "name" : "A name.", + "platform" : "A platform.", "description_plain" : "A plain text description.", "type" : "ml_model", "description_html" : "An html description." diff --git a/src/tests/resources/elasticsearch/news_search.json b/src/tests/resources/elasticsearch/news_search.json index d3d8f0ac..098ade52 100644 --- a/src/tests/resources/elasticsearch/news_search.json +++ b/src/tests/resources/elasticsearch/news_search.json @@ -23,6 +23,7 @@ "headline" : "A headline.", "date_modified" : "2023-09-01T00:00:00.000Z", "name" : "A name.", + "platform" : "A platform.", "description_plain" : "A plain text description.", "type" : "news", "description_html" : "An html description.", diff --git a/src/tests/resources/elasticsearch/organisation_search.json b/src/tests/resources/elasticsearch/organisation_search.json index a0540021..ebf21542 100644 --- a/src/tests/resources/elasticsearch/organisation_search.json +++ b/src/tests/resources/elasticsearch/organisation_search.json @@ -22,6 +22,7 @@ "identifier" : 1, "date_modified" : "2023-09-01T00:00:00.000Z", "name" : "A name.", + "platform" : "A platform.", "description_plain" : "A plain text description.", "type" : "organisation", "description_html" : "An html description.", diff --git a/src/tests/resources/elasticsearch/project_search.json b/src/tests/resources/elasticsearch/project_search.json index f16a0422..04166363 100644 --- a/src/tests/resources/elasticsearch/project_search.json +++ b/src/tests/resources/elasticsearch/project_search.json @@ -22,6 +22,7 @@ "identifier" : 1, "date_modified" : "2023-09-01T00:00:00.000Z", "name" : "A name.", + "platform" : "A platform.", "description_plain" : "A plain text description.", "type" : "project", "description_html" : "An html description." diff --git a/src/tests/resources/elasticsearch/publication_search.json b/src/tests/resources/elasticsearch/publication_search.json index 98b61157..56afa65f 100644 --- a/src/tests/resources/elasticsearch/publication_search.json +++ b/src/tests/resources/elasticsearch/publication_search.json @@ -22,6 +22,7 @@ "identifier" : 1, "date_modified" : "2023-09-01T00:00:00.000Z", "name" : "A name.", + "platform" : "A platform.", "description_plain" : "A plain text description.", "issn" : "20493630", "type" : "publication", diff --git a/src/tests/resources/elasticsearch/service_search.json b/src/tests/resources/elasticsearch/service_search.json index 796f57c0..50f119ee 100644 --- a/src/tests/resources/elasticsearch/service_search.json +++ b/src/tests/resources/elasticsearch/service_search.json @@ -23,6 +23,7 @@ "slogan" : "A slogan.", "date_modified" : "2023-09-01T00:00:00.000Z", "name" : "A name.", + "platform" : "A platform.", "description_plain" : "A plain text description.", "type" : "service", "description_html" : "An html description." diff --git a/src/tests/routers/generic/test_authentication.py b/src/tests/routers/generic/test_authentication.py index 13f35649..3325993e 100644 --- a/src/tests/routers/generic/test_authentication.py +++ b/src/tests/routers/generic/test_authentication.py @@ -124,7 +124,8 @@ def test_post_unauthenticated(client_test_resource: TestClient): assert response.status_code == 401, response.json() response_json = response.json() assert ( - response_json["detail"] == "This endpoint requires authorization. You need to be logged in." + response_json["detail"] + == "No token found - This endpoint requires authorization. You need to be logged in." ) @@ -173,5 +174,6 @@ def test_put_unauthenticated(client_test_resource: TestClient): assert response.status_code == 401, response.json() response_json = response.json() assert ( - response_json["detail"] == "This endpoint requires authorization. You need to be logged in." + response_json["detail"] + == "No token found - This endpoint requires authorization. You need to be logged in." ) diff --git a/src/tests/routers/resource_routers/test_router_contact.py b/src/tests/routers/resource_routers/test_router_contact.py index d0d725c4..5931ab4d 100644 --- a/src/tests/routers/resource_routers/test_router_contact.py +++ b/src/tests/routers/resource_routers/test_router_contact.py @@ -1,9 +1,15 @@ import copy +import pytest from unittest.mock import Mock from starlette.testclient import TestClient from authentication import keycloak_openid +from database.model.agent.contact import Contact +from database.model.agent.email import Email +from database.model.platform.platform import Platform +from database.session import DbSession +from tests.testutils.default_instances import _create_class_with_body def test_happy_path(client: TestClient, mocked_privileged_token: Mock, body_asset: dict): @@ -27,7 +33,7 @@ def test_happy_path(client: TestClient, mocked_privileged_token: Mock, body_asse response = client.post("/contacts/v1", json=body, headers={"Authorization": "Fake token"}) assert response.status_code == 200, response.json() - response = client.get("/contacts/v1/1") + response = client.get("/contacts/v1/1", headers={"Authorization": "Fake token"}) assert response.status_code == 200, response.json() response_json = response.json() @@ -59,11 +65,11 @@ def test_post_duplicate_email( response = client.post("/contacts/v1", json=body2, headers={"Authorization": "Fake token"}) assert response.status_code == 200, response.json() - contact = client.get("/contacts/v1/2").json() + contact = client.get("/contacts/v1/2", headers={"Authorization": "Fake token"}).json() assert set(contact["email"]) == {"b@example.com", "c@example.com"} body3 = {"email": ["d@example.com", "b@example.com"]} client.put("/contacts/v1/1", json=body3, headers={"Authorization": "Fake token"}) - contact = client.get("/contacts/v1/2").json() + contact = client.get("/contacts/v1/2", headers={"Authorization": "Fake token"}).json() msg = "changing emails of contact 1 should not change emails of contact 2." assert set(contact["email"]) == {"b@example.com", "c@example.com"}, msg @@ -78,3 +84,140 @@ def test_person_and_organisation_both_specified(client: TestClient, mocked_privi response = client.post("/contacts/v1", json=body, headers=headers) assert response.status_code == 400, response.json() assert response.json()["detail"] == "Person and organisation cannot be both filled." + + +@pytest.fixture +def contact2(body_concept) -> Contact: + body = copy.copy(body_concept) + body["platform_resource_identifier"] = "fake:100" + body["email"] = ["fake@email.com", "fake2@email.com"] + return _create_class_with_body(Contact, body) + + +@pytest.fixture( + params=[ + "/contacts/v1", + "/contacts/v1/1", + "/platforms/example/contacts/v1", + "/platforms/example/contacts/v1/fake:100", + ] +) +def endpoint_from_fixture1(request) -> str: + return request.param + + +def test_email_mask_for_not_authenticated_user( + client: TestClient, + mocked_privileged_token: Mock, + contact: Contact, + contact2: Contact, + endpoint_from_fixture1: str, +): + keycloak_openid.introspect = mocked_privileged_token + + with DbSession() as session: + session.add(contact) + session.add(contact2) + session.commit() + + guest_response = client.get(endpoint_from_fixture1) + assert guest_response.status_code == 200, guest_response.json() + guest_response_json = guest_response.json() + if not isinstance(guest_response_json, list): + guest_response_json = [guest_response_json] + assert len(guest_response_json) > 0, guest_response_json + for contact_json in guest_response_json: + assert contact_json["email"] == ["******"] + + +def test_email_mask_for_authenticated_user( + client: TestClient, + mocked_privileged_token: Mock, + contact: Contact, + contact2: Contact, +): + keycloak_openid.introspect = mocked_privileged_token + headers = {"Authorization": "Fake token"} + + with DbSession() as session: + session.add(contact) + session.add(contact2) + session.commit() + + response = client.get("/contacts/v1", headers=headers) + response_json = response.json() + assert response.status_code == 200, response_json + assert len(response_json) == 2, response_json + assert response_json[0]["email"] == ["a@b.com"] + assert set(response_json[1]["email"]) == {"fake2@email.com", "fake@email.com"} + + response = client.get("/contacts/v1/2", headers=headers) + assert response.status_code == 200, response.json() + response_json = response.json() + assert set(response_json["email"]) == {"fake2@email.com", "fake@email.com"} + + response = client.get("/platforms/example/contacts/v1", headers=headers) + response_json = response.json() + assert response.status_code == 200, response_json + assert len(response_json) == 2, response_json + assert response_json[0]["email"] == ["a@b.com"] + assert set(response_json[1]["email"]) == {"fake2@email.com", "fake@email.com"} + + response = client.get("/platforms/example/contacts/v1/fake:100", headers=headers) + response_json = response.json() + assert response.status_code == 200, response_json + assert set(response_json["email"]) == {"fake2@email.com", "fake@email.com"} + + +@pytest.fixture( + params=[ + "/contacts/v1", + "/contacts/v1/1", + "/platforms/ai4europe_cms/contacts/v1", + "/platforms/ai4europe_cms/contacts/v1/fake:100", + ] +) +def endpoint_from_fixture2(request) -> str: + return request.param + + +def test_email_privacy_for_ai4europe_cms( + client: TestClient, + mocked_privileged_token: Mock, + mocked_ai4europe_cms_token: Mock, + contact: Contact, + platform: Platform, + endpoint_from_fixture2: str, +): + + with DbSession() as session: + contact.platform = "ai4europe_cms" + contact.platform_resource_identifier = "fake:100" + email = Email(name="fake@email.com") + another_email = Email(name="fake2@email.com") + contact.email = [email, another_email] + session.add(contact) + session.commit() + + keycloak_openid.introspect = mocked_privileged_token + headers = {"Authorization": "Fake token"} + + response = client.get(endpoint_from_fixture2, headers=headers) + response_json = response.json() + if isinstance(response_json, list): + response_json = response_json[0] + + assert response.status_code == 200, response_json + assert len(response_json) > 0, response_json + assert response_json["email"] == ["******"] + + keycloak_openid.introspect = mocked_ai4europe_cms_token + + response = client.get(endpoint_from_fixture2, headers=headers) + response_json = response.json() + if isinstance(response_json, list): + response_json = response_json[0] + + assert response.status_code == 200, response_json + assert len(response_json) > 0, response_json + assert response_json["email"] == ["fake@email.com", "fake2@email.com"] diff --git a/src/tests/routers/resource_routers/test_router_dataset.py b/src/tests/routers/resource_routers/test_router_dataset.py index c92ae95e..9c82f818 100644 --- a/src/tests/routers/resource_routers/test_router_dataset.py +++ b/src/tests/routers/resource_routers/test_router_dataset.py @@ -59,14 +59,15 @@ def test_post_invalid_huggingface_identifier( ): keycloak_openid.userinfo = mocked_privileged_token - body = {"name": "name", "platform": "huggingface", "platform_resource_identifier": "a"} + body = {"name": "name", "platform": "huggingface", "platform_resource_identifier": ""} response = client.post("/datasets/v1", json=body, headers={"Authorization": "Fake token"}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json() assert ( response.json()["detail"][0]["msg"] - == "The platform_resource_identifier for HuggingFace should be a valid repo_id. A repo_id " - "should be between 1 and 96 characters." + == "Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are" + " forbidden, '-' and '.' cannot start or end the name, max length is 96:" + f" '{body['platform_resource_identifier']}'." ) diff --git a/src/tests/routers/resource_routers/test_router_dataset_generic_fields.py b/src/tests/routers/resource_routers/test_router_dataset_generic_fields.py index fb354d2c..9f00cf0e 100644 --- a/src/tests/routers/resource_routers/test_router_dataset_generic_fields.py +++ b/src/tests/routers/resource_routers/test_router_dataset_generic_fields.py @@ -134,8 +134,8 @@ def test_happy_path( date_created = dateutil.parser.parse(response_json["aiod_entry"]["date_created"] + "Z") date_modified = dateutil.parser.parse(response_json["aiod_entry"]["date_modified"] + "Z") - assert 0 < (date_created - datetime_create_request).total_seconds() < 0.1 - assert 0 < (date_modified - datetime_update_request).total_seconds() < 0.1 + assert 0 < (date_created - datetime_create_request).total_seconds() < 0.2 + assert 0 < (date_modified - datetime_update_request).total_seconds() < 0.2 assert response_json["platform"] == "example" assert response_json["platform_resource_identifier"] == "2" diff --git a/src/tests/routers/resource_routers/test_router_person.py b/src/tests/routers/resource_routers/test_router_person.py index 5bf048d8..5a810294 100644 --- a/src/tests/routers/resource_routers/test_router_person.py +++ b/src/tests/routers/resource_routers/test_router_person.py @@ -1,11 +1,13 @@ import copy from unittest.mock import Mock +import pytest from starlette.testclient import TestClient from authentication import keycloak_openid from database.model.agent.contact import Contact from database.model.agent.person import Person +from database.model.platform.platform import Platform from database.session import DbSession @@ -48,3 +50,62 @@ def test_happy_path( assert response_json["price_per_hour_euro"] == 10.50 assert response_json["wants_to_be_contacted"] assert response_json["contact_details"] == 1 + + +@pytest.fixture( + params=[ + "/persons/v1", + "/persons/v1/1", + "/platforms/ai4europe_cms/persons/v1", + "/platforms/ai4europe_cms/persons/v1/2", + ] +) +def endpoint(request) -> str: + return request.param + + +def test_privacy_for_ai4europe_cms( + client: TestClient, + mocked_privileged_token: Mock, + mocked_ai4europe_cms_token: Mock, + platform: Platform, + person: Person, + contact: Contact, + endpoint: str, +): + """Test to ensure that only authenticated users with "full_view_ai4europe_cms_resources" role + can visualise fields such as name, given_name and surname of a person migrated from + the old ai4europe_cms platform. + """ + + with DbSession() as session: + person.platform = "ai4europe_cms" + person.platform_resource_identifier = "2" + person.name = "Joe Doe" + person.given_name = "Joe" + person.surname = "Doe" + session.add(person) + session.add(contact) + session.commit() + + headers = {"Authorization": "Fake token"} + keycloak_openid.introspect = mocked_privileged_token + + response = client.get(endpoint, headers=headers) + response_json = response.json() + response_json = [response_json] if isinstance(response_json, dict) else response_json + assert response.status_code == 200, response_json + for person_dict in response_json: + assert person_dict["name"] == "******" + assert person_dict["given_name"] == "******" + assert person_dict["surname"] == "******" + + keycloak_openid.introspect = mocked_ai4europe_cms_token + response = client.get(endpoint, headers=headers) + response_json = response.json() + response_json = [response_json] if isinstance(response_json, dict) else response_json + assert response.status_code == 200, response_json + for person_dict in response_json: + assert person_dict["name"] == "Joe Doe" + assert person_dict["given_name"] == "Joe" + assert person_dict["surname"] == "Doe" diff --git a/src/tests/routers/search_routers/test_search_routers.py b/src/tests/routers/search_routers/test_search_routers.py index 39deadf3..810e72e3 100644 --- a/src/tests/routers/search_routers/test_search_routers.py +++ b/src/tests/routers/search_routers/test_search_routers.py @@ -24,6 +24,7 @@ def test_search_happy_path(client: TestClient, search_router): assert resource["identifier"] == 1 assert resource["name"] == "A name." + assert resource["platform"] == "A platform." assert resource["description"]["plain"] == "A plain text description." assert resource["description"]["html"] == "An html description." assert resource["aiod_entry"]["date_modified"] == "2023-09-01T00:00:00+00:00" diff --git a/src/tests/test_authentication.py b/src/tests/test_authentication.py index 00ccba37..799d91fd 100644 --- a/src/tests/test_authentication.py +++ b/src/tests/test_authentication.py @@ -1,4 +1,5 @@ -"""Unittests for the behaviour of get_current_user().""" +"""Unittests for the behaviour of get_user_or_raise().""" + import inspect from unittest.mock import Mock @@ -8,14 +9,14 @@ from starlette import status -from authentication import get_current_user, keycloak_openid, User +from authentication import get_user_or_raise, keycloak_openid, User from tests.testutils.mock_keycloak import MockedKeycloak, TestUserType @pytest.mark.asyncio async def test_happy_path(): with MockedKeycloak() as _: - user = await get_current_user(token="Bearer mocked") + user = await get_user_or_raise(token="Bearer mocked") assert user.name == "user" assert set(user.roles) == {"offline_access", "uma_authorization", "default-roles-aiod"} @@ -23,7 +24,7 @@ async def test_happy_path(): @pytest.mark.asyncio async def test_happy_path_privileged(): with MockedKeycloak(type_=TestUserType.privileged) as _: - user = await get_current_user(token="Bearer mocked") + user = await get_user_or_raise(token="Bearer mocked") assert user.name == "user" assert set(user.roles) == { "offline_access", @@ -33,14 +34,14 @@ async def test_happy_path_privileged(): } -def test_get_current_user_leaks_no_information(): +def test_get_user_or_none_leaks_no_information(): """ Make sure an error is thrown if you change the fields on User. There may be good reasons to make a change, but please be very careful: we don't want to expose sensitive information to our application if it is not necessary. Moreover, the User class is returned by the authorization_test endpoint. """ - assert inspect.signature(get_current_user).return_annotation == User + assert inspect.signature(get_user_or_raise).return_annotation == User assert set(inspect.get_annotations(User)) == {"name", "roles"} @@ -48,20 +49,23 @@ def test_get_current_user_leaks_no_information(): async def test_inactive_user(): with MockedKeycloak(type_=TestUserType.inactive) as _: with pytest.raises(HTTPException) as exception_info: - await get_current_user(token="Bearer mocked") + await get_user_or_raise(token="Bearer mocked") assert exception_info.value.status_code == status.HTTP_401_UNAUTHORIZED - assert exception_info.value.detail == "Invalid authentication token" + assert exception_info.value.detail == ( + "Invalid userinfo or inactive user - " + "This endpoint requires authorization. You need to be logged in." + ) @pytest.mark.asyncio async def test_unauthenticated(): with pytest.raises(HTTPException) as exception_info: - await get_current_user(token=None) + await get_user_or_raise(token=None) assert exception_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert ( exception_info.value.detail - == "This endpoint requires authorization. You need to be logged in." + == "No token found - This endpoint requires authorization. You need to be logged in." ) @@ -79,6 +83,6 @@ async def test_keycloak_error(): ) ) with pytest.raises(HTTPException) as exception_info: - await get_current_user(token="Bearer mocked") + await get_user_or_raise(token="Bearer mocked") assert exception_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exception_info.value.detail == "Invalid authentication token" diff --git a/src/tests/testutils/default_instances.py b/src/tests/testutils/default_instances.py index c6f3f15d..6cb8a9cc 100644 --- a/src/tests/testutils/default_instances.py +++ b/src/tests/testutils/default_instances.py @@ -3,6 +3,7 @@ This way you have easy access to, for instance, an AIoDDataset filled with default values. """ + import copy import json @@ -16,6 +17,7 @@ from database.model.dataset.dataset import Dataset from database.model.knowledge_asset.publication import Publication from database.model.models_and_experiments.experiment import Experiment +from database.model.platform.platform import Platform from database.model.resource_read_and_create import resource_create from database.model.serializers import deserialize_resource_relationships from database.session import DbSession @@ -116,6 +118,12 @@ def experiment(body_asset) -> Experiment: return _create_class_with_body(Experiment, body) +@pytest.fixture +def platform() -> Platform: + body = {"name": "aiod"} + return _create_class_with_body(Platform, body) + + def _create_class_with_body(clz, body: dict): pydantic_class = resource_create(clz) res_create = pydantic_class(**body) diff --git a/src/tests/testutils/default_sqlalchemy.py b/src/tests/testutils/default_sqlalchemy.py index c893bdc2..22f55f34 100644 --- a/src/tests/testutils/default_sqlalchemy.py +++ b/src/tests/testutils/default_sqlalchemy.py @@ -141,3 +141,14 @@ def mocked_token(request: SubRequest) -> Mock: def mocked_privileged_token() -> Mock: roles = ["offline_access", "uma_authorization", "default-roles-aiod", "edit_aiod_resources"] return Mock(return_value=_user_with_roles(*roles)) + + +@pytest.fixture() +def mocked_ai4europe_cms_token() -> Mock: + roles = [ + "offline_access", + "uma_authorization", + "default-roles-aiod", + "full_view_ai4europe_cms_resources", + ] + return Mock(return_value=_user_with_roles(*roles)) diff --git a/src/tests/uploader/huggingface/test_dataset_uploader.py b/src/tests/uploader/huggingface/test_dataset_uploader.py index fc86f36d..f1a923b2 100644 --- a/src/tests/uploader/huggingface/test_dataset_uploader.py +++ b/src/tests/uploader/huggingface/test_dataset_uploader.py @@ -196,8 +196,9 @@ def test_wrong_platform(client: TestClient, mocked_privileged_token: Mock, datas status_code=status.HTTP_400_BAD_REQUEST, detail=( ERROR_MSG_PREFIX - + "The platform_resource_identifier for HuggingFace should be a valid repo_id. " - "A repo_id should only contain [a-zA-Z0-9] or ”-”, ”_”, ”.”" + + "Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are " + "forbidden, '-' and '.' cannot start or end the name, max length is 96: " + "'user/Test name with ?'." ), ), ), @@ -221,10 +222,8 @@ def test_wrong_platform(client: TestClient, mocked_privileged_token: Mock, datas status_code=status.HTTP_400_BAD_REQUEST, detail=( ERROR_MSG_PREFIX - + "The platform_resource_identifier for HuggingFace should be a valid repo_id. " - "For new repositories, there should be a single forward slash in the repo_id " - "(namespace/repo_name). Legacy repositories are without a namespace. This " - "repo_id has too many forward slashes." + + "Repo id must be in the form 'repo_name' or 'namespace/repo_name': " + "'user/data/set'. Use `repo_type` argument if needed." ), ), ), @@ -235,8 +234,8 @@ def test_wrong_platform(client: TestClient, mocked_privileged_token: Mock, datas status_code=status.HTTP_400_BAD_REQUEST, detail=( ERROR_MSG_PREFIX - + "The namespace (the first part of the platform_resource_identifier) should be" - " equal to the username, but wrong-namespace != user." + + "The namespace (the first part of the platform_resource_identifier) " + "should be equal to the username, but wrong-namespace != user." ), ), ), @@ -247,8 +246,9 @@ def test_wrong_platform(client: TestClient, mocked_privileged_token: Mock, datas status_code=status.HTTP_400_BAD_REQUEST, detail=( ERROR_MSG_PREFIX - + "The platform_resource_identifier for HuggingFace should be a valid repo_id. " - "A repo_id should be between 1 and 96 characters." + + "Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are " + "forbidden, '-' and '.' cannot start or end the name, max length is 96: " + "'user/" + "a" * 200 + "'." ), ), ), diff --git a/src/tests/validators/test_huggingface_validators.py b/src/tests/validators/test_huggingface_validators.py index 95b83ebd..a6b60dba 100644 --- a/src/tests/validators/test_huggingface_validators.py +++ b/src/tests/validators/test_huggingface_validators.py @@ -12,24 +12,24 @@ ( "user/data/set", ValueError( - "The platform_resource_identifier for HuggingFace should be a valid repo_id. For " - "new repositories, there should be a single forward slash in the repo_id " - "(namespace/repo_name). Legacy repositories are without a namespace. This repo_id " - "has too many forward slashes." + "Repo id must be in the form 'repo_name' or 'namespace/repo_name': " + "'user/data/set'. Use `repo_type` argument if needed." ), ), ( - "a", + "", ValueError( - "The platform_resource_identifier for HuggingFace should be a valid repo_id. A " - "repo_id should be between 1 and 96 characters." + "Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are " + "forbidden, '-' and '.' cannot start or end the name, max length is 96: " + "''." ), ), ( "user/" + "a" * 200, ValueError( - "The platform_resource_identifier for HuggingFace should be a valid repo_id. A " - "repo_id should be between 1 and 96 characters." + "Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are " + "forbidden, '-' and '.' cannot start or end the name, max length is 96: " + "'user/" + "a" * 200 + "'." ), ), ],