Skip to content

Commit

Permalink
extensions: render default templates with default static_url (#1435)
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk authored Dec 20, 2024
1 parent e74da85 commit 2195971
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 6 deletions.
22 changes: 19 additions & 3 deletions jupyter_server/extension/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from logging import Logger
from typing import TYPE_CHECKING, Any, cast

from jinja2 import Template
from jinja2.exceptions import TemplateNotFound

from jupyter_server.base.handlers import FileFindHandler
Expand All @@ -21,13 +22,14 @@ class ExtensionHandlerJinjaMixin:
template rendering.
"""

def get_template(self, name: str) -> str:
def get_template(self, name: str) -> Template:
"""Return the jinja template object for a given name"""
try:
env = f"{self.name}_jinja2_env" # type:ignore[attr-defined]
return cast(str, self.settings[env].get_template(name)) # type:ignore[attr-defined]
template = cast(Template, self.settings[env].get_template(name)) # type:ignore[attr-defined]
return template
except TemplateNotFound:
return cast(str, super().get_template(name)) # type:ignore[misc]
return cast(Template, super().get_template(name)) # type:ignore[misc]


class ExtensionHandlerMixin:
Expand Down Expand Up @@ -81,6 +83,20 @@ def server_config(self) -> Config:
def base_url(self) -> str:
return cast(str, self.settings.get("base_url", "/"))

def render_template(self, name: str, **ns) -> str:
"""Override render template to handle static_paths
If render_template is called with a template from the base environment
(e.g. default error pages)
make sure our extension-specific static_url is _not_ used.
"""
template = cast(Template, self.get_template(name)) # type:ignore[attr-defined]
ns.update(self.template_namespace) # type:ignore[attr-defined]
if template.environment is self.settings["jinja2_env"]:
# default template environment, use default static_url
ns["static_url"] = super().static_url # type:ignore[misc]
return cast(str, template.render(**ns))

@property
def static_url_prefix(self) -> str:
return self.extensionapp.static_url_prefix
Expand Down
6 changes: 5 additions & 1 deletion tests/extension/mockextensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
to load in various tests.
"""

from .app import MockExtensionApp
from .app import MockExtensionApp, MockExtensionNoTemplateApp


# Function that makes these extensions discoverable
Expand All @@ -13,6 +13,10 @@ def _jupyter_server_extension_points():
"module": "tests.extension.mockextensions.app",
"app": MockExtensionApp,
},
{
"module": "tests.extension.mockextensions.app",
"app": MockExtensionNoTemplateApp,
},
{"module": "tests.extension.mockextensions.mock1"},
{"module": "tests.extension.mockextensions.mock2"},
{"module": "tests.extension.mockextensions.mock3"},
Expand Down
27 changes: 26 additions & 1 deletion tests/extension/mockextensions/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from jupyter_events import EventLogger
from jupyter_events.schema_registry import SchemaRegistryException
from tornado import web
from traitlets import List, Unicode

from jupyter_server.base.handlers import JupyterHandler
Expand Down Expand Up @@ -44,14 +45,24 @@ def get(self):
self.write(self.render_template("index.html"))


class MockExtensionErrorHandler(ExtensionHandlerMixin, JupyterHandler):
def get(self):
raise web.HTTPError(418)


class MockExtensionApp(ExtensionAppJinjaMixin, ExtensionApp):
name = "mockextension"
template_paths: List[str] = List().tag(config=True) # type:ignore[assignment]
static_paths = [STATIC_PATH] # type:ignore[assignment]
mock_trait = Unicode("mock trait", config=True)
loaded = False

serverapp_config = {"jpserver_extensions": {"tests.extension.mockextensions.mock1": True}}
serverapp_config = {
"jpserver_extensions": {
"tests.extension.mockextensions.mock1": True,
"tests.extension.mockextensions.app.mockextension_notemplate": True,
}
}

@staticmethod
def get_extension_package():
Expand All @@ -69,6 +80,20 @@ def initialize_settings(self):
def initialize_handlers(self):
self.handlers.append(("/mock", MockExtensionHandler))
self.handlers.append(("/mock_template", MockExtensionTemplateHandler))
self.handlers.append(("/mock_error_template", MockExtensionErrorHandler))
self.loaded = True


class MockExtensionNoTemplateApp(ExtensionApp):
name = "mockextension_notemplate"
loaded = False

@staticmethod
def get_extension_package():
return "tests.extension.mockextensions"

def initialize_handlers(self):
self.handlers.append(("/mock_error_notemplate", MockExtensionErrorHandler))
self.loaded = True


Expand Down
4 changes: 3 additions & 1 deletion tests/extension/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,14 @@ async def _stop(*args):
"Shutting down 2 extensions",
"jupyter_server_terminals | extension app 'jupyter_server_terminals' stopping",
f"{extension_name} | extension app 'mockextension' stopping",
f"{extension_name} | extension app 'mockextension_notemplate' stopping",
"jupyter_server_terminals | extension app 'jupyter_server_terminals' stopped",
f"{extension_name} | extension app 'mockextension' stopped",
f"{extension_name} | extension app 'mockextension_notemplate' stopped",
}

# check the shutdown method was called twice
assert calls == 2
assert calls == 3


async def test_events(jp_serverapp, jp_fetch):
Expand Down
69 changes: 69 additions & 0 deletions tests/extension/test_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from html.parser import HTMLParser

import pytest
from tornado.httpclient import HTTPClientError


@pytest.fixture
Expand Down Expand Up @@ -118,3 +121,69 @@ async def test_base_url(jp_fetch, jp_server_config, jp_base_url):
assert r.code == 200
body = r.body.decode()
assert "mock static content" in body


class StylesheetFinder(HTMLParser):
"""Minimal HTML parser to find iframe.src attr"""

def __init__(self):
super().__init__()
self.stylesheets = []
self.body_chunks = []
self.in_head = False
self.in_body = False
self.in_script = False

def handle_starttag(self, tag, attrs):
tag = tag.lower()
if tag == "head":
self.in_head = True
elif tag == "body":
self.in_body = True
elif tag == "script":
self.in_script = True
elif self.in_head and tag.lower() == "link":
attr_dict = dict(attrs)
if attr_dict.get("rel", "").lower() == "stylesheet":
self.stylesheets.append(attr_dict["href"])

def handle_endtag(self, tag):
if tag == "head":
self.in_head = False
if tag == "body":
self.in_body = False
if tag == "script":
self.in_script = False

def handle_data(self, data):
if self.in_body and not self.in_script:
data = data.strip()
if data:
self.body_chunks.append(data)


def find_stylesheets_body(html):
"""Find the href= attr of stylesheets
and body text of an HTML document
stylesheets are used to test static_url prefix
"""
finder = StylesheetFinder()
finder.feed(html)
return (finder.stylesheets, "\n".join(finder.body_chunks))


@pytest.mark.parametrize("error_url", ["mock_error_template", "mock_error_notemplate"])
async def test_error_render(jp_fetch, jp_serverapp, jp_base_url, error_url):
with pytest.raises(HTTPClientError) as e:
await jp_fetch(error_url, method="GET")
r = e.value.response
assert r.code == 418
assert r.headers["Content-Type"] == "text/html"
html = r.body.decode("utf8")
stylesheets, body = find_stylesheets_body(html)
static_prefix = f"{jp_base_url}static/"
assert stylesheets
assert all(stylesheet.startswith(static_prefix) for stylesheet in stylesheets)
assert str(r.code) in body

0 comments on commit 2195971

Please sign in to comment.