diff --git a/plugin_runner/plugin_runner.py b/plugin_runner/plugin_runner.py index 63776325..6b3ba99c 100644 --- a/plugin_runner/plugin_runner.py +++ b/plugin_runner/plugin_runner.py @@ -262,32 +262,15 @@ def find_modules(base_path: pathlib.Path, prefix: str | None = None) -> list[str return modules -def sandbox_from_package(package_path: pathlib.Path) -> dict[str, Any]: +def sandbox_from_module(base_path: pathlib.Path, module_name: str) -> Any: """Sandbox the code execution.""" - package_name = package_path.name - available_modules = find_modules(package_path) - sandboxes = {} - - for module_name in available_modules: - result = sandbox_from_module(package_path, module_name) - full_module_name = f"{package_name}.{module_name}" - sandboxes[full_module_name] = result - - return sandboxes - - -def sandbox_from_module(package_path: pathlib.Path, module_name: str) -> Any: - """Sandbox the code execution.""" - module_path = package_path / str(module_name.replace(".", "/") + ".py") + module_path = base_path / str(module_name.replace(".", "/") + ".py") if not module_path.exists(): raise ModuleNotFoundError(f'Could not load module "{module_name}"') - source_code = module_path.read_text() + sandbox = Sandbox(module_path, namespace=module_name) - full_module_name = f"{package_path.name}.{module_name}" - - sandbox = Sandbox(source_code, namespace=full_module_name) return sandbox.execute() @@ -321,7 +304,6 @@ def load_or_reload_plugin(path: pathlib.Path) -> None: handlers = manifest_json["components"].get("protocols", []) + manifest_json[ "components" ].get("applications", []) - results = sandbox_from_package(path) except Exception as e: log.error(f'Unable to load plugin "{name}": {str(e)}') return @@ -337,11 +319,11 @@ def load_or_reload_plugin(path: pathlib.Path) -> None: continue try: + result = sandbox_from_module(path.parent, handler_module) + if name_and_class in LOADED_PLUGINS: log.info(f"Reloading plugin '{name_and_class}'") - result = results[handler_module] - LOADED_PLUGINS[name_and_class]["active"] = True LOADED_PLUGINS[name_and_class]["class"] = result[handler_class] @@ -350,8 +332,6 @@ def load_or_reload_plugin(path: pathlib.Path) -> None: else: log.info(f"Loading plugin '{name_and_class}'") - result = results[handler_module] - LOADED_PLUGINS[name_and_class] = { "active": True, "class": result[handler_class], diff --git a/plugin_runner/sandbox.py b/plugin_runner/sandbox.py index c70b987c..6726e31d 100644 --- a/plugin_runner/sandbox.py +++ b/plugin_runner/sandbox.py @@ -1,7 +1,10 @@ import ast import builtins +import importlib +import sys from _ast import AnnAssign from functools import cached_property +from pathlib import Path from typing import Any, cast from RestrictedPython import ( @@ -88,6 +91,20 @@ def _apply(_ob: Any, *args: Any, **kwargs: Any) -> Any: return _ob(*args, **kwargs) +def _find_folder_in_path(file_path: Path, target_folder_name: str) -> Path | None: + """Recursively search for a folder with the specified name in the hierarchy of the given file path.""" + file_path = file_path.resolve() + + if file_path.name == target_folder_name: + return file_path + + # If we've reached the root of the file system, return None + if file_path.parent == file_path: + return None + + return _find_folder_in_path(file_path.parent, target_folder_name) + + class Sandbox: """A restricted sandbox for safely executing arbitrary Python code.""" @@ -199,16 +216,23 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST: # Impossible Case only ctx Load, Store and Del are defined in ast. raise NotImplementedError(f"Unknown ctx type: {type(node.ctx)}") - def __init__(self, source_code: str, namespace: str | None = None) -> None: + def __init__(self, source_code: str | Path, namespace: str | None = None) -> None: if source_code is None: raise TypeError("source_code may not be None") - self.namespace = namespace or "protocols" - self.source_code = source_code - @cached_property - def package_name(self) -> str | None: - """Return the root package name.""" - return self.namespace.split(".")[0] if self.namespace else None + self.namespace = namespace or "protocols" + self.package_name = self.namespace.split(".")[0] + + if isinstance(source_code, Path): + if not source_code.exists(): + raise FileNotFoundError(f"File not found: {source_code}") + self.source_code = source_code.read_text() + module_path = _find_folder_in_path(source_code, self.package_name) + self.base_path = module_path.parent if module_path else None + self._evaluated_modules: dict[str, bool] = {} + else: + self.source_code = source_code + self.base_path = None @cached_property def scope(self) -> dict[str, Any]: @@ -266,12 +290,27 @@ def warnings(self) -> tuple[str, ...]: def _is_known_module(self, name: str) -> bool: return bool( _is_known_module(name) - or (self.package_name and name.split(".")[0] == self.package_name) + or (self.package_name and name.split(".")[0] == self.package_name and self.base_path) ) + def _evaluate_module(self, name: str) -> None: + """Evaluate the given module in the sandbox. + If the module to import belongs to the same package as the current module, evaluate it inside a sandbox. + """ + if name.startswith(self.package_name) and name not in self._evaluated_modules: + code = Path(cast(Path, self.base_path) / f"{name.replace('.', '/')}.py").read_text() + Sandbox(code, namespace=name).execute() + self._evaluated_modules[name] = True + if sys.modules.get(name): + # if the module was already imported, reload it to make sure the latest version is used + importlib.reload(sys.modules[name]) + def _safe_import(self, name: str, *args: Any, **kwargs: Any) -> Any: if not (self._is_known_module(name)): raise ImportError(f"{name!r} is not an allowed import.") + + self._evaluate_module(name) + return __import__(name, *args, **kwargs) def execute(self) -> dict: diff --git a/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/CANVAS_MANIFEST.json b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/CANVAS_MANIFEST.json new file mode 100644 index 00000000..0c049dd7 --- /dev/null +++ b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/CANVAS_MANIFEST.json @@ -0,0 +1,29 @@ +{ + "sdk_version": "0.1.4", + "plugin_version": "0.0.1", + "name": "test_module_forbidden_imports_plugin", + "description": "Edit the description in CANVAS_MANIFEST.json", + "components": { + "protocols": [ + { + "class": "test_module_forbidden_imports_plugin.protocols.my_protocol:Protocol", + "description": "A protocol that does xyz...", + "data_access": { + "event": "", + "read": [], + "write": [] + } + } + ], + "commands": [], + "content": [], + "effects": [], + "views": [] + }, + "secrets": [], + "tags": {}, + "references": [], + "license": "", + "diagram": false, + "readme": "./README.md" +} diff --git a/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/README.md b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/README.md new file mode 100644 index 00000000..0e61f047 --- /dev/null +++ b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/README.md @@ -0,0 +1,12 @@ +========================== +test_module_forbidden_imports_plugin +========================== + +## Description + +A description of this plugin + +### Important Note! + +The CANVAS_MANIFEST.json is used when installing your plugin. Please ensure it +gets updated if you add, remove, or rename protocols. diff --git a/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/other_module/__init__.py b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/other_module/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/other_module/base.py b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/other_module/base.py new file mode 100644 index 00000000..f7b97d1d --- /dev/null +++ b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/other_module/base.py @@ -0,0 +1,10 @@ +import os + +from logger import log + +log.info(f"This is a forbidden import. {os}") + + +def import_me() -> str: + """Test method.""" + return "Successfully imported!" diff --git a/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/protocols/__init__.py b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/protocols/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/protocols/my_protocol.py b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/protocols/my_protocol.py new file mode 100644 index 00000000..27951a8b --- /dev/null +++ b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_plugin/protocols/my_protocol.py @@ -0,0 +1,18 @@ +from test_module_forbidden_imports_plugin.other_module.base import import_me + +from canvas_sdk.effects import Effect, EffectType +from canvas_sdk.events import EventType +from canvas_sdk.protocols import BaseProtocol + + +class Protocol(BaseProtocol): + """ + You should put a helpful description of this protocol's behavior here. + """ + + # Name the event type you wish to run in response to + RESPONDS_TO = EventType.Name(EventType.UNKNOWN) + + def compute(self) -> list[Effect]: + """This method gets called when an event of the type RESPONDS_TO is fired.""" + return [Effect(type=EffectType.LOG, payload=import_me())] diff --git a/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/CANVAS_MANIFEST.json b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/CANVAS_MANIFEST.json new file mode 100644 index 00000000..8ba4ab99 --- /dev/null +++ b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/CANVAS_MANIFEST.json @@ -0,0 +1,29 @@ +{ + "sdk_version": "0.1.4", + "plugin_version": "0.0.1", + "name": "test_module_forbidden_imports_runtime_plugin", + "description": "Edit the description in CANVAS_MANIFEST.json", + "components": { + "protocols": [ + { + "class": "test_module_forbidden_imports_runtime_plugin.protocols.my_protocol:Protocol", + "description": "A protocol that does xyz...", + "data_access": { + "event": "", + "read": [], + "write": [] + } + } + ], + "commands": [], + "content": [], + "effects": [], + "views": [] + }, + "secrets": [], + "tags": {}, + "references": [], + "license": "", + "diagram": false, + "readme": "./README.md" +} diff --git a/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/README.md b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/README.md new file mode 100644 index 00000000..c446cd3f --- /dev/null +++ b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/README.md @@ -0,0 +1,12 @@ +========================== +test_module_forbidden_imports_runtime_plugin +========================== + +## Description + +A description of this plugin + +### Important Note! + +The CANVAS_MANIFEST.json is used when installing your plugin. Please ensure it +gets updated if you add, remove, or rename protocols. diff --git a/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/other_module/__init__.py b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/other_module/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/other_module/base.py b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/other_module/base.py new file mode 100644 index 00000000..f7b97d1d --- /dev/null +++ b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/other_module/base.py @@ -0,0 +1,10 @@ +import os + +from logger import log + +log.info(f"This is a forbidden import. {os}") + + +def import_me() -> str: + """Test method.""" + return "Successfully imported!" diff --git a/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/protocols/__init__.py b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/protocols/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/protocols/my_protocol.py b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/protocols/my_protocol.py new file mode 100644 index 00000000..a2f8d36b --- /dev/null +++ b/plugin_runner/tests/fixtures/plugins/test_module_forbidden_imports_runtime_plugin/protocols/my_protocol.py @@ -0,0 +1,18 @@ +from canvas_sdk.effects import Effect, EffectType +from canvas_sdk.events import EventType +from canvas_sdk.protocols import BaseProtocol + + +class Protocol(BaseProtocol): + """ + You should put a helpful description of this protocol's behavior here. + """ + + # Name the event type you wish to run in response to + RESPONDS_TO = EventType.Name(EventType.UNKNOWN) + + def compute(self) -> list[Effect]: + """This method gets called when an event of the type RESPONDS_TO is fired.""" + from test_module_forbidden_imports_runtime_plugin.other_module.base import import_me + + return [Effect(type=EffectType.LOG, payload=import_me())] diff --git a/plugin_runner/tests/test_plugin_runner.py b/plugin_runner/tests/test_plugin_runner.py index 66164394..16f2aff2 100644 --- a/plugin_runner/tests/test_plugin_runner.py +++ b/plugin_runner/tests/test_plugin_runner.py @@ -1,3 +1,4 @@ +import logging import shutil from collections.abc import Generator from pathlib import Path @@ -7,13 +8,13 @@ from canvas_generated.messages.effects_pb2 import EffectType from canvas_generated.messages.plugins_pb2 import ReloadPluginsRequest -from canvas_sdk.events import EventRequest, EventType +from canvas_sdk.events import Event, EventRequest, EventType from plugin_runner.plugin_runner import ( EVENT_HANDLER_MAP, LOADED_PLUGINS, PluginRunner, + load_or_reload_plugin, load_plugins, - sandbox_from_package, ) @@ -111,11 +112,53 @@ async def test_load_plugins_with_plugin_that_imports_other_modules_within_plugin indirect=True, ) def test_load_plugins_with_plugin_that_imports_other_modules_outside_plugin_package( - setup_test_plugin: Path, + setup_test_plugin: Path, caplog: pytest.LogCaptureFixture ) -> None: """Test loading plugins with an invalid plugin that imports other modules outside the current plugin package.""" - with pytest.raises(ImportError, match="is not an allowed import"): - sandbox_from_package(setup_test_plugin) + with caplog.at_level(logging.ERROR): + load_or_reload_plugin(setup_test_plugin) + + assert any( + "Error importing module" in record.message for record in caplog.records + ), "log.error() was not called with the expected message." + + +@pytest.mark.parametrize( + "setup_test_plugin", + [ + "test_module_forbidden_imports_plugin", + ], + indirect=True, +) +def test_load_plugins_with_plugin_that_imports_forbidden_modules( + setup_test_plugin: Path, caplog: pytest.LogCaptureFixture +) -> None: + """Test loading plugins with an invalid plugin that imports forbidden modules.""" + with caplog.at_level(logging.ERROR): + load_or_reload_plugin(setup_test_plugin) + + assert any( + "Error importing module" in record.message for record in caplog.records + ), "log.error() was not called with the expected message." + + +@pytest.mark.parametrize( + "setup_test_plugin", + [ + "test_module_forbidden_imports_runtime_plugin", + ], + indirect=True, +) +def test_load_plugins_with_plugin_that_imports_forbidden_modules_at_runtime( + setup_test_plugin: Path, +) -> None: + """Test loading plugins with an invalid plugin that imports forbidden modules at runtime.""" + with pytest.raises(ImportError, match="is not an allowed import."): + load_or_reload_plugin(setup_test_plugin) + class_handler = LOADED_PLUGINS[ + "test_module_forbidden_imports_runtime_plugin:test_module_forbidden_imports_runtime_plugin.protocols.my_protocol:Protocol" + ]["class"] + class_handler(Event(EventRequest(type=EventType.UNKNOWN))).compute() @pytest.mark.parametrize("setup_test_plugin", ["example_plugin"], indirect=True) @@ -205,3 +248,44 @@ async def test_reload_plugins_import_error(plugin_runner: PluginRunner) -> None: assert len(responses) == 1 assert responses[0].success is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize("setup_test_plugin", ["test_module_imports_plugin"], indirect=True) +async def test_changes_to_plugin_modules_should_be_reflected_after_reload( + setup_test_plugin: Path, plugin_runner: PluginRunner +) -> None: + """Test that changes to plugin modules are reflected after reloading the plugin.""" + load_plugins() + + event = EventRequest(type=EventType.UNKNOWN) + + result = [] + async for response in plugin_runner.HandleEvent(event, None): + result.append(response) + + assert len(result) == 1 + assert result[0].success is True + assert len(result[0].effects) == 1 + assert result[0].effects[0].type == EffectType.LOG + assert result[0].effects[0].payload == "Successfully imported!" + + NEW_CODE = """ +def import_me() -> str: + return "Successfully changed!" +""" + file_path = setup_test_plugin / "other_module" / "base.py" + file_path.write_text(NEW_CODE, encoding="utf-8") + + # Reload the plugin + load_plugins() + + result = [] + async for response in plugin_runner.HandleEvent(event, None): + result.append(response) + + assert len(result) == 1 + assert result[0].success is True + assert len(result[0].effects) == 1 + assert result[0].effects[0].type == EffectType.LOG + assert result[0].effects[0].payload == "Successfully changed!" diff --git a/plugin_runner/tests/test_sandbox.py b/plugin_runner/tests/test_sandbox.py index c4a2ce30..9ad5031c 100644 --- a/plugin_runner/tests/test_sandbox.py +++ b/plugin_runner/tests/test_sandbox.py @@ -23,9 +23,9 @@ builtins = {} """ -SOURCE_CODE_MODULE_OS = """ -import os -result = os.listdir('.') +SOURCE_CODE_MODULE = """ +import module.b +result = module.b """ @@ -98,16 +98,8 @@ def test_print_collector() -> None: assert "Hello, Sandbox!" in scope["_print"].txt, "Print output should be captured." -def test_sandbox_module_name_imports_within_package() -> None: - """Test that modules within the same package can be imported.""" - sandbox_module_a = Sandbox(source_code=SOURCE_CODE_MODULE_OS, namespace="os.a") - result = sandbox_module_a.execute() - - assert "os" in result - - def test_sandbox_denies_module_name_import_outside_package() -> None: """Test that modules outside the root package cannot be imported.""" - sandbox_module_a = Sandbox(source_code=SOURCE_CODE_MODULE_OS, namespace="module.a") - with pytest.raises(ImportError, match="os' is not an allowed import."): + sandbox_module_a = Sandbox(source_code=SOURCE_CODE_MODULE, namespace="other_module.a") + with pytest.raises(ImportError, match="module.b' is not an allowed import."): sandbox_module_a.execute()