Skip to content

Commit

Permalink
Merge pull request #8988 from OpenMined/tauquir/uvicorn-hotreload-2
Browse files Browse the repository at this point in the history
🔥 Hot-reload Syft nodes in your notebooks
  • Loading branch information
rasswanth-s authored Jul 3, 2024
2 parents 316ff6a + 0148fe0 commit 9426b61
Showing 1 changed file with 85 additions and 115 deletions.
200 changes: 85 additions & 115 deletions packages/syft/src/syft/node/server.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
# stdlib
import asyncio
from collections.abc import Callable
from enum import Enum
import multiprocessing
import os
from pathlib import Path
import platform
import signal
import subprocess # nosec
import sys
import time
from typing import Any

# third party
from fastapi import APIRouter
from fastapi import FastAPI
from pydantic_settings import BaseSettings
from pydantic_settings import SettingsConfigDict
import requests
from starlette.middleware.cors import CORSMiddleware
import uvicorn

# relative
from ..abstract_node import NodeSideType
from ..client.client import API_PATH
from ..util.autoreload import enable_autoreload
from ..util.constants import DEFAULT_TIMEOUT
from ..util.util import os_name
from .domain import Domain
Expand All @@ -35,134 +39,96 @@
WAIT_TIME_SECONDS = 20


def make_app(name: str, router: APIRouter) -> FastAPI:
app = FastAPI(
title=name,
)
class AppSettings(BaseSettings):
name: str
node_type: NodeType = NodeType.DOMAIN
node_side_type: NodeSideType = NodeSideType.HIGH_SIDE
processes: int = 1
reset: bool = False
dev_mode: bool = False
enable_warnings: bool = False
in_memory_workers: bool = True
queue_port: int | None = None
create_producer: bool = False
n_consumers: int = 0
association_request_auto_approval: bool = False
background_tasks: bool = False

model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None")


def app_factory() -> FastAPI:
settings = AppSettings()

worker_classes = {
NodeType.DOMAIN: Domain,
NodeType.GATEWAY: Gateway,
NodeType.ENCLAVE: Enclave,
}
if settings.node_type not in worker_classes:
raise NotImplementedError(f"node_type: {settings.node_type} is not supported")
worker_class = worker_classes[settings.node_type]

kwargs = settings.model_dump()
if settings.dev_mode:
print(
f"\nWARNING: private key is based on node name: {settings.name} in dev_mode. "
"Don't run this in production."
)
worker = worker_class.named(**kwargs)
else:
del kwargs["reset"] # Explicitly remove reset from kwargs for non-dev mode
worker = worker_class(**kwargs)

app = FastAPI(title=settings.name)
router = make_routes(worker=worker)
api_router = APIRouter()

api_router.include_router(router)
app.include_router(api_router, prefix="/api/v2")

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

return app


worker_classes = {
NodeType.DOMAIN: Domain,
NodeType.GATEWAY: Gateway,
NodeType.ENCLAVE: Enclave,
}


def run_uvicorn(
name: str,
node_type: Enum,
host: str,
port: int,
processes: int,
reset: bool,
dev_mode: bool,
node_side_type: str,
enable_warnings: bool,
in_memory_workers: bool,
queue_port: int | None,
create_producer: bool,
association_request_auto_approval: bool,
n_consumers: int,
background_tasks: bool,
) -> None:
async def _run_uvicorn(
name: str,
node_type: NodeType,
host: str,
port: int,
reset: bool,
dev_mode: bool,
node_side_type: Enum,
) -> None:
if node_type not in worker_classes:
raise NotImplementedError(f"node_type: {node_type} is not supported")
worker_class = worker_classes[node_type]
if dev_mode:
print(
f"\nWARNING: private key is based on node name: {name} in dev_mode. "
"Don't run this in production."
)

worker = worker_class.named(
name=name,
processes=processes,
reset=reset,
local_db=True,
node_type=node_type,
node_side_type=node_side_type,
enable_warnings=enable_warnings,
migrate=True,
in_memory_workers=in_memory_workers,
queue_port=queue_port,
create_producer=create_producer,
n_consumers=n_consumers,
association_request_auto_approval=association_request_auto_approval,
background_tasks=background_tasks,
)
else:
worker = worker_class(
name=name,
processes=processes,
local_db=True,
node_type=node_type,
node_side_type=node_side_type,
enable_warnings=enable_warnings,
migrate=True,
in_memory_workers=in_memory_workers,
queue_port=queue_port,
create_producer=create_producer,
n_consumers=n_consumers,
association_request_auto_approval=association_request_auto_approval,
background_tasks=background_tasks,
)
router = make_routes(worker=worker)
app = make_app(worker.name, router=router)

if reset:
try:
python_pids = find_python_processes_on_port(port)
for pid in python_pids:
print(f"Stopping process on port: {port}")
kill_process(pid)
time.sleep(1)
except Exception: # nosec
print(f"Failed to kill python process on port: {port}")

config = uvicorn.Config(app, host=host, port=port, reload=dev_mode)
server = uvicorn.Server(config)

await server.serve()
asyncio.get_running_loop().stop()

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(
_run_uvicorn(
name,
node_type,
host,
port,
reset,
dev_mode,
node_side_type,
)
def run_uvicorn(host: str, port: int, **kwargs: Any) -> None:
if kwargs.get("reset"):
try:
python_pids = find_python_processes_on_port(port)
for pid in python_pids:
print(f"Stopping process on port: {port}")
kill_process(pid)
time.sleep(1)
except Exception: # nosec
print(f"Failed to kill python process on port: {port}")

# Set up all kwargs as environment variables so that they can be accessed in the app_factory function.
env_prefix = AppSettings.model_config.get("env_prefix", "")
for key, value in kwargs.items():
key_with_prefix = f"{env_prefix}{key.upper()}"
os.environ[key_with_prefix] = str(value)

# The `serve_node` function calls `run_uvicorn` in a separate process using `multiprocessing.Process`.
# When the child process is created, it inherits the file descriptors from the parent process.
# If the parent process has a file descriptor open for sys.stdin, the child process will also have a file descriptor
# open for sys.stdin. This can cause an OSError in uvicorn when it tries to access sys.stdin in the child process.
# To prevent this, we set sys.stdin to None in the child process. This is safe because we don't actually need
# sys.stdin while running uvicorn programmatically.
sys.stdin = None # type: ignore

# Finally, run the uvicorn server.
uvicorn.run(
"syft.node.server:app_factory",
host=host,
port=port,
factory=True,
reload=kwargs.get("dev_mode"),
reload_dirs=[Path(__file__).parent.parent] if kwargs.get("dev_mode") else None,
)
loop.close()


def serve_node(
Expand All @@ -183,6 +149,10 @@ def serve_node(
association_request_auto_approval: bool = False,
background_tasks: bool = False,
) -> tuple[Callable, Callable]:
# Enable IPython autoreload if dev_mode is enabled.
if dev_mode:
enable_autoreload()

server_process = multiprocessing.Process(
target=run_uvicorn,
kwargs={
Expand Down

0 comments on commit 9426b61

Please sign in to comment.