Skip to content

Commit

Permalink
Embed default html templates (#87)
Browse files Browse the repository at this point in the history
This embeds test.html and module_test.html into pytest-pyodide wheel. So we don't need to bundle them into pyodide distribution.
  • Loading branch information
ryanking13 authored May 10, 2023
1 parent 54de3f9 commit 8c4a468
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 39 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,11 @@ jobs:
- name: Install pytest-pyodide
shell: bash -l {0}
run: |
${{needs.get_versions.outputs.pythonexec}} -m pip install pytest-cov
if [ -d "pytest_pyodide" ]; then
# Currently we only install the package for dependencies.
# We then uninstall it otherwise tests fails due to pytest hook being
# registered twice.
${{needs.get_versions.outputs.pythonexec}} -m pip install -e .
${{needs.get_versions.outputs.pythonexec}} -m pip install -e ".[test]"
${{needs.get_versions.outputs.pythonexec}} -m pip uninstall -y pytest-pyodide
else
${{needs.get_versions.outputs.pythonexec}} -m pip install pytest-pyodide
Expand Down
Empty file.
26 changes: 26 additions & 0 deletions pytest_pyodide/_templates/module_test.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<!-- Bootstrap HTML for running the unit tests. -->
<!doctype html>
<html>
<head>
<script type="text/javascript">
window.logs = [];
console.log = function (message) {
window.logs.push(message);
};
console.warn = function (message) {
window.logs.push(message);
};
console.info = function (message) {
window.logs.push(message);
};
console.error = function (message) {
window.logs.push(message);
};
</script>
<script type="module">
import { loadPyodide } from "./pyodide.mjs";
window.loadPyodide = loadPyodide;
</script>
</head>
<body></body>
</html>
27 changes: 27 additions & 0 deletions pytest_pyodide/_templates/test.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
<!-- Bootstrap HTML for running the unit tests. -->
<!doctype html>
<html>
<head>
<title>pyodide</title>
<script type="text/javascript">
window.logs = [];
console.log = function (message) {
window.logs.push(message);
};
console.warn = function (message) {
window.logs.push(message);
};
console.info = function (message) {
window.logs.push(message);
};
console.error = function (message) {
window.logs.push(message);
};
console.debug = function (message) {
window.logs.push(message);
};
</script>
<script src="./pyodide.js"></script>
</head>
<body></body>
</html>
88 changes: 64 additions & 24 deletions pytest_pyodide/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import functools
import http.server
import multiprocessing
import os
Expand All @@ -8,6 +9,68 @@
import socketserver
import sys
import tempfile
from io import BytesIO


@functools.cache
def _default_templates() -> dict[str, bytes]:
templates_dir = pathlib.Path(__file__).parent / "_templates"

templates = {}
for template_file in templates_dir.glob("*.html"):
templates[f"/{template_file.name}"] = template_file.read_bytes()

return templates


class DefaultHandler(http.server.SimpleHTTPRequestHandler):
default_templates = _default_templates()

def __init__(self, *args, **kwargs):
self.extra_headers = kwargs.pop("extra_headers", {})
super().__init__(*args, **kwargs)

def log_message(self, format_, *args):
print(
"[%s] source: %s:%s - %s"
% (
self.log_date_time_string(),
*self.client_address,
format_ % args,
)
)

def get_template(self, path: str) -> bytes | None:
"""
Return the content of the template if it exists, None otherwise
This method is used to serve the default templates, and can be
overridden to serve custom templates.
"""
return self.default_templates.get(path)

def do_GET(self):
body = self.get_template(self.path)
if body:
self.send_response(200)
self.send_header("Content-type", "text/html; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
self.end_headers()

self.copyfile(BytesIO(body), self.wfile)
else:
return super().do_GET()

def end_headers(self):
# Enable Cross-Origin Resource Sharing (CORS)
self.send_header("Access-Control-Allow-Origin", "*")
for k, v in self.extra_headers.items():
self.send_header(k, v)
if len(self.extra_headers) > 0:
joined_headers = ",".join(self.extra_headers.keys())
# if you don't send this, CORS blocks custom headers in javascript
self.send_header("Access-Control-Expose-Headers", joined_headers)
super().end_headers()


@contextlib.contextmanager
Expand Down Expand Up @@ -55,30 +118,7 @@ def run_web_server(q, log_filepath, dist_dir, extra_headers, handler_cls):
sys.stderr = log_fh

if not handler_cls:

class DefaultHandler(http.server.SimpleHTTPRequestHandler):
def log_message(self, format_, *args):
print(
"[%s] source: %s:%s - %s"
% (
self.log_date_time_string(),
*self.client_address,
format_ % args,
)
)

def end_headers(self):
# Enable Cross-Origin Resource Sharing (CORS)
self.send_header("Access-Control-Allow-Origin", "*")
for k, v in extra_headers.items():
self.send_header(k, v)
if len(extra_headers) > 0:
joined_headers = ",".join(extra_headers.keys())
# if you don't send this, CORS blocks custom headers in javascript
self.send_header("Access-Control-Expose-Headers", joined_headers)
super().end_headers()

handler_cls = DefaultHandler
handler_cls = functools.partial(DefaultHandler, extra_headers=extra_headers)

with socketserver.TCPServer(("", 0), handler_cls) as httpd:
host, port = httpd.server_address
Expand Down
15 changes: 10 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,18 @@ install_requires =
playwright
pyodide-tblib # Forked to add https://github.com/ionelmc/python-tblib/pull/66

[options.extras_require]
test =
pytest-cov
build
requests

# This is required to add node driver code to the package.
[options.package_data]
pytest_pyodide = *.js
pytest_pyodide =
*.js
pytest_pyodide._templates =
*.html

# pytest will look up `pytest11` entrypoints to find plugins
# See: https://docs.pytest.org/en/7.1.x/how-to/writing_plugins.html#making-your-plugin-installable-by-others
Expand All @@ -40,13 +49,9 @@ pytest11 =
pytest_pyodide = pytest_pyodide.fixture
pytest_pyodide_hook = pytest_pyodide.hook

[options.packages.find]
where = .

[tool:pytest]
asyncio_mode = strict
addopts =
--tb=short
--doctest-modules
--cov=pytest_pyodide --cov-report xml
testpaths = tests
57 changes: 49 additions & 8 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import http.server
import urllib.request
from http import HTTPStatus

from pytest_pyodide.server import spawn_web_server
import requests

from pytest_pyodide.server import DefaultHandler, _default_templates, spawn_web_server


def test_spawn_web_server_with_params(tmp_path):
Expand All @@ -13,14 +14,54 @@ def test_spawn_web_server_with_params(tmp_path):
port,
log_path,
):
res = urllib.request.urlopen(f"http://{hostname}:{port}/index.txt")
assert res.status == 200
res = requests.get(f"http://{hostname}:{port}/index.txt")
assert res.ok
assert res.headers
assert res.read() == b"a"
assert res.content == b"a"
assert res.headers["Access-Control-Allow-Origin"] == "*"
assert res.headers.get("Custom-Header") == "42"


def test_spawn_web_server_default_templates(tmp_path):
default_templates = _default_templates()

with spawn_web_server(tmp_path) as (hostname, port, _):
for path, content in default_templates.items():
res = requests.get(f"http://{hostname}:{port}{path}")
assert res.ok
assert res.headers
assert res.content == content
assert res.headers["Access-Control-Allow-Origin"] == "*"


class CustomTemplateHandler(DefaultHandler):
def get_template(self, path: str) -> bytes | None:
if path == "/index.txt":
return b"hello world"

return super().get_template(path)


def test_spawn_web_server_custom_templates(tmp_path):
default_templates = _default_templates()

with spawn_web_server(tmp_path, handler_cls=CustomTemplateHandler) as (
hostname,
port,
_,
):
for path, content in default_templates.items():
res = requests.get(f"http://{hostname}:{port}{path}")
assert res.ok
assert res.headers
assert res.content == content
assert res.headers["Access-Control-Allow-Origin"] == "*"

res = requests.get(f"http://{hostname}:{port}/index.txt")
assert res.ok
assert res.content == b"hello world"


class HelloWorldHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
self.send_response(HTTPStatus.OK)
Expand All @@ -31,6 +72,6 @@ def do_GET(self):
def test_custom_handler(tmp_path):
with spawn_web_server(tmp_path, handler_cls=HelloWorldHandler) as server:
hostname, port, _ = server
res = urllib.request.urlopen(f"http://{hostname}:{port}/index.txt")
assert res.status == 200
assert res.read() == b"hello world"
res = requests.get(f"http://{hostname}:{port}/index.txt")
assert res.ok
assert res.content == b"hello world"

0 comments on commit 8c4a468

Please sign in to comment.