Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add security linting, request timeouts #376

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ repos:
rev: 22.12.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.1
hooks:
- id: ruff # only do linting for security for now
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.991
hooks:
Expand Down
1 change: 0 additions & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ services:
- ${AIOD_REST_PORT}:8000
volumes:
- ./src:/app:ro
stdin_open: true # docker run -i
command: >
python main.py
--build-db if-absent
Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ py-modules = []
[tool.black]
line-length = 100

[tool.ruff]
exclude = [
"src/tests",
]

[tool.ruff.lint]
select = ["S"]

[tool.pytest.ini_options]
filterwarnings = [
"ignore::FutureWarning"
Expand Down
1 change: 1 addition & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@

DB_CONFIG = CONFIG.get("database", {})
KEYCLOAK_CONFIG = CONFIG.get("keycloak", {})
REQUEST_TIMEOUT = CONFIG.get("dev", {}).get("request_timeout", None)
1 change: 1 addition & 0 deletions src/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ password = "ok"
# Additional options for development
[dev]
reload = true
request_timeout = 10 # seconds

# Authentication and authorization
[keycloak]
Expand Down
3 changes: 2 additions & 1 deletion src/connectors/huggingface/huggingface_dataset_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from huggingface_hub import list_datasets
from huggingface_hub.hf_api import DatasetInfo

from config import REQUEST_TIMEOUT
from connectors.abstract.resource_connector_on_start_up import ResourceConnectorOnStartUp
from connectors.record_error import RecordError
from connectors.resource_with_relations import ResourceWithRelations
Expand Down Expand Up @@ -37,7 +38,7 @@ def platform_name(self) -> PlatformName:

@staticmethod
def _get(url: str, dataset_id: str) -> typing.List[typing.Dict[str, typing.Any]]:
response = requests.get(url, params={"dataset": dataset_id})
response = requests.get(url, params={"dataset": dataset_id}, timeout=REQUEST_TIMEOUT)
response_json = response.json()
if not response.ok:
msg = response_json["error"]
Expand Down
7 changes: 4 additions & 3 deletions src/connectors/openml/openml_dataset_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sqlmodel import SQLModel
from typing import Iterator

from config import REQUEST_TIMEOUT
from connectors.abstract.resource_connector_by_id import ResourceConnectorById
from connectors.record_error import RecordError
from database.model import field_length
Expand Down Expand Up @@ -40,7 +41,7 @@ def platform_name(self) -> PlatformName:

def retry(self, identifier: int) -> SQLModel | RecordError:
url_qual = f"https://www.openml.org/api/v1/json/data/qualities/{identifier}"
response = requests.get(url_qual)
response = requests.get(url_qual, timeout=REQUEST_TIMEOUT)
if not response.ok:
msg = response.json()["error"]["message"]
return RecordError(
Expand All @@ -54,7 +55,7 @@ def fetch_record(
self, identifier: int, qualities: list[dict[str, str]]
) -> SQLModel | RecordError:
url_data = f"https://www.openml.org/api/v1/json/data/{identifier}"
response = requests.get(url_data)
response = requests.get(url_data, timeout=REQUEST_TIMEOUT)
if not response.ok:
msg = response.json()["error"]["message"]
return RecordError(
Expand Down Expand Up @@ -105,7 +106,7 @@ def fetch(self, offset: int, from_identifier: int) -> Iterator[SQLModel | Record
"https://www.openml.org/api/v1/json/data/list/"
f"limit/{self.limit_per_iteration}/offset/{offset}"
)
response = requests.get(url_data)
response = requests.get(url_data, timeout=REQUEST_TIMEOUT)
if not response.ok:
status_code = response.status_code
msg = response.json()["error"]["message"]
Expand Down
5 changes: 3 additions & 2 deletions src/connectors/openml/openml_mlmodel_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sqlmodel import SQLModel
from typing import Iterator, Any

from config import REQUEST_TIMEOUT
from connectors.abstract.resource_connector_by_id import ResourceConnectorById
from connectors.record_error import RecordError
from database.model import field_length
Expand Down Expand Up @@ -45,7 +46,7 @@ def retry(self, identifier: int) -> ResourceWithRelations[SQLModel] | RecordErro

def fetch_record(self, identifier: int) -> ResourceWithRelations[MLModel] | RecordError:
url_mlmodel = f"https://www.openml.org/api/v1/json/flow/{identifier}"
response = requests.get(url_mlmodel)
response = requests.get(url_mlmodel, timeout=REQUEST_TIMEOUT)
if not response.ok:
msg = response.json()["error"]["message"]
return RecordError(
Expand Down Expand Up @@ -101,7 +102,7 @@ def fetch(
"https://www.openml.org/api/v1/json/flow/list/"
f"limit/{self.limit_per_iteration}/offset/{offset}"
)
response = requests.get(url_mlmodel)
response = requests.get(url_mlmodel, timeout=REQUEST_TIMEOUT)

if not response.ok:
status_code = response.status_code
Expand Down
5 changes: 4 additions & 1 deletion src/connectors/zenodo/zenodo_dataset_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from starlette import status
from typing import Iterator, Tuple

from config import REQUEST_TIMEOUT
from connectors.abstract.resource_connector_by_date import ResourceConnectorByDate
from connectors.record_error import RecordError
from connectors.resource_with_relations import ResourceWithRelations
Expand Down Expand Up @@ -65,7 +66,9 @@ def _error_msg_bad_format(field) -> str:
@limits(calls=GLOBAL_MAX_CALLS_MINUTE, period=ONE_MINUTE)
@limits(calls=GLOBAL_MAX_CALLS_HOUR, period=ONE_HOUR)
def _get_record(id_number: str) -> requests.Response:
response = requests.get(f"https://zenodo.org/api/records/{id_number}/files")
response = requests.get(
f"https://zenodo.org/api/records/{id_number}/files", timeout=REQUEST_TIMEOUT
)
return response

def _dataset_from_record(
Expand Down
8 changes: 4 additions & 4 deletions src/database/deletion/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def create_deletion_trigger_one_to_one(
DELETE FROM {delete_name}
WHERE {delete_name}.{to_delete_identifier} = OLD.{trigger_identifier_link};
END;
"""
""" # noqa: S608 # never user input
)
event.listen(trigger.metadata, "after_create", ddl)

Expand Down Expand Up @@ -97,7 +97,7 @@ def create_deletion_trigger_many_to_one(
WHERE {trigger_name}.{trigger_identifier_link} = OLD.{trigger_identifier_link}
);
END;
"""
""" # noqa: S608 # never user input
)
event.listen(trigger.metadata, "after_create", ddl)

Expand Down Expand Up @@ -139,7 +139,7 @@ def create_deletion_trigger_many_to_many(
SELECT 1 FROM {link_name}
WHERE {link_name}.{link_to_identifier} = {delete_name}.{to_delete_identifier}
)
"""
""" # noqa: S608 # never user input
for link_name in link_names
)
ddl = DDL(
Expand All @@ -153,6 +153,6 @@ def create_deletion_trigger_many_to_many(
DELETE FROM {delete_name}
WHERE {links_clause};
END;
"""
""" # noqa: S608 # never user input
)
event.listen(trigger.metadata, "after_create", ddl)
7 changes: 6 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,12 @@ def create_app() -> FastAPI:
def main():
"""Run the application. Placed in a separate function, to avoid having global variables"""
args = _parse_args()
uvicorn.run("main:create_app", host="0.0.0.0", reload=args.reload, factory=True)
uvicorn.run(
"main:create_app",
host="0.0.0.0", # noqa: S104 # required to make the interface available outside of docker
reload=args.reload,
factory=True,
)


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion src/tests/connectors/zenodo/test_get_datasets_zenodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ratelimit.exception import RateLimitException
from requests.exceptions import HTTPError

from config import REQUEST_TIMEOUT
from connectors.record_error import RecordError
from connectors.zenodo import zenodo_dataset_connector
from connectors.zenodo.zenodo_dataset_connector import ZenodoDatasetConnector
Expand Down Expand Up @@ -168,7 +169,9 @@ def test_fetch_records_rate_limit(monkeypatch):
@staticmethod
@limits(calls=1, period=60)
def mock_check(id_number):
response = requests.get(f"https://zenodo.org/api/records/{id_number}/files")
response = requests.get(
f"https://zenodo.org/api/records/{id_number}/files", timeout=REQUEST_TIMEOUT
)
return response

monkeypatch.setattr(ZenodoDatasetConnector, "_get_record", mock_check)
Expand Down
20 changes: 15 additions & 5 deletions src/uploaders/zenodo_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi import UploadFile, HTTPException, status

from authentication import User
from config import REQUEST_TIMEOUT

from database.model.agent.contact import Contact
from database.model.ai_asset.license import License
Expand Down Expand Up @@ -142,6 +143,7 @@ def _create_repo(self, metadata: dict, token: str) -> dict:
self.BASE_URL,
params=params,
json={"metadata": metadata},
timeout=REQUEST_TIMEOUT,
)
except Exception as exc:
raise as_http_exception(exc)
Expand All @@ -158,7 +160,7 @@ def _get_metadata_from_zenodo(self, repo_id: str, token: str) -> dict:
"""
params = {"access_token": token}
try:
res = requests.get(f"{self.BASE_URL}/{repo_id}", params=params)
res = requests.get(f"{self.BASE_URL}/{repo_id}", params=params, timeout=REQUEST_TIMEOUT)
except Exception as exc:
raise as_http_exception(exc)

Expand All @@ -179,6 +181,7 @@ def _update_zenodo_metadata(self, metadata: dict, repo_id: str, token: str) -> N
params={"access_token": token},
data=json.dumps({"metadata": metadata}),
headers=headers,
timeout=REQUEST_TIMEOUT,
)
except Exception as exc:
raise as_http_exception(exc)
Expand All @@ -194,7 +197,10 @@ def _upload_file(self, repo_url: str, file: UploadFile, token: str) -> None:
params = {"access_token": token}
try:
res = requests.put(
f"{repo_url}/{file.filename}", data=io.BufferedReader(file.file), params=params
f"{repo_url}/{file.filename}",
data=io.BufferedReader(file.file),
params=params,
timeout=REQUEST_TIMEOUT,
)
except Exception as exc:
raise as_http_exception(exc)
Expand All @@ -209,7 +215,9 @@ def _publish_resource(self, repo_id: str, token: str) -> dict:
"""
params = {"access_token": token}
try:
res = requests.post(f"{self.BASE_URL}/{repo_id}/actions/publish", params=params)
res = requests.post(
f"{self.BASE_URL}/{repo_id}/actions/publish", params=params, timeout=REQUEST_TIMEOUT
)
except Exception as exc:
raise as_http_exception(exc)

Expand All @@ -229,7 +237,7 @@ def _get_distribution(
url = public_url or f"{self.BASE_URL}/{repo_id}"

try:
res = requests.get(f"{url}/files", params=params)
res = requests.get(f"{url}/files", params=params, timeout=REQUEST_TIMEOUT)
except Exception as exc:
raise as_http_exception(exc)

Expand Down Expand Up @@ -261,7 +269,9 @@ def _get_and_validate_license(self, license_: License | None) -> str:
Checks if the provided license is valid for uploading content to Zenodo.
"""
try:
res = requests.get("https://zenodo.org/api/vocabularies/licenses?q=&tags=data")
res = requests.get(
"https://zenodo.org/api/vocabularies/licenses?q=&tags=data", timeout=REQUEST_TIMEOUT
)
except Exception as exc:
raise as_http_exception(exc)
if res.status_code != status.HTTP_200_OK:
Expand Down
Loading