Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds GraphQL data client #81

Merged
merged 20 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions canvas_sdk/data/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Any, cast

from gql import Client, gql
from gql.transport.aiohttp import AIOHTTPTransport

from settings import GRAPHQL_ENDPOINT


class _CanvasGQLClient:
"""
This is a GraphQL client that can be used to query home-app in order to fetch data for use in plugins.

Usage Examples:

A query with no parameters:

TEST_QUERY_NO_PARAMS = '''
query PatientsAll {
patients {
edges {
node {
firstName
lastName
birthDate
}
}
}
}
'''

client = _CanvasGQLClient()
result = client.query(TEST_QUERY_NO_PARAMS)
print(result) # returns dictionary

A query with parameters:

TEST_QUERY_WITH_PARAMS = '''
query PatientGet($patientKey: String!) {
patient(patientKey: $patientKey) {
firstName
lastName
birthDate
}
}
'''

client = _CanvasGQLClient()
result = client.query(TEST_QUERY_NO_PARAMS)
print(result)

For use in plugins, it is included in the instantiation of Protocol class. This means
it can simply be referred to as self.client in plugin code.
"""

def __init__(self) -> None:
self.client = Client(
transport=AIOHTTPTransport(url=cast(str, GRAPHQL_ENDPOINT)),
fetch_schema_from_transport=True,
)

def query(
self,
gql_query: str,
variables: dict[str, Any] | None = None,
extra_args: dict[str, Any] | None = None
) -> dict[str, Any]:
if variables is None:
query_variables = {}
else:
query_variables = variables

return self.client.execute(
gql(gql_query),
variable_values=query_variables,
extra_args=extra_args,
)


GQL_CLIENT = _CanvasGQLClient()
beaugunderson marked this conversation as resolved.
Show resolved Hide resolved
15 changes: 14 additions & 1 deletion canvas_sdk/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
import json
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from canvas_generated.messages.events_pb2 import Event


class BaseHandler:
"""
The class that all handlers inherit from.
"""

def __init__(self, event, secrets=None) -> None:
secrets: dict[str, Any]
target: str

def __init__(
self,
event: "Event",
secrets: dict[str, Any] | None = None,
) -> None:
self.event = event

try:
self.context = json.loads(event.context)
except ValueError:
self.context = {}

self.target = event.target
self.secrets = secrets or {}
10 changes: 10 additions & 0 deletions canvas_sdk/protocols/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from typing import Any

from canvas_sdk.data.client import GQL_CLIENT
from canvas_sdk.handlers.base import BaseHandler


class BaseProtocol(BaseHandler):
"""
The class that protocols inherit from.
"""

def run_gql_query(self, query: str, variables: dict | None = None) -> dict[str, Any]:
beaugunderson marked this conversation as resolved.
Show resolved Hide resolved
return GQL_CLIENT.query(query, variables=variables, extra_args={
'headers': {
'Authorization': f'Bearer {self.secrets['graphql_jwt']}',
},
})
2 changes: 2 additions & 0 deletions env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ INTEGRATION_TEST_CLIENT_ID=
INTEGRATION_TEST_CLIENT_SECRET=

PLUGIN_RUNNER_DAL_TARGET=localhost:50052

GRAPHQL_ENDPOINT=http://localhost:8000/plugins/internal-graphql
37 changes: 37 additions & 0 deletions plugin_runner/authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
from typing import cast

import arrow
from jwt import encode

ONE_DAY_IN_MINUTES = 60 * 24


def token_for_plugin(
plugin_name: str,
audience: str,
issuer: str = "plugin-runner",
jwt_signing_key: str = cast(str, os.getenv('PLUGIN_RUNNER_SIGNING_KEY')),
expiration_minutes: int = ONE_DAY_IN_MINUTES,
extra_kwargs: dict | None = None,
) -> str:
"""
Generate a JWT for the given plugin and audience.
"""
if not extra_kwargs:
extra_kwargs = {}

token = encode(
{
"plugin_name": plugin_name,
"customer_identifier": os.getenv('CUSTOMER_IDENTIFIER'),
"exp": arrow.utcnow().shift(minutes=expiration_minutes).datetime,
"aud": audience,
"iss": issuer,
**extra_kwargs,
},
jwt_signing_key,
algorithm="HS512",
)

return token
11 changes: 10 additions & 1 deletion plugin_runner/plugin_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import grpc
import statsd
from authentication import token_for_plugin
from plugin_synchronizer import publish_message
from sandbox import Sandbox

Expand Down Expand Up @@ -74,15 +75,20 @@ async def HandleEvent(self, request: Event, context: Any) -> EventResponse:
protocol_class = plugin["class"]
base_plugin_name = plugin_name.split(":")[0]

secrets = plugin.get('secrets', {})
secrets['graphql_jwt'] = token_for_plugin(plugin_name=plugin_name, audience='home')

try:
protocol = protocol_class(request, plugin.get("secrets", {}))
protocol = protocol_class(request, secrets)

compute_start_time = time.time()
_effects = await asyncio.get_running_loop().run_in_executor(None, protocol.compute)
effects = [
Effect(type=effect.type, payload=effect.payload, plugin_name=base_plugin_name)
for effect in _effects
]
compute_duration = get_duration_ms(compute_start_time)

log.info(f"{plugin_name}.compute() completed ({compute_duration} ms)")
statsd_tags = tags_to_line_protocol({"plugin": plugin_name})
self.statsd_client.timing(
Expand All @@ -92,9 +98,11 @@ async def HandleEvent(self, request: Event, context: Any) -> EventResponse:
except Exception as e:
log.error(traceback.format_exception(e))
continue

effect_list += effects

event_duration = get_duration_ms(event_start_time)

# Don't log anything if a protocol didn't actually run.
if relevant_plugins:
log.info(f"Responded to Event {event_name} ({event_duration} ms)")
Expand All @@ -103,6 +111,7 @@ async def HandleEvent(self, request: Event, context: Any) -> EventResponse:
f"plugins.event_duration_ms,{statsd_tags}",
delta=event_duration,
)

yield EventResponse(success=True, effects=effect_list)

async def ReloadPlugins(
Expand Down
Loading
Loading