Skip to content

Commit

Permalink
misc: added some docstrings, types
Browse files Browse the repository at this point in the history
  • Loading branch information
WinPlay02 committed Nov 21, 2023
1 parent 9186587 commit 5d6cda5
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 44 deletions.
1 change: 1 addition & 0 deletions src/safeds_runner/server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Infrastructure for dynamically running Safe-DS pipelines and communication with the vscode extension."""
31 changes: 16 additions & 15 deletions src/safeds_runner/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from typing import Any

import simple_websocket
from flask import Flask
from flask_cors import CORS
from flask_sock import Sock
Expand All @@ -28,37 +29,37 @@


@sock.route("/WSMain")
def ws_run_program(ws):
logging.debug(f"Request to WSRunProgram")
def ws_run_program(ws: simple_websocket.Server) -> None:
logging.debug("Request to WSRunProgram")
set_new_websocket_target(ws)
while True:
# This would be a JSON message
received_message: str = ws.receive()
logging.debug(f"> Received Message: {received_message}")
logging.debug("> Received Message: %s", received_message)
try:
received_object: dict[str, Any] = json.loads(received_message)
except json.JSONDecodeError:
logging.warn(f"Invalid message received: {received_message}")
logging.warn("Invalid message received: %s", received_message)
ws.close(None)
return
if "type" not in received_object:
logging.warn(f"No message type specified in: {received_message}")
logging.warn("No message type specified in: %s", received_message)
ws.close(None)
return
if "id" not in received_object:
logging.warn(f"No message id specified in: {received_message}")
logging.warn("No message id specified in: %s", received_message)
ws.close(None)
return
if "data" not in received_object:
logging.warn(f"No message data specified in: {received_message}")
logging.warn("No message data specified in: %s", received_message)
ws.close(None)
return
if not isinstance(received_object["type"], str):
logging.warn(f"Message type is not a string: {received_message}")
logging.warn("Message type is not a string: %s", received_message)
ws.close(None)
return
if not isinstance(received_object["id"], str):
logging.warn(f"Message id is not a string: {received_message}")
logging.warn("Message id is not a string: %s", received_message)
ws.close(None)
return
request_data = received_object["data"]
Expand All @@ -68,7 +69,7 @@ def ws_run_program(ws):
case "program":
valid, invalid_message = messages.validate_program_message(request_data)
if not valid:
logging.warn(f"Invalid message data specified in: {received_message} ({invalid_message})")
logging.warn("Invalid message data specified in: %s (%s)", received_message, invalid_message)
ws.close(None)
return
code = request_data["code"]
Expand All @@ -78,7 +79,7 @@ def ws_run_program(ws):
case "placeholder_query":
valid, invalid_message = messages.validate_placeholder_query_message(request_data)
if not valid:
logging.warn(f"Invalid message data specified in: {received_message} ({invalid_message})")
logging.warn("Invalid message data specified in: %s (%s)", received_message, invalid_message)
ws.close(None)
return
placeholder_type, placeholder_value = get_placeholder(execution_id, request_data)
Expand All @@ -89,14 +90,14 @@ def ws_run_program(ws):
send_websocket_value(ws, request_data, "", "")
case _:
if message_type not in messages.message_types:
logging.warn(f"Invalid message type {message_type}")
logging.warn("Invalid message type: %s", message_type)


def send_websocket_value(connection, name: str, var_type: str, value: str):
def send_websocket_value(connection: simple_websocket.Server, name: str, var_type: str, value: str) -> None:
send_websocket_message(connection, "value", {"name": name, "type": var_type, "value": value})


def send_websocket_message(connection, msg_type: str, msg_data):
def send_websocket_message(connection: simple_websocket.Server, msg_type: str, msg_data) -> None:
message = {"type": msg_type, "data": msg_data}
connection.send(json.dumps(message))

Expand All @@ -115,6 +116,6 @@ def send_websocket_message(connection, msg_type: str, msg_data):
parser.add_argument('--port', type=int, default=5000, help='Port on which to run the python server.')
args = parser.parse_args()
setup_pipeline_execution()
logging.info(f"Starting Safe-DS Runner on port {args.port}")
logging.info("Starting Safe-DS Runner on port %s", str(args.port))
# Only bind to host=127.0.0.1. Connections from other devices should not be accepted
WSGIServer(('127.0.0.1', args.port), app).serve_forever()
5 changes: 3 additions & 2 deletions src/safeds_runner/server/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def create_runtime_progress_done() -> str:
return "done"


def validate_program_message(message_data: dict[str, typing.Any] | str) -> (bool, typing.Optional[str]):
def validate_program_message(message_data: dict[str, typing.Any] | str) -> typing.Tuple[bool, typing.Optional[str]]:
if not isinstance(message_data, dict):
return False, "Message data is not a JSON object"
if "code" not in message_data:
Expand Down Expand Up @@ -50,7 +50,8 @@ def validate_program_message(message_data: dict[str, typing.Any] | str) -> (bool
return True, None


def validate_placeholder_query_message(message_data: dict[str, typing.Any] | str) -> (bool, typing.Optional[str]):
def validate_placeholder_query_message(message_data: dict[str, typing.Any] | str) -> typing.Tuple[
bool, typing.Optional[str]]:
if not isinstance(message_data, str):
return False, "Message data is not a string"
return True, None
38 changes: 19 additions & 19 deletions src/safeds_runner/server/module_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib.abc
import typing
from abc import ABC
from importlib.machinery import ModuleSpec
import sys
Expand All @@ -17,10 +18,10 @@ def __init__(self, code_bytes: bytes, filename: str):
self.code_bytes = code_bytes
self.filename = filename

def get_data(self, path) -> bytes:
def get_data(self, path: bytes | str) -> bytes:
return self.code_bytes

def get_filename(self, fullname) -> str:
def get_filename(self, fullname: str) -> str:
return self.filename


Expand All @@ -33,18 +34,20 @@ def __init__(self, code: dict[str, dict[str, str]]):
"""
self.code = code
self.allowed_packages = {key for key in code.keys()}
self.imports_to_remove = set()
self.imports_to_remove: typing.Set[str] = set()
for key in code.keys():
while "." in key:
key = key.rpartition(".")[0]
self.allowed_packages.add(key)

