Skip to content

Commit

Permalink
Merge pull request #858 from openzim/fix_worker_ip_update
Browse files Browse the repository at this point in the history
Fix worker IP update to use one single DB session
  • Loading branch information
rgaudin authored Nov 21, 2023
2 parents e078943 + 2ddee9b commit e208419
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 25 deletions.
2 changes: 1 addition & 1 deletion dispatcher/backend/src/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
# using the following, it is possible to automate
# the update of a whitelist of workers IPs on Wasabi (S3 provider)
# enable this feature (default is off)
USES_WORKERS_IPS_WHITELIST = bool(os.getenv("USES_WORKERS_IPS_WHITELIST", ""))
USES_WORKERS_IPS_WHITELIST = bool(os.getenv("USES_WORKERS_IPS_WHITELIST"))
MAX_WORKER_IP_CHANGES_PER_DAY = 4
# wasabi URL with credentials to update policy
WASABI_URL = os.getenv("WASABI_URL", "")
Expand Down
15 changes: 12 additions & 3 deletions dispatcher/backend/src/common/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@
logger = logging.getLogger(__name__)


def update_workers_whitelist():
def update_workers_whitelist(session: so.Session):
"""update whitelist of workers on external services"""
update_wasabi_whitelist(build_workers_whitelist())
ExternalIpUpdater.update(build_workers_whitelist(session=session))


@dbsession
def build_workers_whitelist(session: so.Session) -> typing.List[str]:
"""list of worker IP adresses and networks (text) to use as whitelist"""
wl_networks = []
Expand Down Expand Up @@ -150,6 +149,16 @@ def get_statement():
)


class ExternalIpUpdater:
"""Class responsible to push IP updates to external system(s)
`update` is called with the new list of all workers IPs everytime
a change is detected.
By default, this class update our IPs whitelist in Wasabi"""

update = update_wasabi_whitelist


