Skip to content

Commit

Permalink
feat: refactor ProtocolBaseHandler and add target_type to Events (#236)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamagalhaes authored Jan 2, 2025
1 parent c2ec60c commit d90dcd8
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ def compute(self) -> list[Effect]:
# `self.event` is the event object that caused this method to be
# called.
#
# `self.target` is an identifier for the object that is the subject of
# `self.event.target.id` is an identifier for the object that is the subject of
# the event. In this case, it would be the identifier of the assess
# command. If this was a patient create event, it would be the
# identifier of the patient. If this was a task update event, it would
# be the identifier of the task. Etc, etc.
# If the targeted model is already supported by the SDK,
# you can retrieve the instance using `self.event.target.instance`
#
# `self.context` is a python dictionary of additional data that was
# `self.event.context` is a python dictionary of additional data that was
# given with the event. The information given here depends on the
# event type.
#
Expand Down
12 changes: 6 additions & 6 deletions canvas_generated/messages/events_pb2.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions canvas_generated/messages/events_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1301,14 +1301,16 @@ PLUGIN_CREATED: EventType
PLUGIN_UPDATED: EventType

class Event(_message.Message):
__slots__ = ("type", "target", "context")
__slots__ = ("type", "target", "context", "target_type")
TYPE_FIELD_NUMBER: _ClassVar[int]
TARGET_FIELD_NUMBER: _ClassVar[int]
CONTEXT_FIELD_NUMBER: _ClassVar[int]
TARGET_TYPE_FIELD_NUMBER: _ClassVar[int]
type: EventType
target: str
context: str
def __init__(self, type: _Optional[_Union[EventType, str]] = ..., target: _Optional[str] = ..., context: _Optional[str] = ...) -> None: ...
target_type: str
def __init__(self, type: _Optional[_Union[EventType, str]] = ..., target: _Optional[str] = ..., context: _Optional[str] = ..., target_type: _Optional[str] = ...) -> None: ...

class EventResponse(_message.Message):
__slots__ = ("success", "effects")
Expand Down
7 changes: 5 additions & 2 deletions canvas_sdk/events/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from canvas_generated.messages.events_pb2 import Event, EventResponse, EventType
from canvas_generated.messages.events_pb2 import Event as EventRequest
from canvas_generated.messages.events_pb2 import EventResponse, EventType

__all__ = ("Event", "EventResponse", "EventType")
from .base import Event

__all__ = ("EventRequest", "EventResponse", "EventType", "Event")
50 changes: 50 additions & 0 deletions canvas_sdk/events/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import dataclasses
import json
from functools import cached_property
from typing import Any

from django.apps import apps
from django.db import models

from canvas_generated.messages.events_pb2 import Event as EventRequest
from canvas_generated.messages.events_pb2 import EventType


@dataclasses.dataclass
class TargetType:
"""The target of the event."""

id: str
type: type[models.Model] | None

@cached_property
def instance(self) -> models.Model | None:
"""Return the instance of the target."""
return self.type._default_manager.filter(id=self.id).first() if self.type else None


class Event:
"""An event that occurs in the Canvas environment."""

name: str
type: EventType
context: dict[str, Any]
target: TargetType

def __init__(self, event_request: EventRequest) -> None:
try:
target_model = apps.get_model(
app_label="canvas_sdk", model_name=event_request.target_type
)
except LookupError:
target_model = None

try:
context = json.loads(event_request.context)
except ValueError:
context = {}

self.type = event_request.type
self.name = EventType.Name(self.type)
self.context = context
self.target = TargetType(id=event_request.target, type=target_model)
46 changes: 31 additions & 15 deletions canvas_sdk/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,45 @@
import json
from typing import TYPE_CHECKING, Any
import importlib.metadata
from typing import Any

if TYPE_CHECKING:
from canvas_generated.messages.events_pb2 import Event
import deprecation

from canvas_sdk.events import Event

version = importlib.metadata.version("canvas")


class BaseHandler:
"""
The class that all handlers inherit from.
"""
"""The class that all handlers inherit from."""

secrets: dict[str, Any]
target: str
event: Event

def __init__(
self,
event: "Event",
event: Event,
secrets: dict[str, Any] | None = None,
) -> None:
self.event = event
self.secrets = secrets or {}

try:
self.context = json.loads(event.context)
except ValueError:
self.context = {}
@property
@deprecation.deprecated(
deprecated_in="0.11.0",
removed_in="1.0.0",
current_version=version,
details="Use 'event.context' directly instead",
)
def context(self) -> dict[str, Any]:
"""The context of the event."""
return self.event.context

self.target = event.target
self.secrets = secrets or {}
@property
@deprecation.deprecated(
deprecated_in="0.11.0",
removed_in="1.0.0",
current_version=version,
details="Use 'event.target['id']' directly instead",
)
def target(self) -> str:
"""The target id of the event."""
return self.event.target.id
7 changes: 7 additions & 0 deletions canvas_sdk/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# ruff: noqa
# register all models in the app so they can be used with apps.get_model()
from canvas_sdk.v1.data.command import Command
from canvas_sdk.v1.data.condition import Condition
from canvas_sdk.v1.data.note import Note
from canvas_sdk.v1.data.patient import Patient
from canvas_sdk.v1.data.user import CanvasUser
56 changes: 50 additions & 6 deletions plugin_runner/plugin_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import grpc
import statsd

from canvas_generated.messages.effects_pb2 import EffectType
from canvas_generated.messages.plugins_pb2 import (
ReloadPluginsRequest,
ReloadPluginsResponse,
Expand All @@ -24,7 +25,7 @@
add_PluginRunnerServicer_to_server,
)
from canvas_sdk.effects import Effect
from canvas_sdk.events import Event, EventResponse, EventType
from canvas_sdk.events import Event, EventRequest, EventResponse, EventType
from canvas_sdk.protocols import ClinicalQualityMeasure
from canvas_sdk.utils.stats import get_duration_ms, tags_to_line_protocol
from logger import log
Expand Down Expand Up @@ -98,16 +99,17 @@ def __init__(self) -> None:
sandbox: Sandbox

async def HandleEvent(
self, request: Event, context: Any
self, request: EventRequest, context: Any
) -> AsyncGenerator[EventResponse, None]:
"""This is invoked when an event comes in."""
event_start_time = time.time()
event_type = request.type
event_name = EventType.Name(event_type)
event = Event(request)
event_type = event.type
event_name = event.name
relevant_plugins = EVENT_PROTOCOL_MAP[event_name]

if event_type in [EventType.PLUGIN_CREATED, EventType.PLUGIN_UPDATED]:
plugin_name = request.target
plugin_name = event.target.id
# filter only for the plugin(s) that were created/updated
relevant_plugins = [p for p in relevant_plugins if p.startswith(f"{plugin_name}:")]

Expand All @@ -122,7 +124,7 @@ async def HandleEvent(
secrets["graphql_jwt"] = token_for_plugin(plugin_name=plugin_name, audience="home")

try:
protocol = protocol_class(request, secrets)
protocol = protocol_class(event, secrets)
classname = (
protocol.__class__.__name__
if isinstance(protocol, ClinicalQualityMeasure)
Expand All @@ -140,6 +142,11 @@ async def HandleEvent(
)
for effect in _effects
]

effects = validate_effects(effects)

apply_effects_to_context(effects, event=event)

compute_duration = get_duration_ms(compute_start_time)

log.info(f"{plugin_name}.compute() completed ({compute_duration} ms)")
Expand Down Expand Up @@ -182,6 +189,43 @@ async def ReloadPlugins(
yield ReloadPluginsResponse(success=True)


def validate_effects(effects: list[Effect]) -> list[Effect]:
"""Validates the effects based on predefined rules.
Keeps only the first AUTOCOMPLETE_SEARCH_RESULTS effect and preserve all non-search-related effects.
"""
seen_autocomplete = False
validated_effects = []

for effect in effects:
if effect.type == EffectType.AUTOCOMPLETE_SEARCH_RESULTS:
if seen_autocomplete:
log.warning("Discarding additional AUTOCOMPLETE_SEARCH_RESULTS effect.")
continue
seen_autocomplete = True
validated_effects.append(effect)

return validated_effects


def apply_effects_to_context(effects: list[Effect], event: Event) -> Event:
"""Applies AUTOCOMPLETE_SEARCH_RESULTS effects to the event context.
If we are dealing with a search event, we need to update the context with the search results.
"""
event_name = event.name

# Skip if the event is not a search event
if not event_name.endswith("__PRE_SEARCH") and not event_name.endswith("__POST_SEARCH"):
return event

for effect in effects:
if effect.type == EffectType.AUTOCOMPLETE_SEARCH_RESULTS:
event.context["results"] = json.loads(effect.payload)
# Stop processing effects if we've found a AUTOCOMPLETE_SEARCH_RESULTS
break

return event


def handle_hup_cb(_signum: int, _frame: FrameType | None) -> None:
"""handle_hup_cb."""
log.info("Received SIGHUP, reloading plugins...")
Expand Down
6 changes: 3 additions & 3 deletions plugin_runner/tests/test_plugin_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from canvas_generated.messages.effects_pb2 import EffectType
from canvas_generated.messages.plugins_pb2 import ReloadPluginsRequest
from canvas_sdk.events import Event, EventType
from canvas_sdk.events import EventRequest, EventType
from plugin_runner.plugin_runner import (
EVENT_PROTOCOL_MAP,
LOADED_PLUGINS,
Expand Down Expand Up @@ -91,7 +91,7 @@ async def test_load_plugins_with_plugin_that_imports_other_modules_within_plugin

result = [
response
async for response in plugin_runner.HandleEvent(Event(type=EventType.UNKNOWN), None)
async for response in plugin_runner.HandleEvent(EventRequest(type=EventType.UNKNOWN), None)
]

assert len(result) == 1
Expand Down Expand Up @@ -160,7 +160,7 @@ async def test_handle_plugin_event_returns_expected_result(
"""Test that HandleEvent successfully calls the relevant plugins and returns the expected result."""
load_plugins()

event = Event(type=EventType.UNKNOWN)
event = EventRequest(type=EventType.UNKNOWN)

result = []
async for response in plugin_runner.HandleEvent(event, None):
Expand Down
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions protobufs/canvas_generated/messages/events.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,7 @@ message Event {
// }
string target = 2;
string context = 3;
string target_type = 4;
}

message EventResponse {
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ version = "0.10.2"
[tool.poetry.dependencies]
cookiecutter = "*"
cron-converter = "^1.2.1"
deprecation = "^2.1.0"
django = "^5.1.1"
django-stubs = {extras = ["compatible-mypy"], version = "^5.1.1"}
django-timezone-utils = "^0.15.0"
Expand Down

0 comments on commit d90dcd8

Please sign in to comment.