def find_spec(self, fullname: str, path=None, target: types.ModuleType | None = None) -> ModuleSpec | None:
logging.debug(f"Find Spec: {fullname} {path} {target}")
def find_spec(self, fullname: str, path: typing.Sequence[bytes | str] = None,
target: types.ModuleType | None = None) -> ModuleSpec | None:
logging.debug("Find Spec: %s %s %s", fullname, path, target)
if fullname in self.allowed_packages:
parent_package = importlib.util.spec_from_loader(fullname, loader=InMemoryLoader("".encode("utf-8"),
fullname.replace(".",
"/")))
parent_package = importlib.util.spec_from_loader(
fullname, loader=InMemoryLoader("".encode("utf-8"), fullname.replace(".", "/")))
if parent_package is None:
return None
if parent_package.submodule_search_locations is None:
parent_package.submodule_search_locations = []
parent_package.submodule_search_locations.append(fullname.replace(".", "/"))
Expand All @@ -53,19 +56,16 @@ def find_spec(self, fullname: str, path=None, target: types.ModuleType | None =
module_path = fullname.split(".")
if len(module_path) == 1 and "" in self.code and fullname in self.code[""]:
self.imports_to_remove.add(fullname)
return importlib.util.spec_from_loader(fullname,
loader=InMemoryLoader(self.code[""][fullname].encode("utf-8"),
fullname.replace(".", "/")),
origin="")
parent_package = ".".join(module_path[:-1])
return importlib.util.spec_from_loader(
fullname, loader=InMemoryLoader(self.code[""][fullname].encode("utf-8"), fullname.replace(".", "/")),
origin="")
parent_package_path = ".".join(module_path[:-1])
submodule_name = module_path[-1]
if parent_package in self.code and submodule_name in self.code[parent_package]:
if parent_package_path in self.code and submodule_name in self.code[parent_package_path]:
self.imports_to_remove.add(fullname)
return importlib.util.spec_from_loader(fullname,
loader=InMemoryLoader(
self.code[parent_package][submodule_name].encode("utf-8"),
fullname.replace(".", "/")),
origin=parent_package)
return importlib.util.spec_from_loader(
fullname, loader=InMemoryLoader(self.code[parent_package_path][submodule_name].encode("utf-8"),
fullname.replace(".", "/")), origin=parent_package_path)
return None

def attach(self) -> None:
Expand Down
20 changes: 12 additions & 8 deletions src/safeds_runner/server/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
import json
import typing
import runpy
from multiprocessing.managers import SyncManager

import simple_websocket
import stack_data
import logging

from safeds_runner.server.module_manager import InMemoryFinder

# Multiprocessing
multiprocessing_manager = None
global_placeholder_map = {}
multiprocessing_manager: SyncManager | None = None
global_placeholder_map: dict = {}
global_messages_queue: queue.Queue | None = None
# Message Queue
websocket_target = None
messages_queue_thread = None
websocket_target: simple_websocket.Server | None = None
messages_queue_thread: threading.Thread | None = None


def setup_pipeline_execution() -> None:
Expand All @@ -35,15 +39,15 @@ def setup_pipeline_execution() -> None:
messages_queue_thread.start()


def _handle_queue_messages():
def _handle_queue_messages() -> None:
global websocket_target
while True:
message = global_messages_queue.get()
if websocket_target is not None:
websocket_target.send(json.dumps(message))


def set_new_websocket_target(ws) -> None:
def set_new_websocket_target(ws: simple_websocket.Server) -> None:
"""
Inform the message queue handling thread that the websocket connection has changed.
:param ws: New websocket connection
Expand Down Expand Up @@ -91,7 +95,7 @@ def save_placeholder(self, placeholder_name: str, value: typing.Any) -> None:
self.placeholder_map[placeholder_name] = value

def _execute(self) -> None:
logging.info(f"Executing {self.sdspackage}.{self.sdsmodule}.{self.sdspipeline}...")
logging.info("Executing %s.%s.%s...", self.sdspackage, self.sdsmodule, self.sdspipeline)
pipeline_finder = InMemoryFinder(self.code)
pipeline_finder.attach()
main_module = f"gen_{self.sdsmodule}_{self.sdspipeline}"
Expand Down Expand Up @@ -159,7 +163,7 @@ def _get_placeholder_type(value: typing.Any):
return "Any"


def get_placeholder(exec_id: str, placeholder_name: str) -> (str | None, typing.Any):
def get_placeholder(exec_id: str, placeholder_name: str) -> typing.Tuple[str | None, typing.Any]:
"""
Gets a placeholder type and value for an execution id and placeholder name
:param exec_id: Unique id identifying execution
Expand Down

0 comments on commit 5d6cda5

Please sign in to comment.