Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-107862: Add property-based round-trip tests for base64 #119406

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Lib/test/support/_hypothesis_stubs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ def decorator(f):
@functools.wraps(f)
def test_function(self):
for example_args, example_kwargs in examples:
with self.subTest(*example_args, **example_kwargs):
if len(example_args) < 2:
subtest_args = example_args
else:
# subTest takes up to one positional argument.
# When there are more, display them as a tuple
subtest_args = [example_args]
with self.subTest(*subtest_args, **example_kwargs):
f(self, *example_args, **example_kwargs)

else:
Expand Down
154 changes: 154 additions & 0 deletions Lib/test/test_base64.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
import unittest
import base64
import binascii
import string
import os
from array import array
from test.support import os_helper
from test.support import script_helper

from test.support.hypothesis_helper import hypothesis


@hypothesis.strategies.composite
def altchars(draw):
"""Generate 'altchars' for base64 encoding.

Via https://docs.python.org/3/library/base64.html#base64.b64encode :

"Optional *altchars* must be a :term:`bytes-like object` of length 2 which
specifies an alternative alphabet for the ``+`` and ``/`` characters."
"""
reserved_chars = (string.digits + string.ascii_letters + "=").encode()
allowed_chars = hypothesis.strategies.sampled_from(
[n for n in range(256) if n not in reserved_chars])
return bytes(draw(hypothesis.strategies.lists(allowed_chars, min_size=2,
max_size=2, unique=True)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using @composite here, we can simply return a strategy from our function:

Suggested change
return bytes(draw(hypothesis.strategies.lists(allowed_chars, min_size=2,
max_size=2, unique=True)))
return hypothesis.strategies.lists(allowed_chars, min_size=2, max_size=2, unique=True))).map(bytes)

In more complicated scenarios, this can be a decent performance improvement (by amortizing the construction and validation over multiple draws, instead of executing the body each time we draw a value) - particularly if we @functools.cache the function, or happen to draw from it multiple times (e.g. st.lists(altchars())). Mostly though it's just an idiom to remind readers that strategies are just python values, not magic.


class LegacyBase64TestCase(unittest.TestCase):

Expand Down Expand Up @@ -60,6 +78,13 @@ def test_decodebytes(self):
eq(base64.decodebytes(array('B', b'YWJj\n')), b'abc')
self.check_type_errors(base64.decodebytes)

@hypothesis.given(payload=hypothesis.strategies.binary())
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz')
def test_bytes_encode_decode_round_trip(self, payload):
encoded = base64.encodebytes(payload)
decoded = base64.decodebytes(encoded)
self.assertEqual(payload, decoded)

def test_encode(self):
eq = self.assertEqual
from io import BytesIO, StringIO
Expand Down Expand Up @@ -88,6 +113,19 @@ def test_decode(self):
self.assertRaises(TypeError, base64.encode, BytesIO(b'YWJj\n'), StringIO())
self.assertRaises(TypeError, base64.encode, StringIO('YWJj\n'), StringIO())

@hypothesis.given(payload=hypothesis.strategies.binary())
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz')
def test_legacy_encode_decode_round_trip(self, payload):
from io import BytesIO
payload_file_r = BytesIO(payload)
encoded_file_w = BytesIO()
base64.encode(payload_file_r, encoded_file_w)
encoded_file_r = BytesIO(encoded_file_w.getvalue())
decoded_file_w = BytesIO()
base64.decode(encoded_file_r, decoded_file_w)
decoded = decoded_file_w.getvalue()
self.assertEqual(payload, decoded)


class BaseXYTestCase(unittest.TestCase):

Expand Down Expand Up @@ -268,6 +306,35 @@ def test_b64decode_invalid_chars(self):
self.assertEqual(base64.b64decode(b'++[[//]]', b'[]'), res)
self.assertEqual(base64.urlsafe_b64decode(b'++--//__'), res)

@hypothesis.given(
payload=hypothesis.strategies.binary(),
altchars=(
hypothesis.strategies.none()
| hypothesis.strategies.just(b"_-")
| altchars()),
validate=hypothesis.strategies.booleans())
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', b"_-", True)
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', b"_-", False)
def test_b64_encode_decode_round_trip(self, payload, altchars, validate):
encoded = base64.b64encode(payload, altchars=altchars)
decoded = base64.b64decode(encoded, altchars=altchars,
validate=validate)
self.assertEqual(payload, decoded)

@hypothesis.given(payload=hypothesis.strategies.binary())
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz')
def test_standard_b64_encode_decode_round_trip(self, payload):
encoded = base64.standard_b64encode(payload)
decoded = base64.standard_b64decode(encoded)
self.assertEqual(payload, decoded)