@dbsession
def advertise_books_to_cms(task_id: UUID, session: so.Session):
"""inform openZIM CMS of all created ZIMs in the farm for this task
Expand Down
14 changes: 14 additions & 0 deletions dispatcher/backend/src/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ def inner(*args, **kwargs):
return inner


def dbsession_manual(func):
"""Decorator to create an SQLAlchemy ORM session object and wrap the function
inside the session. A `session` argument is automatically set. Transaction must
be managed by the developer (e.g. perform a commit / rollback).
"""

def inner(*args, **kwargs):
with Session() as session:
kwargs["session"] = session
return func(*args, **kwargs)

return inner


def count_from_stmt(session: OrmSession, stmt: SelectBase) -> int:
"""Count all records returned by any statement `stmt` passed as parameter"""
return session.execute(
Expand Down
43 changes: 25 additions & 18 deletions dispatcher/backend/src/routes/requested_tasks/requested_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@
from marshmallow import ValidationError

import db.models as dbm
from common import WorkersIpChangesCounts, getnow
from common.constants import (
ENABLED_SCHEDULER,
MAX_WORKER_IP_CHANGES_PER_DAY,
USES_WORKERS_IPS_WHITELIST,
)
from common import WorkersIpChangesCounts, constants, getnow
from common.constants import ENABLED_SCHEDULER, MAX_WORKER_IP_CHANGES_PER_DAY
from common.external import update_workers_whitelist
from common.schemas.orms import RequestedTaskFullSchema, RequestedTaskLightSchema
from common.schemas.parameters import (
Expand All @@ -24,8 +20,8 @@
WorkerRequestedTaskSchema,
)
from common.utils import task_event_handler
from db import count_from_stmt, dbsession
from errors.http import InvalidRequestJSON, TaskNotFound, WorkerNotFound
from db import count_from_stmt, dbsession, dbsession_manual
from errors.http import HTTPBase, InvalidRequestJSON, TaskNotFound, WorkerNotFound
from routes import auth_info_if_supplied, authenticate, require_perm, url_uuid
from routes.base import BaseRoute
from routes.errors import NotFound
Expand All @@ -35,14 +31,14 @@
logger = logging.getLogger(__name__)


def record_ip_change(worker_name):
def record_ip_change(session: so.Session, worker_name: str):
"""record that this worker changed its IP and trigger whitelist changes"""
today = datetime.date.today()
# counts and limits are per-day so reset it if date changed
if today != WorkersIpChangesCounts.today:
WorkersIpChangesCounts.reset()
if WorkersIpChangesCounts.add(worker_name) <= MAX_WORKER_IP_CHANGES_PER_DAY:
update_workers_whitelist()
update_workers_whitelist(session)
else:
logger.error(
f"Worker {worker_name} IP changes for {today} "
Expand Down Expand Up @@ -208,7 +204,7 @@ class RequestedTasksForWorkers(BaseRoute):
methods = ["GET"]

@authenticate
@dbsession
@dbsession_manual
def get(self, session: so.Session, token: AccessToken.Payload):
"""list of requested tasks to be retrieved by workers, auth-only"""

Expand All @@ -229,15 +225,26 @@ def get(self, session: so.Session, token: AccessToken.Payload):
worker = dbm.Worker.get(session, worker_name, WorkerNotFound)
if worker.user.username == token.username:
worker.last_seen = getnow()
previous_ip = str(worker.last_ip)
worker.last_ip = worker_ip

# flush to DB so that record_ip_change has access to updated IP
session.flush()

# IP changed since last encounter
if USES_WORKERS_IPS_WHITELIST and previous_ip != worker_ip:
record_ip_change(worker_name)
if str(worker.last_ip) != worker_ip:
logger.info(
f"Worker IP changed detected for {worker_name}: "
f"IP changed from {worker.last_ip} to {worker_ip}"
)
worker.last_ip = worker_ip
# commit explicitely since we are not using an explicit transaction,
# and do it before calling Wasabi so that changes are propagated
# quickly and transaction is not blocking
session.commit()
if constants.USES_WORKERS_IPS_WHITELIST:
try:
record_ip_change(session=session, worker_name=worker_name)
except Exception:
raise HTTPBase(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
error="Recording IP changes failed",
)

request_args = WorkerRequestedTaskSchema().load(request_args)

Expand Down
12 changes: 12 additions & 0 deletions dispatcher/backend/src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Generator

import pytest
from sqlalchemy.orm import Session as OrmSession

from db import Session


@pytest.fixture
def dbsession() -> Generator[OrmSession, None, None]:
with Session.begin() as session:
yield session
136 changes: 133 additions & 3 deletions dispatcher/backend/src/tests/integration/routes/workers/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import List

import pytest

from common.external import build_workers_whitelist
from common import constants
from common.external import ExternalIpUpdater, build_workers_whitelist


class TestWorkersCommon:
def test_build_workers_whitelist(self, workers):
whitelist = build_workers_whitelist()
def test_build_workers_whitelist(self, workers, dbsession):
whitelist = build_workers_whitelist(session=dbsession)
# - 4 because:
# 2 workers have a duplicate IP
# 1 worker has an IP missing
Expand Down Expand Up @@ -206,3 +209,130 @@ def test_checkin_another_user(
# response.get_json()["error"]
# == "worker with same name already exists for another user"
# )


class TestWorkerRequestedTasks:
def test_requested_task_worker_as_admin(self, client, access_token, worker):
response = client.get(
"/requested-tasks/worker",
query_string={
"worker": worker["name"],
"avail_cpu": 4,
"avail_memory": 2048,
"avail_disk": 4096,
},
headers={"Authorization": access_token},
)
assert response.status_code == 200

def test_requested_task_worker_as_worker(self, client, make_access_token, worker):
response = client.get(
"/requested-tasks/worker",
query_string={
"worker": worker["name"],
"avail_cpu": 4,
"avail_memory": 2048,
"avail_disk": 4096,
},
headers={"Authorization": make_access_token(worker["username"], "worker")},
)
assert response.status_code == 200

@pytest.mark.parametrize(
"prev_ip, new_ip, external_update_enabled, external_update_fails,"
" external_update_called",
[
("77.77.77.77", "88.88.88.88", False, False, False), # ip update disabled
("77.77.77.77", "77.77.77.77", True, False, False), # ip did not changed
("77.77.77.77", "88.88.88.88", True, False, True), # ip should be updated
("77.77.77.77", "88.88.88.88", True, True, False), # ip update fails
],
)
def test_requested_task_worker_update_ip_whitelist(
self,
client,
make_access_token,
worker,
prev_ip,
new_ip,
external_update_enabled,
external_update_fails,
external_update_called,
):
# call it once to set prev_ip
response = client.get(
"/requested-tasks/worker",
query_string={
"worker": worker["name"],
"avail_cpu": 4,
"avail_memory": 2048,
"avail_disk": 4096,
},
headers={
"Authorization": make_access_token(worker["username"], "worker"),
"X-Forwarded-For": prev_ip,
},
)
assert response.status_code == 200

# check prev_ip has been set
response = client.get("/workers/")
assert response.status_code == 200
response_data = response.get_json()
for item in response_data["items"]:
if item["name"] != worker["name"]:
continue
assert item["last_ip"] == prev_ip

# setup custom ip updater to intercept Wasabi operations
updater = IpUpdaterAndChecker(should_fail=external_update_fails)
assert new_ip not in updater.ip_addresses
ExternalIpUpdater.update = updater.ip_update
constants.USES_WORKERS_IPS_WHITELIST = external_update_enabled

# call it once to set next_ip
response = client.get(
"/requested-tasks/worker",
query_string={
"worker": worker["name"],
"avail_cpu": 4,
"avail_memory": 2048,
"avail_disk": 4096,
},
headers={
"Authorization": make_access_token(worker["username"], "worker"),
"X-Forwarded-For": new_ip,
},
)
if external_update_fails:
assert response.status_code == 503
else:
assert response.status_code == 200
assert updater.ips_updated == external_update_called
if external_update_called:
assert new_ip in updater.ip_addresses

# check new_ip has been set (even if ip update is disabled or has failed)
response = client.get("/workers/")
assert response.status_code == 200
response_data = response.get_json()
for item in response_data["items"]:
if item["name"] != worker["name"]:
continue
assert item["last_ip"] == new_ip


class IpUpdaterAndChecker:
"""Helper class to intercept Wasabi operations and perform assertions"""

def __init__(self, should_fail: bool) -> None:
self.ips_updated = False
self.should_fail = should_fail
self.ip_addresses = []

def ip_update(self, ip_addresses: List):
if self.should_fail:
raise Exception()
else:
self.ips_updated = True
self.ip_addresses = ip_addresses

0 comments on commit e208419

Please sign in to comment.