Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: reloading relative imports on plugin updates #323

Merged
merged 3 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 5 additions & 25 deletions plugin_runner/plugin_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,32 +261,15 @@ def find_modules(base_path: pathlib.Path, prefix: str | None = None) -> list[str
return modules


def sandbox_from_package(package_path: pathlib.Path) -> dict[str, Any]:
def sandbox_from_module(base_path: pathlib.Path, module_name: str) -> Any:
"""Sandbox the code execution."""
package_name = package_path.name
available_modules = find_modules(package_path)
sandboxes = {}

for module_name in available_modules:
result = sandbox_from_module(package_path, module_name)
full_module_name = f"{package_name}.{module_name}"
sandboxes[full_module_name] = result

return sandboxes


def sandbox_from_module(package_path: pathlib.Path, module_name: str) -> Any:
"""Sandbox the code execution."""
module_path = package_path / str(module_name.replace(".", "/") + ".py")
module_path = base_path / str(module_name.replace(".", "/") + ".py")

if not module_path.exists():
raise ModuleNotFoundError(f'Could not load module "{module_name}"')

source_code = module_path.read_text()
sandbox = Sandbox(module_path, namespace=module_name)

full_module_name = f"{package_path.name}.{module_name}"

sandbox = Sandbox(source_code, namespace=full_module_name)
return sandbox.execute()


Expand Down Expand Up @@ -320,7 +303,6 @@ def load_or_reload_plugin(path: pathlib.Path) -> None:
handlers = manifest_json["components"].get("protocols", []) + manifest_json[
"components"
].get("applications", [])
results = sandbox_from_package(path)
except Exception as e:
log.error(f'Unable to load plugin "{name}": {str(e)}')
return
Expand All @@ -336,11 +318,11 @@ def load_or_reload_plugin(path: pathlib.Path) -> None:
continue

try:
result = sandbox_from_module(path.parent, handler_module)

if name_and_class in LOADED_PLUGINS:
log.info(f"Reloading plugin '{name_and_class}'")

result = results[handler_module]

LOADED_PLUGINS[name_and_class]["active"] = True

LOADED_PLUGINS[name_and_class]["class"] = result[handler_class]
Expand All @@ -349,8 +331,6 @@ def load_or_reload_plugin(path: pathlib.Path) -> None:
else:
log.info(f"Loading plugin '{name_and_class}'")

result = results[handler_module]

LOADED_PLUGINS[name_and_class] = {
"active": True,
"class": result[handler_class],
Expand Down
105 changes: 97 additions & 8 deletions plugin_runner/sandbox.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import ast
import builtins
import importlib
import sys
from _ast import AnnAssign
from functools import cached_property
from pathlib import Path
from typing import Any, cast

from RestrictedPython import (
Expand Down Expand Up @@ -88,6 +91,20 @@ def _apply(_ob: Any, *args: Any, **kwargs: Any) -> Any:
return _ob(*args, **kwargs)


def _find_folder_in_path(file_path: Path, target_folder_name: str) -> Path | None:
"""Recursively search for a folder with the specified name in the hierarchy of the given file path."""
file_path = file_path.resolve()

if file_path.name == target_folder_name:
return file_path

# If we've reached the root of the file system, return None
if file_path.parent == file_path:
return None

return _find_folder_in_path(file_path.parent, target_folder_name)


class Sandbox:
"""A restricted sandbox for safely executing arbitrary Python code."""

Expand Down Expand Up @@ -199,16 +216,28 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
# Impossible Case only ctx Load, Store and Del are defined in ast.
raise NotImplementedError(f"Unknown ctx type: {type(node.ctx)}")

def __init__(self, source_code: str, namespace: str | None = None) -> None:
def __init__(
self,
source_code: str | Path,
namespace: str | None = None,
evaluated_modules: dict[str, bool] | None = None,
) -> None:
if source_code is None:
raise TypeError("source_code may not be None")
self.namespace = namespace or "protocols"
self.source_code = source_code

@cached_property
def package_name(self) -> str | None:
"""Return the root package name."""
return self.namespace.split(".")[0] if self.namespace else None
self.namespace = namespace or "protocols"
self.package_name = self.namespace.split(".")[0]

if isinstance(source_code, Path):
if not source_code.exists():
raise FileNotFoundError(f"File not found: {source_code}")
self.source_code = source_code.read_text()
package_path = _find_folder_in_path(source_code, self.package_name)
self.base_path = package_path.parent if package_path else None
self._evaluated_modules: dict[str, bool] = evaluated_modules or {}
else:
self.source_code = source_code
self.base_path = None

@cached_property
def scope(self) -> dict[str, Any]:
Expand Down Expand Up @@ -266,12 +295,72 @@ def warnings(self) -> tuple[str, ...]:
def _is_known_module(self, name: str) -> bool:
return bool(
_is_known_module(name)
or (self.package_name and name.split(".")[0] == self.package_name)
or (self.package_name and name.split(".")[0] == self.package_name and self.base_path)
)

def _get_module(self, module_name: str) -> Path:
"""Get the module path for the given module name."""
module_relative_path = module_name.replace(".", "/")
module = Path(cast(Path, self.base_path) / f"{module_relative_path}.py")

if not module.exists():
module = Path(cast(Path, self.base_path) / f"{module_relative_path}/__init__.py")

return module

def _evaluate_module(self, module_name: str) -> None:
"""Evaluate the given module in the sandbox.
If the module to import belongs to the same package as the current module, evaluate it inside a sandbox.
"""
if not module_name.startswith(self.package_name) or module_name in self._evaluated_modules:
return # Skip modules outside the package or already evaluated.

module = self._get_module(module_name)
self._evaluate_implicit_imports(module)

# Re-check after evaluating implicit imports to avoid duplicate evaluations.
if module_name not in self._evaluated_modules:
Sandbox(
module, namespace=module_name, evaluated_modules=self._evaluated_modules
).execute()
self._evaluated_modules[module_name] = True

# Reload the module if already imported to ensure the latest version is used.
if sys.modules.get(module_name):
importlib.reload(sys.modules[module_name])

def _evaluate_implicit_imports(self, module: Path) -> None:
"""Evaluate implicit imports in the sandbox."""
# Determine the parent module to check for implicit imports.
parent = module.parent.parent if module.name == "__init__.py" else module.parent
base_path = cast(Path, self.base_path)

# Skip evaluation if the parent module is outside the base path or already the source code root.
if not parent.is_relative_to(base_path) or parent == base_path:
return

module_name = parent.relative_to(base_path).as_posix().replace("/", ".")
init_file = parent / "__init__.py"

if module_name not in self._evaluated_modules:
if init_file.exists():
# Mark as evaluated to prevent infinite recursion.
self._evaluated_modules[module_name] = True
Sandbox(
init_file, namespace=module_name, evaluated_modules=self._evaluated_modules
).execute()
else:
# Mark as evaluated even if no init file exists to prevent redundant checks.
self._evaluated_modules[module_name] = True

self._evaluate_implicit_imports(parent)

def _safe_import(self, name: str, *args: Any, **kwargs: Any) -> Any:
if not (self._is_known_module(name)):
raise ImportError(f"{name!r} is not an allowed import.")

self._evaluate_module(name)

return __import__(name, *args, **kwargs)

def execute(self) -> dict:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"sdk_version": "0.1.4",
"plugin_version": "0.0.1",
"name": "test_implicit_imports_plugin",
"description": "Edit the description in CANVAS_MANIFEST.json",
"components": {
"protocols": [
{
"class": "test_implicit_imports_plugin.protocols.my_protocol:Forbidden",
"description": "A protocol that does xyz...",
"data_access": {
"event": "",
"read": [],
"write": []
}
},
{
"class": "test_implicit_imports_plugin.protocols.my_protocol:Allowed",
"description": "A protocol that does xyz...",
"data_access": {
"event": "",
"read": [],
"write": []
}
}
],
"commands": [],
"content": [],
"effects": [],
"views": []
},
"secrets": [],
"tags": {},
"references": [],
"license": "",
"diagram": false,
"readme": "./README.md"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
test_forbiden_implicit_imports_plugin
=====================================

## Description

A description of this plugin

### Important Note!

The CANVAS_MANIFEST.json is used when installing your plugin. Please ensure it
gets updated if you add, remove, or rename protocols.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from canvas_sdk.effects import Effect
from canvas_sdk.events import EventType
from canvas_sdk.protocols import BaseProtocol
from logger import log


class Forbidden(BaseProtocol):
"""You should put a helpful description of this protocol's behavior here."""

# Name the event type you wish to run in response to
RESPONDS_TO = EventType.Name(EventType.UNKNOWN)

def compute(self) -> list[Effect]:
"""This method gets called when an event of the type RESPONDS_TO is fired."""
from test_implicit_imports_plugin.utils.base import OtherClass

OtherClass()

return []


class Allowed(BaseProtocol):
"""You should put a helpful description of this protocol's behavior here."""

RESPONDS_TO = EventType.Name(EventType.UNKNOWN)

def compute(self) -> list[Effect]:
"""This method gets called when an event of the type RESPONDS_TO is fired."""
from test_implicit_imports_plugin.templates import Template

log.info(Template().render())

return []
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from test_implicit_imports_plugin.templates.base import Template

__all__ = ("Template",)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class Template:
"""A template class for testing implicit imports."""

def render(self) -> str:
"""Renders the template."""
return "Hello, World!"
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os

from logger import log

log.info(f"os list dir: {os.listdir('.')}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class OtherClass:
"""This class is used to test implicit imports."""

pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"sdk_version": "0.1.4",
"plugin_version": "0.0.1",
"name": "test_module_forbidden_imports_plugin",
"description": "Edit the description in CANVAS_MANIFEST.json",
"components": {
"protocols": [
{
"class": "test_module_forbidden_imports_plugin.protocols.my_protocol:Protocol",
"description": "A protocol that does xyz...",
"data_access": {
"event": "",
"read": [],
"write": []
}
}
],
"commands": [],
"content": [],
"effects": [],
"views": []
},
"secrets": [],
"tags": {},
"references": [],
"license": "",
"diagram": false,
"readme": "./README.md"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
==========================
test_module_forbidden_imports_plugin
==========================

## Description

A description of this plugin

### Important Note!

The CANVAS_MANIFEST.json is used when installing your plugin. Please ensure it
gets updated if you add, remove, or rename protocols.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os

from logger import log

log.info(f"This is a forbidden import. {os}")


def import_me() -> str:
"""Test method."""
return "Successfully imported!"
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from test_module_forbidden_imports_plugin.other_module.base import import_me

from canvas_sdk.effects import Effect, EffectType
from canvas_sdk.events import EventType
from canvas_sdk.protocols import BaseProtocol


class Protocol(BaseProtocol):
"""
You should put a helpful description of this protocol's behavior here.
"""

# Name the event type you wish to run in response to
RESPONDS_TO = EventType.Name(EventType.UNKNOWN)

def compute(self) -> list[Effect]:
"""This method gets called when an event of the type RESPONDS_TO is fired."""
return [Effect(type=EffectType.LOG, payload=import_me())]
Loading
Loading