Skip to content

Commit

Permalink
feat(plugins): add support for importing other modules within a plugin (
Browse files Browse the repository at this point in the history
  • Loading branch information
jamagalhaes authored Nov 14, 2024
1 parent 2ea51da commit ac077fe
Show file tree
Hide file tree
Showing 40 changed files with 802 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,15 @@

# Inherit from BaseProtocol to properly get registered for events
class Protocol(BaseProtocol):
"""
You should put a helpful description of this protocol's behavior here.
"""
"""You should put a helpful description of this protocol's behavior here."""

# Name the event type you wish to run in response to
RESPONDS_TO = EventType.Name(EventType.ASSESS_COMMAND__CONDITION_SELECTED)

NARRATIVE_STRING = "I was inserted from my plugin's protocol."

def compute(self) -> list[Effect]:
"""
This method gets called when an event of the type RESPONDS_TO is fired.
"""
"""This method gets called when an event of the type RESPONDS_TO is fired."""
# This class is initialized with several pieces of information you can
# access.
#
Expand Down
10 changes: 3 additions & 7 deletions plugin_runner/authentication.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
import os
from typing import cast

import arrow
from jwt import encode

from logger import log
from settings import PLUGIN_RUNNER_SIGNING_KEY

ONE_DAY_IN_MINUTES = 60 * 24

INSECURE_DEFAULT_SIGNING_KEY = "INSECURE_KEY"


def token_for_plugin(
plugin_name: str,
audience: str,
issuer: str = "plugin-runner",
jwt_signing_key: str = cast(
str, os.getenv("PLUGIN_RUNNER_SIGNING_KEY", INSECURE_DEFAULT_SIGNING_KEY)
),
jwt_signing_key: str = PLUGIN_RUNNER_SIGNING_KEY,
expiration_minutes: int = ONE_DAY_IN_MINUTES,
extra_kwargs: dict | None = None,
) -> str:
Expand All @@ -27,7 +23,7 @@ def token_for_plugin(
if not extra_kwargs:
extra_kwargs = {}

if jwt_signing_key == INSECURE_DEFAULT_SIGNING_KEY:
if not jwt_signing_key:
log.warning(
"Using an insecure JWT signing key for GraphQL access. Set the PLUGIN_RUNNER_SIGNING_KEY environment variable to avoid this message."
)
Expand Down
75 changes: 49 additions & 26 deletions plugin_runner/plugin_runner.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import asyncio
import importlib.util
import json
import os
import pathlib
import pkgutil
import signal
import sys
import time
import traceback
from collections import defaultdict
from types import FrameType
from typing import Any, AsyncGenerator, Optional, TypedDict, cast
from typing import Any, AsyncGenerator, Optional, TypedDict

import grpc
import statsd
Expand All @@ -30,17 +30,7 @@
from plugin_runner.authentication import token_for_plugin
from plugin_runner.plugin_synchronizer import publish_message
from plugin_runner.sandbox import Sandbox

ENV = os.getenv("ENV", "development")

IS_PRODUCTION = ENV == "production"

MANIFEST_FILE_NAME = "CANVAS_MANIFEST.json"

SECRETS_FILE_NAME = "SECRETS.json"

# specify a local plugin directory for development
PLUGIN_DIRECTORY = "/plugin-runner/custom-plugins" if IS_PRODUCTION else "./custom-plugins"
from settings import MANIFEST_FILE_NAME, PLUGIN_DIRECTORY, SECRETS_FILE_NAME

# when we import plugins we'll use the module name directly so we need to add the plugin
# directory to the path
Expand All @@ -51,7 +41,7 @@
LOADED_PLUGINS: dict = {}

# a global dictionary of events to protocol class names
EVENT_PROTOCOL_MAP: dict = {}
EVENT_PROTOCOL_MAP: dict[str, list] = defaultdict(list)


class DataAccess(TypedDict):
Expand Down Expand Up @@ -113,7 +103,7 @@ async def HandleEvent(
event_start_time = time.time()
event_type = request.type
event_name = EventType.Name(event_type)
relevant_plugins = EVENT_PROTOCOL_MAP.get(event_name, [])
relevant_plugins = EVENT_PROTOCOL_MAP[event_name]

if event_type in [EventType.PLUGIN_CREATED, EventType.PLUGIN_UPDATED]:
plugin_name = request.target
Expand Down Expand Up @@ -197,18 +187,50 @@ def handle_hup_cb(_signum: int, _frame: Optional[FrameType]) -> None:
load_plugins()


def sandbox_from_module_name(module_name: str) -> Any:
def find_modules(base_path: pathlib.Path, prefix: str | None = None) -> list[str]:
"""Find all modules in the specified package path."""
modules: list[str] = []

for file_finder, module_name, is_pkg in pkgutil.iter_modules(
[base_path.as_posix()],
):
if is_pkg:
modules = modules + find_modules(
base_path / module_name,
prefix=f"{prefix}.{module_name}" if prefix else module_name,
)
else:
modules.append(f"{prefix}.{module_name}" if prefix else module_name)

return modules


def sandbox_from_package(package_path: pathlib.Path) -> dict[str, Any]:
"""Sandbox the code execution."""
package_name = package_path.name
available_modules = find_modules(package_path)
sandboxes = {}

for module_name in available_modules:
result = sandbox_from_module(package_path, module_name)
full_module_name = f"{package_name}.{module_name}"
sandboxes[full_module_name] = result

return sandboxes


def sandbox_from_module(package_path: pathlib.Path, module_name: str) -> Any:
"""Sandbox the code execution."""
spec = importlib.util.find_spec(module_name)
module_path = package_path / str(module_name.replace(".", "/") + ".py")

if not spec or not spec.origin:
raise Exception(f'Could not load plugin "{module_name}"')
if not module_path.exists():
raise ModuleNotFoundError(f'Could not load module "{module_name}"')

origin = pathlib.Path(spec.origin)
source_code = origin.read_text()
source_code = module_path.read_text()

sandbox = Sandbox(source_code)
full_module_name = f"{package_path.name}.{module_name}"

sandbox = Sandbox(source_code, module_name=full_module_name)
return sandbox.execute()


Expand Down Expand Up @@ -244,6 +266,8 @@ def load_or_reload_plugin(path: pathlib.Path) -> None:
log.error(f'Unable to load plugin "{name}": {str(e)}')
return

results = sandbox_from_package(path)

for protocol in protocols:
# TODO add class colon validation to existing schema validation
# TODO when we encounter an exception here, disable the plugin in response
Expand All @@ -258,7 +282,7 @@ def load_or_reload_plugin(path: pathlib.Path) -> None:
if name_and_class in LOADED_PLUGINS:
log.info(f"Reloading plugin '{name_and_class}'")

result = sandbox_from_module_name(protocol_module)
result = results[protocol_module]

LOADED_PLUGINS[name_and_class]["active"] = True

Expand All @@ -268,7 +292,7 @@ def load_or_reload_plugin(path: pathlib.Path) -> None:
else:
log.info(f"Loading plugin '{name_and_class}'")

result = sandbox_from_module_name(protocol_module)
result = results[protocol_module]

LOADED_PLUGINS[name_and_class] = {
"active": True,
Expand All @@ -285,8 +309,7 @@ def load_or_reload_plugin(path: pathlib.Path) -> None:

def refresh_event_type_map() -> None:
"""Ensure the event subscriptions are up to date."""
global EVENT_PROTOCOL_MAP
EVENT_PROTOCOL_MAP = defaultdict(list)
EVENT_PROTOCOL_MAP.clear()

for name, plugin in LOADED_PLUGINS.items():
if hasattr(plugin["class"], "RESPONDS_TO"):
Expand Down
30 changes: 22 additions & 8 deletions plugin_runner/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,6 @@ def _is_known_module(name: str) -> bool:
return any(name.startswith(m) for m in ALLOWED_MODULES)


def _safe_import(name: str, *args: Any, **kwargs: Any) -> Any:
if not _is_known_module(name):
raise ImportError(f"{name!r} is not an allowed import.")
return __import__(name, *args, **kwargs)


def _unrestricted(_ob: Any, *args: Any, **kwargs: Any) -> Any:
"""Return the given object, unmodified."""
return _ob
Expand All @@ -96,6 +90,7 @@ class Sandbox:

source_code: str
namespace: str
module_name: str | None

class Transformer(RestrictingNodeTransformer):
"""A node transformer for customizing the sandbox compiler."""
Expand Down Expand Up @@ -204,20 +199,28 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
# Impossible Case only ctx Load, Store and Del are defined in ast.
raise NotImplementedError(f"Unknown ctx type: {type(node.ctx)}")

def __init__(self, source_code: str, namespace: str | None = None) -> None:
def __init__(
self, source_code: str, namespace: str | None = None, module_name: str | None = None
) -> None:
if source_code is None:
raise TypeError("source_code may not be None")
self.module_name = module_name
self.namespace = namespace or "protocols"
self.source_code = source_code

@cached_property
def package_name(self) -> str | None:
"""Return the root package name."""
return self.module_name.split(".")[0] if self.module_name else None

@cached_property
def scope(self) -> dict[str, Any]:
"""Return the scope used for evaluation."""
return {
"__builtins__": {
**safe_builtins.copy(),
**utility_builtins.copy(),
"__import__": _safe_import,
"__import__": self._safe_import,
"classmethod": builtins.classmethod,
"staticmethod": builtins.staticmethod,
"any": builtins.any,
Expand Down Expand Up @@ -263,6 +266,17 @@ def warnings(self) -> tuple[str, ...]:
"""Return warnings encountered when compiling the source code."""
return cast(tuple[str, ...], self.compile_result.warnings)

def _is_known_module(self, name: str) -> bool:
return bool(
_is_known_module(name)
or (self.package_name and name.split(".")[0] == self.package_name)
)

def _safe_import(self, name: str, *args: Any, **kwargs: Any) -> Any:
if not (self._is_known_module(name)):
raise ImportError(f"{name!r} is not an allowed import.")
return __import__(name, *args, **kwargs)

def execute(self) -> dict:
"""Execute the given code in a restricted sandbox."""
if self.errors:
Expand Down
Empty file added plugin_runner/tests/__init__.py
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"sdk_version": "0.1.4",
"plugin_version": "0.0.1",
"name": "example_plugin",
"description": "Edit the description in CANVAS_MANIFEST.json",
"components": {
"protocols": [
{
"class": "example_plugin.protocols.my_protocol:Protocol",
"description": "A protocol that does xyz...",
"data_access": {
"event": "",
"read": [],
"write": []
}
}
],
"commands": [],
"content": [],
"effects": [],
"views": []
},
"secrets": [],
"tags": {},
"references": [],
"license": "",
"diagram": false,
"readme": "./README.md"
}
12 changes: 12 additions & 0 deletions plugin_runner/tests/fixtures/plugins/example_plugin/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
==============
example_plugin
==============

## Description

A description of this plugin

### Important Note!

The CANVAS_MANIFEST.json is used when installing your plugin. Please ensure it
gets updated if you add, remove, or rename protocols.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from canvas_sdk.effects import Effect, EffectType
from canvas_sdk.events import EventType
from canvas_sdk.protocols import BaseProtocol


class Protocol(BaseProtocol):
"""
You should put a helpful description of this protocol's behavior here.
"""

# Name the event type you wish to run in response to
RESPONDS_TO = EventType.Name(EventType.UNKNOWN)

NARRATIVE_STRING = "I was inserted from my plugin's protocol."

def compute(self) -> list[Effect]:
"""This method gets called when an event of the type RESPONDS_TO is fired."""
return [Effect(type=EffectType.LOG, payload="Hello, world!")]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"sdk_version": "0.1.4",
"plugin_version": "0.0.1",
"name": "test_module_imports_outside_plugin_v1",
"description": "Edit the description in CANVAS_MANIFEST.json",
"components": {
"protocols": [
{
"class": "test_module_imports_outside_plugin_v1.protocols.my_protocol:Protocol",
"description": "A protocol that does xyz...",
"data_access": {
"event": "",
"read": [],
"write": []
}
}
],
"commands": [],
"content": [],
"effects": [],
"views": []
},
"secrets": [],
"tags": {},
"references": [],
"license": "",
"diagram": false,
"readme": "./README.md"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
==========================
test_module_imports_outside_plugin_v1
==========================

## Description

A description of this plugin

### Important Note!

The CANVAS_MANIFEST.json is used when installing your plugin. Please ensure it
gets updated if you add, remove, or rename protocols.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def import_me() -> str:
"""Test method."""
return "Successfully imported!"
Loading

0 comments on commit ac077fe

Please sign in to comment.