@hypothesis.given(payload=hypothesis.strategies.binary())
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz')
def test_urlsafe_b64_encode_decode_round_trip(self, payload):
encoded = base64.urlsafe_b64encode(payload)
decoded = base64.urlsafe_b64decode(encoded)
self.assertEqual(payload, decoded)

def test_b32encode(self):
eq = self.assertEqual
eq(base64.b32encode(b''), b'')
Expand Down Expand Up @@ -355,6 +422,19 @@ def test_b32decode_error(self):
with self.assertRaises(binascii.Error):
base64.b32decode(data.decode('ascii'))

@hypothesis.given(
payload=hypothesis.strategies.binary(),
casefold=hypothesis.strategies.booleans(),
map01=(
hypothesis.strategies.none()
| hypothesis.strategies.binary(min_size=1, max_size=1)))
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True, None)
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False, None)
def test_b32_encode_decode_round_trip(self, payload, casefold, map01):
encoded = base64.b32encode(payload)
decoded = base64.b32decode(encoded, casefold=casefold, map01=map01)
self.assertEqual(payload, decoded)

def test_b32hexencode(self):
test_cases = [
# to_encode, expected
Expand Down Expand Up @@ -424,6 +504,15 @@ def test_b32hexdecode_error(self):
with self.assertRaises(binascii.Error):
base64.b32hexdecode(data.decode('ascii'))

@hypothesis.given(
payload=hypothesis.strategies.binary(),
casefold=hypothesis.strategies.booleans())
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True)
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False)
def test_b32_hexencode_decode_round_trip(self, payload, casefold):
encoded = base64.b32hexencode(payload)
decoded = base64.b32hexdecode(encoded, casefold=casefold)
self.assertEqual(payload, decoded)

def test_b16encode(self):
eq = self.assertEqual
Expand Down Expand Up @@ -461,6 +550,16 @@ def test_b16decode(self):
# Incorrect "padding"
self.assertRaises(binascii.Error, base64.b16decode, '010')

@hypothesis.given(
payload=hypothesis.strategies.binary(),
casefold=hypothesis.strategies.booleans())
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True)
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False)
def test_b16_encode_decode_round_trip(self, payload, casefold):
endoded = base64.b16encode(payload)
decoded = base64.b16decode(endoded, casefold=casefold)
self.assertEqual(payload, decoded)

def test_a85encode(self):
eq = self.assertEqual

Expand Down Expand Up @@ -791,6 +890,61 @@ def test_z85decode_errors(self):
self.assertRaises(ValueError, base64.z85decode, b'%nSc')
self.assertRaises(ValueError, base64.z85decode, b'%nSc1')

def add_padding(self, payload):
"""Add the expected padding for test_?85_encode_decode_round_trip."""
if len(payload) % 4 != 0:
padding = b"\0" * ((-len(payload)) % 4)
payload = payload + padding
return payload

@hypothesis.given(
payload=hypothesis.strategies.binary(),
foldspaces=hypothesis.strategies.booleans(),
wrapcol=(
hypothesis.strategies.just(0)
| hypothesis.strategies.integers(1, 1000)),
pad=hypothesis.strategies.booleans(),
adobe=hypothesis.strategies.booleans(),
)
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False, 0, False, False)
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False, 20, True, True)
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True, 0, False, True)
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True, 20, True, False)
def test_a85_encode_decode_round_trip(
self, payload, foldspaces, wrapcol, pad, adobe
):
encoded = base64.a85encode(
payload, foldspaces=foldspaces, wrapcol=wrapcol,
pad=pad, adobe=adobe,
)
if wrapcol:
if adobe and wrapcol == 1:
# "adobe" needs wrapcol to be at least 2.
# a85decode quietly uses 2 if 1 is given; it's not worth
# loudly deprecating this behavior.
wrapcol = 2
for line in encoded.splitlines(keepends=False):
self.assertLessEqual(len(line), wrapcol)
if adobe:
self.assertTrue(encoded.startswith(b'<~'))
self.assertTrue(encoded.endswith(b'~>'))
decoded = base64.a85decode(encoded, foldspaces=foldspaces, adobe=adobe)
if pad:
payload = self.add_padding(payload)
self.assertEqual(payload, decoded)

@hypothesis.given(
payload=hypothesis.strategies.binary(),
pad=hypothesis.strategies.booleans())
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True)
@hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False)
def test_b85_encode_decode_round_trip(self, payload, pad):
encoded = base64.b85encode(payload, pad=pad)
if pad:
payload = self.add_padding(payload)
decoded = base64.b85decode(encoded)
self.assertEqual(payload, decoded)

def test_decode_nonascii_str(self):
decode_funcs = (base64.b64decode,
base64.standard_b64decode,
Expand Down
Loading