Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sapphire] Update tests #403

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions sapphire/conftest.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
"""
Sapphire unit test fixtures
"""
import hashlib
import logging
import random
import re
import socket
import sys
import threading
from hashlib import sha1
from http.client import BadStatusLine
from logging import getLogger
from random import shuffle
from re import match
from sys import exc_info
from threading import Event, Thread
from urllib.error import HTTPError, URLError
from urllib.request import urlopen

import pytest

LOG = logging.getLogger(__name__)
LOG = getLogger(__name__)


@pytest.fixture
Expand All @@ -32,9 +32,9 @@ def __init__(self, rx_size=0x10000):
self.rx_size = rx_size
# use this event to add delays instead of sleep
# this will help avoid shutdown hangs when there are test failures
self._closed = threading.Event()
self._closed = Event()
self._closed.set()
self._idle = threading.Event()
self._idle = Event()
self._idle.set()

def close(self):
Expand All @@ -60,7 +60,7 @@ def launch(
assert self.thread is None
self._closed.clear()
self._idle.clear()
self.thread = threading.Thread(
self.thread = Thread(
target=self._handle_request,
args=(addr, port, files_to_serve),
kwargs={
Expand Down Expand Up @@ -90,16 +90,16 @@ def _handle_request(
indexes = list(range(len(files_to_request)))
if not in_order:
# request files in random order
random.shuffle(indexes)
shuffle(indexes)
for index in indexes:
t_file = files_to_request[index]
with t_file.lock:
# check if the file has been served
if skip_served and t_file.code is not None:
continue
# if t_file.md5_org is set to anything but None the test client
# will calculate the md5 hash
data_hash = hashlib.md5() if t_file.md5_org is not None else None
# if t_file.hash_org is set to anything but None the test client
# will calculate the hash
data_hash = sha1() if t_file.hash_org is not None else None
try:
if t_file.custom_request is None:
with urlopen(
Expand Down Expand Up @@ -144,7 +144,7 @@ def _handle_request(
data_length = len(data)
try:
resp_code = int(
re.match(
match(
r"HTTP/1\.\d\s(?P<code>\d+)\s", data.decode("ascii")
).group("code")
)
Expand All @@ -162,15 +162,15 @@ def _handle_request(
if resp_code == 200:
t_file.content_type = content_type
t_file.len_srv = data_length
t_file.md5_srv = data_hash
t_file.hash_srv = data_hash

except HTTPError as http_err:
with t_file.lock:
t_file.requested += 1
if not skip_served or t_file.code is None:
t_file.code = http_err.code
except (BadStatusLine, OSError, URLError):
exc_type, exc_obj, exc_tb = sys.exc_info()
exc_type, exc_obj, exc_tb = exc_info()
# set code to zero to help testing
with t_file.lock:
LOG.debug(
Expand Down
30 changes: 15 additions & 15 deletions sapphire/test_sapphire.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
# pylint: disable=protected-access

import socket
from hashlib import md5
from hashlib import sha1
from itertools import repeat
from os import urandom
from pathlib import Path
from platform import system
from random import choices, getrandbits
from threading import Lock
from urllib.parse import quote, urlparse
from urllib.parse import quote, urlsplit

from pytest import mark, raises

Expand All @@ -35,10 +35,10 @@ def __init__(self, url, url_prefix=None):
self.len_org = 0 # original file length
self.len_srv = 0 # served file length
self.lock = Lock()
self.md5_org = None
self.md5_srv = None
self.hash_org = None
self.hash_srv = None
self.requested = 0 # number of time file was requested
url = urlparse(self.file.replace("\\", "/"))
url = urlsplit(self.file.replace("\\", "/"))
self.url = (
"?".join((quote(url.path), url.query)) if url.query else quote(url.path)
)
Expand All @@ -51,7 +51,7 @@ def _create_test(fname, path, data=b"Test!", calc_hash=False, url_prefix=None):
test.len_org = out_fp.tell()
if calc_hash:
out_fp.seek(0)
test.md5_org = md5(out_fp.read()).hexdigest()
test.hash_org = sha1(out_fp.read()).hexdigest()
return test


Expand Down Expand Up @@ -184,7 +184,7 @@ def test_sapphire_06(client, tmp_path):
test["file"] = _TestFile(test["name"])
t_data = "".join(choices("ABCD1234", k=test["size"])).encode("ascii")
(tmp_path / test["file"].file).write_bytes(t_data)
test["file"].md5_org = md5(t_data).hexdigest()
test["file"].hash_org = sha1(t_data).hexdigest()
required = [test["file"].file for test in tests]
with Sapphire(timeout=10) as serv:
client.launch("127.0.0.1", serv.port, [test["file"] for test in tests])
Expand All @@ -195,27 +195,27 @@ def test_sapphire_06(client, tmp_path):
for test in tests:
assert test["file"].code == 200
assert test["file"].len_srv == test["size"]
assert test["file"].md5_srv == test["file"].md5_org
assert test["file"].hash_srv == test["file"].hash_org


def test_sapphire_07(client, tmp_path):
"""test serving a large (100MB) file"""
t_file = _TestFile("test_case.html")
data_hash = md5()
data_hash = sha1()
with (tmp_path / t_file.file).open("wb") as test_fp:
# write 100MB of 'A'
data = b"A" * (100 * 1024) # 100KB of 'A'
for _ in range(1024):
test_fp.write(data)
data_hash.update(data)
t_file.md5_org = data_hash.hexdigest()
t_file.hash_org = data_hash.hexdigest()
with Sapphire(timeout=10) as serv:
client.launch("127.0.0.1", serv.port, [t_file])
assert serv.serve_path(tmp_path, required_files=[t_file.file])[0] == Served.ALL
assert client.wait(timeout=10)
assert t_file.code == 200
assert t_file.len_srv == (100 * 1024 * 1024)
assert t_file.md5_srv == t_file.md5_org
assert t_file.hash_srv == t_file.hash_org


def test_sapphire_08(client, tmp_path):
Expand All @@ -227,7 +227,7 @@ def test_sapphire_08(client, tmp_path):
assert client.wait(timeout=10)
assert t_file.code == 200
assert t_file.len_srv == t_file.len_org
assert t_file.md5_srv == t_file.md5_org
assert t_file.hash_srv == t_file.hash_org


def test_sapphire_09():
Expand Down Expand Up @@ -431,7 +431,7 @@ def dr_callback(data):
# create files
test_dr = _TestFile(request)
test_dr.len_org = len(_data)
test_dr.md5_org = md5(_data).hexdigest()
test_dr.hash_org = sha1(_data).hexdigest()
test = _create_test("test_case.html", tmp_path)
if required:
req_files = []
Expand All @@ -452,7 +452,7 @@ def dr_callback(data):
assert test.len_srv == test.len_org
assert test_dr.code == 200
assert test_dr.len_srv == test_dr.len_org
assert test_dr.md5_srv == test_dr.md5_org
assert test_dr.hash_srv == test_dr.hash_org


def test_sapphire_16(client_factory, tmp_path):
Expand Down Expand Up @@ -608,7 +608,7 @@ def test_sapphire_22(client, tmp_path):
assert client.wait(timeout=10)
assert t_file.code == 200
assert t_file.len_srv == t_file.len_org
assert t_file.md5_srv == t_file.md5_org
assert t_file.hash_srv == t_file.hash_org


def test_sapphire_23(client, tmp_path):
Expand Down
36 changes: 18 additions & 18 deletions sapphire/test_server_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import pytest
from pytest import raises

from .server_map import InvalidURLError, MapCollisionError, Resource, ServerMap

Expand All @@ -28,23 +28,23 @@ def test_servermap_02(tmp_path):
assert len(srv_map.dynamic) == 2
assert not srv_map.include
assert not srv_map.redirect
with pytest.raises(TypeError, match="callback must be callable"):
with raises(TypeError, match="callback must be callable"):
srv_map.set_dynamic_response("x", None)
with pytest.raises(TypeError, match="callback requires 1 argument"):
with raises(TypeError, match="callback requires 1 argument"):
srv_map.set_dynamic_response("x", lambda: 0)
with pytest.raises(TypeError, match="mime_type must be of type 'str'"):
with raises(TypeError, match="mime_type must be of type 'str'"):
srv_map.set_dynamic_response("x", lambda _: 0, None)
# test detecting collisions
with pytest.raises(MapCollisionError):
with raises(MapCollisionError):
srv_map.set_include("url_01", str(tmp_path))
with pytest.raises(MapCollisionError):
with raises(MapCollisionError):
srv_map.set_redirect("url_01", "test_file")


def test_servermap_03(tmp_path):
"""test ServerMap includes"""
srv_map = ServerMap()
with pytest.raises(IOError, match="Include path not found: no_dir"):
with raises(IOError, match="Include path not found: no_dir"):
srv_map.set_include("test_url", "no_dir")
assert not srv_map.include
srv_map.set_include("url_01", str(tmp_path))
Expand All @@ -64,16 +64,16 @@ def test_servermap_03(tmp_path):
assert not srv_map.dynamic
assert not srv_map.redirect
# test detecting collisions
with pytest.raises(MapCollisionError, match="URL collision on 'url_01'"):
with raises(MapCollisionError, match="URL collision on 'url_01'"):
srv_map.set_redirect("url_01", "test_file")
with pytest.raises(MapCollisionError):
with raises(MapCollisionError):
srv_map.set_dynamic_response("url_01", lambda _: 0, mime_type="test/type")
# test overlapping includes
with pytest.raises(MapCollisionError, match=r"'url_01' and '\w+' include"):
with raises(MapCollisionError, match=r"'url_01' and '\w+' include"):
srv_map.set_include("url_01", str(tmp_path))
inc3 = tmp_path / "includes" / "b" / "c"
inc3.mkdir()
with pytest.raises(MapCollisionError, match=r"'url_01' and '\w+' include"):
with raises(MapCollisionError, match=r"'url_01' and '\w+' include"):
srv_map.set_include("url_01", str(inc3))


Expand All @@ -90,14 +90,14 @@ def test_servermap_04(tmp_path):
assert not srv_map.redirect["url_02"].required
assert not srv_map.dynamic
assert not srv_map.include
with pytest.raises(TypeError, match="target must not be an empty string"):
with raises(TypeError, match="target must not be an empty string"):
srv_map.set_redirect("x", "")
with pytest.raises(TypeError, match="target must be of type 'str'"):
with raises(TypeError, match="target must be of type 'str'"):
srv_map.set_redirect("x", None)
# test detecting collisions
with pytest.raises(MapCollisionError):
with raises(MapCollisionError):
srv_map.set_include("url_01", str(tmp_path))
with pytest.raises(MapCollisionError):
with raises(MapCollisionError):
srv_map.set_dynamic_response("url_01", lambda _: 0, mime_type="test/type")


Expand All @@ -107,11 +107,11 @@ def test_servermap_05():
assert ServerMap._check_url("test") == "test"
assert ServerMap._check_url("") == ""
# only alphanumeric is allowed
with pytest.raises(InvalidURLError):
with raises(InvalidURLError):
ServerMap._check_url("asd!@#")
# '..' should not be accepted
with pytest.raises(InvalidURLError):
with raises(InvalidURLError):
ServerMap._check_url("/..")
# cannot map more than one '/' deep
with pytest.raises(InvalidURLError):
with raises(InvalidURLError):
ServerMap._check_url("/test/test")
Loading