diff --git a/jupyter_server/extension/handler.py b/jupyter_server/extension/handler.py index 31377cc367..4285c415b0 100644 --- a/jupyter_server/extension/handler.py +++ b/jupyter_server/extension/handler.py @@ -5,6 +5,7 @@ from logging import Logger from typing import TYPE_CHECKING, Any, cast +from jinja2 import Template from jinja2.exceptions import TemplateNotFound from jupyter_server.base.handlers import FileFindHandler @@ -21,13 +22,14 @@ class ExtensionHandlerJinjaMixin: template rendering. """ - def get_template(self, name: str) -> str: + def get_template(self, name: str) -> Template: """Return the jinja template object for a given name""" try: env = f"{self.name}_jinja2_env" # type:ignore[attr-defined] - return cast(str, self.settings[env].get_template(name)) # type:ignore[attr-defined] + template = cast(Template, self.settings[env].get_template(name)) # type:ignore[attr-defined] + return template except TemplateNotFound: - return cast(str, super().get_template(name)) # type:ignore[misc] + return cast(Template, super().get_template(name)) # type:ignore[misc] class ExtensionHandlerMixin: @@ -81,6 +83,20 @@ def server_config(self) -> Config: def base_url(self) -> str: return cast(str, self.settings.get("base_url", "/")) + def render_template(self, name: str, **ns) -> str: + """Override render template to handle static_paths + + If render_template is called with a template from the base environment + (e.g. default error pages) + make sure our extension-specific static_url is _not_ used. + """ + template = cast(Template, self.get_template(name)) # type:ignore[attr-defined] + ns.update(self.template_namespace) # type:ignore[attr-defined] + if template.environment is self.settings["jinja2_env"]: + # default template environment, use default static_url + ns["static_url"] = super().static_url # type:ignore[misc] + return cast(str, template.render(**ns)) + @property def static_url_prefix(self) -> str: return self.extensionapp.static_url_prefix diff --git a/tests/extension/mockextensions/__init__.py b/tests/extension/mockextensions/__init__.py index a25d9fc670..2c01cfe266 100644 --- a/tests/extension/mockextensions/__init__.py +++ b/tests/extension/mockextensions/__init__.py @@ -2,7 +2,7 @@ to load in various tests. """ -from .app import MockExtensionApp +from .app import MockExtensionApp, MockExtensionNoTemplateApp # Function that makes these extensions discoverable @@ -13,6 +13,10 @@ def _jupyter_server_extension_points(): "module": "tests.extension.mockextensions.app", "app": MockExtensionApp, }, + { + "module": "tests.extension.mockextensions.app", + "app": MockExtensionNoTemplateApp, + }, {"module": "tests.extension.mockextensions.mock1"}, {"module": "tests.extension.mockextensions.mock2"}, {"module": "tests.extension.mockextensions.mock3"}, diff --git a/tests/extension/mockextensions/app.py b/tests/extension/mockextensions/app.py index c4f7af099c..5546195593 100644 --- a/tests/extension/mockextensions/app.py +++ b/tests/extension/mockextensions/app.py @@ -4,6 +4,7 @@ from jupyter_events import EventLogger from jupyter_events.schema_registry import SchemaRegistryException +from tornado import web from traitlets import List, Unicode from jupyter_server.base.handlers import JupyterHandler @@ -44,6 +45,11 @@ def get(self): self.write(self.render_template("index.html")) +class MockExtensionErrorHandler(ExtensionHandlerMixin, JupyterHandler): + def get(self): + raise web.HTTPError(418) + + class MockExtensionApp(ExtensionAppJinjaMixin, ExtensionApp): name = "mockextension" template_paths: List[str] = List().tag(config=True) # type:ignore[assignment] @@ -51,7 +57,12 @@ class MockExtensionApp(ExtensionAppJinjaMixin, ExtensionApp): mock_trait = Unicode("mock trait", config=True) loaded = False - serverapp_config = {"jpserver_extensions": {"tests.extension.mockextensions.mock1": True}} + serverapp_config = { + "jpserver_extensions": { + "tests.extension.mockextensions.mock1": True, + "tests.extension.mockextensions.app.mockextension_notemplate": True, + } + } @staticmethod def get_extension_package(): @@ -69,6 +80,20 @@ def initialize_settings(self): def initialize_handlers(self): self.handlers.append(("/mock", MockExtensionHandler)) self.handlers.append(("/mock_template", MockExtensionTemplateHandler)) + self.handlers.append(("/mock_error_template", MockExtensionErrorHandler)) + self.loaded = True + + +class MockExtensionNoTemplateApp(ExtensionApp): + name = "mockextension_notemplate" + loaded = False + + @staticmethod + def get_extension_package(): + return "tests.extension.mockextensions" + + def initialize_handlers(self): + self.handlers.append(("/mock_error_notemplate", MockExtensionErrorHandler)) self.loaded = True diff --git a/tests/extension/test_app.py b/tests/extension/test_app.py index 965fe2ca16..ae324756ec 100644 --- a/tests/extension/test_app.py +++ b/tests/extension/test_app.py @@ -171,12 +171,14 @@ async def _stop(*args): "Shutting down 2 extensions", "jupyter_server_terminals | extension app 'jupyter_server_terminals' stopping", f"{extension_name} | extension app 'mockextension' stopping", + f"{extension_name} | extension app 'mockextension_notemplate' stopping", "jupyter_server_terminals | extension app 'jupyter_server_terminals' stopped", f"{extension_name} | extension app 'mockextension' stopped", + f"{extension_name} | extension app 'mockextension_notemplate' stopped", } # check the shutdown method was called twice - assert calls == 2 + assert calls == 3 async def test_events(jp_serverapp, jp_fetch): diff --git a/tests/extension/test_handler.py b/tests/extension/test_handler.py index af4d0568ec..050f734eb4 100644 --- a/tests/extension/test_handler.py +++ b/tests/extension/test_handler.py @@ -1,4 +1,7 @@ +from html.parser import HTMLParser + import pytest +from tornado.httpclient import HTTPClientError @pytest.fixture @@ -118,3 +121,69 @@ async def test_base_url(jp_fetch, jp_server_config, jp_base_url): assert r.code == 200 body = r.body.decode() assert "mock static content" in body + + +class StylesheetFinder(HTMLParser): + """Minimal HTML parser to find iframe.src attr""" + + def __init__(self): + super().__init__() + self.stylesheets = [] + self.body_chunks = [] + self.in_head = False + self.in_body = False + self.in_script = False + + def handle_starttag(self, tag, attrs): + tag = tag.lower() + if tag == "head": + self.in_head = True + elif tag == "body": + self.in_body = True + elif tag == "script": + self.in_script = True + elif self.in_head and tag.lower() == "link": + attr_dict = dict(attrs) + if attr_dict.get("rel", "").lower() == "stylesheet": + self.stylesheets.append(attr_dict["href"]) + + def handle_endtag(self, tag): + if tag == "head": + self.in_head = False + if tag == "body": + self.in_body = False + if tag == "script": + self.in_script = False + + def handle_data(self, data): + if self.in_body and not self.in_script: + data = data.strip() + if data: + self.body_chunks.append(data) + + +def find_stylesheets_body(html): + """Find the href= attr of stylesheets + + and body text of an HTML document + + stylesheets are used to test static_url prefix + """ + finder = StylesheetFinder() + finder.feed(html) + return (finder.stylesheets, "\n".join(finder.body_chunks)) + + +@pytest.mark.parametrize("error_url", ["mock_error_template", "mock_error_notemplate"]) +async def test_error_render(jp_fetch, jp_serverapp, jp_base_url, error_url): + with pytest.raises(HTTPClientError) as e: + await jp_fetch(error_url, method="GET") + r = e.value.response + assert r.code == 418 + assert r.headers["Content-Type"] == "text/html" + html = r.body.decode("utf8") + stylesheets, body = find_stylesheets_body(html) + static_prefix = f"{jp_base_url}static/" + assert stylesheets + assert all(stylesheet.startswith(static_prefix) for stylesheet in stylesheets) + assert str(r.code) in body