diff --git a/crypto_condor/primitives/AES.py b/crypto_condor/primitives/AES.py index 2d1a75b..022a04c 100644 --- a/crypto_condor/primitives/AES.py +++ b/crypto_condor/primitives/AES.py @@ -1912,13 +1912,15 @@ def verify_file(filename: str, mode: Mode, operation: Operation) -> Results: # --------------------------- Lib hook functions -------------------------------------- + + def _test_lib_enc( ffi: cffi.FFI, lib, function: str, mode: Mode, key_length: KeyLength -) -> list[Results]: +) -> ResultsDict: """Tests CC_AES_encrypt. Returns: - A list of results, from the ResultsDict returned by :func:`test`. + The dictionary of results returned by :func:`test`. """ logger.info("Testing harness function %s", function) @@ -1936,17 +1938,16 @@ def _enc(key: bytes, plaintext: bytes, iv: bytes = b"") -> bytes: enc(buf, len(plaintext), _key, len(key), _iv, len(iv)) return bytes(buf) - rd = test(_enc, None, mode, key_length) # type: ignore - return list(rd.values()) + return test(_enc, None, mode, key_length) # type: ignore def _test_lib_dec( ffi: cffi.FFI, lib, function: str, mode: Mode, key_length: KeyLength -) -> list[Results]: +) -> ResultsDict: """Tests CC_AES_decrypt. Returns: - A list of results, from the ResultsDict returned by :func:`test`. + The dictionary of results returned by :func:`test`. """ logger.info("Testing harness function %s", function) @@ -1964,17 +1965,16 @@ def _dec(key: bytes, ciphertext: bytes, iv: bytes = b"") -> bytes: dec(buf, len(ciphertext), _key, len(key), _iv, len(iv)) return bytes(buf) - rd = test(None, _dec, mode, key_length) # type: ignore - return list(rd.values()) + return test(None, _dec, mode, key_length) # type: ignore def _test_lib_enc_aead( ffi: cffi.FFI, lib, function: str, mode: Mode, key_length: KeyLength -) -> list[Results]: +) -> ResultsDict: """Tests CC_AES_AEAD_encrypt. Returns: - A list of results, from the ResultsDict returned by :func:`test`. + The dictionary of results returned by :func:`test`. """ logger.info("Testing harness function %s", function) @@ -2017,17 +2017,16 @@ def _enc( ) return (bytes(buf), bytes(mac_buf)) - rd = test(_enc, None, mode, key_length) # type: ignore[arg-type] - return list(rd.values()) + return test(_enc, None, mode, key_length) # type: ignore[arg-type] def _test_lib_dec_aead( ffi: cffi.FFI, lib, function: str, mode: Mode, key_length: KeyLength -) -> list[Results]: +) -> ResultsDict: """Tests CC_AES_AEAD_decrypt. Returns: - A list of results, from the ResultsDict returned by :func:`test`. + The dictionary of results returned by :func:`test`. """ logger.info("Testing harness function %s", function) @@ -2075,8 +2074,7 @@ def _dec( else: raise ValueError(f"Invalid returned value {rc} (expected 0 or -1)") - rd = test(None, _dec, mode, key_length) # type: ignore[arg-type] - return list(rd.values()) + return test(None, _dec, mode, key_length) # type: ignore[arg-type] def test_lib(ffi: cffi.FFI, lib, functions: list[str]) -> ResultsDict: @@ -2145,21 +2143,13 @@ def test_lib(ffi: cffi.FFI, lib, functions: list[str]) -> ResultsDict: # If the condition is false, it continues searching for a pattern. match (mode, operation): case (mode, Operation.ENCRYPT) if mode in Mode.classic_modes(): - results[f"AES/test_lib_enc/{str(mode)}"] = _test_lib_enc( - ffi, lib, function, mode, key_size - ) + results |= _test_lib_enc(ffi, lib, function, mode, key_size) case (mode, Operation.ENCRYPT): - results[f"AES/test_lib_enc_aead/{str(mode)}"] = _test_lib_enc_aead( - ffi, lib, function, mode, key_size - ) + results |= _test_lib_enc_aead(ffi, lib, function, mode, key_size) case (mode, Operation.DECRYPT) if mode in Mode.classic_modes(): - results[f"AES/test_lib_dec/{str(mode)}"] = _test_lib_dec( - ffi, lib, function, mode, key_size - ) + results |= _test_lib_dec(ffi, lib, function, mode, key_size) case (mode, Operation.DECRYPT): - results[f"AES/test_lib_dec_aead/{str(mode)}"] = _test_lib_dec_aead( - ffi, lib, function, mode, key_size - ) + results |= _test_lib_dec_aead(ffi, lib, function, mode, key_size) return results diff --git a/crypto_condor/primitives/RSASSA.py b/crypto_condor/primitives/RSASSA.py index c0d5bc0..ae62cae 100644 --- a/crypto_condor/primitives/RSASSA.py +++ b/crypto_condor/primitives/RSASSA.py @@ -433,7 +433,7 @@ def _test_verify_pss_wycheproof( "Tests a function that signs with RSASSA-PSS.", {"hash_algorithm": hash_algorithm, "vectors file": filename}, ) - results_dict["Wycheproof/verify/{filename}"] = results + results_dict[f"Wycheproof/verify/{filename}"] = results logger.debug("Using vectors from: %s" % filename) # Add Wycheproof notes to results. results.add_notes(vectors_file.get("notes", {})) diff --git a/crypto_condor/primitives/common.py b/crypto_condor/primitives/common.py index cd0a8de..479838e 100644 --- a/crypto_condor/primitives/common.py +++ b/crypto_condor/primitives/common.py @@ -577,21 +577,49 @@ def check(self, *, empty_as_fail: bool = False) -> bool: return True +# Key duplication checks from: https://stackoverflow.com/a/30242574 class ResultsDict(dict): """A dictionary of Results. This class extends the built-in dictionary to group :class:`Results` as values. The - keys are defined by the calling function. Currently key uniqueness is not enforced, - the caller is responsible for not overwriting previous results. See :meth:`add` for - a suggestion. + uniqueness of the keys is enforced: updating the dictionary or trying to set a key + already present will raise ValueError. To facilitate the creation of unique keys, + see the :meth:`add` method. - It provides the :meth:`check` method to check for failed results over all of its - values. It also defines a string representation with :meth:`__str__`, similar to - that of :class:`Results`. + It provides the :meth:`check` method to check if the results included have failed + tests. It also defines a string representation with :meth:`__str__`, similar to that + of :class:`Results`. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, other=None, **kwargs): + super().__init__() + self.update(other, **kwargs) + + # TODO: override __ror__ too? + # TODO: how to get rid of the type ignore? + def __ior__(self, other): # type: ignore + """Override method to raise ValueError on duplicate keys.""" + if isinstance(other, ResultsDict): + self.update(other) + return self + else: + raise TypeError(f"Invalid type '{type(other)}' for operator |=") + + def __setitem__(self, k, v): + """Override method to raise ValueError on duplicate keys.""" + if k in self: + raise ValueError(f"Duplicate key '{k}' in ResultsDict") + super().__setitem__(k, v) + + def update(self, other=None, **kwargs): + """Override method to raise ValueError on duplicate keys.""" + if other is not None: + for k, v in ( + other.items() if isinstance(other, collections.abc.Mapping) else other + ): + self[k] = v + for k, v in kwargs.items(): + self[k] = v def __str__(self) -> str: """Returns a summary of the results contained.""" @@ -770,12 +798,9 @@ def process_results( "valid inputs that the implementation should use correctly.\n" ) description += ( - "Invalid tests : " - "invalid inputs that the implementation should reject.\n" - ) - description += ( - "Acceptable tests: " "inputs for legacy cases or weak parameters." + "Invalid tests : invalid inputs that the implementation should reject.\n" ) + description += "Acceptable tests: inputs for legacy cases or weak parameters." self.print(Panel(description, title="Types of tests")) # Show results summary: give some general info like primitives tested and show # the total count of tests. Include CC version as subtitle for reference. diff --git a/tests/harness/test.py b/tests/harness/test_harness.py similarity index 100% rename from tests/harness/test.py rename to tests/harness/test_harness.py diff --git a/tests/primitives/test_common.py b/tests/primitives/test_common.py index 6ce3ff2..c9407c8 100644 --- a/tests/primitives/test_common.py +++ b/tests/primitives/test_common.py @@ -3,8 +3,10 @@ from hashlib import sha256 from pathlib import Path +import pytest + from crypto_condor.primitives import SHA -from crypto_condor.primitives.common import Console +from crypto_condor.primitives.common import Console, Results, ResultsDict class TestConsole: @@ -18,3 +20,23 @@ def test_filename_none(self, tmp_path: Path): rd = SHA.test(lambda msg: sha256(msg).digest(), SHA.Algorithm.SHA_256) console = Console() assert console.process_results(rd, None) + +def test_results_dict(): + """Tests that ResultsDict raises ValueError on duplicate keys.""" + rd1 = ResultsDict() + rd2 = ResultsDict() + + res1 = Results("AES", "test", "description", {"mode": "ECB"}) + res2 = Results("AES", "test", "description", {"mode": "ECB"}) + + rd1.add(res1) + with pytest.raises(ValueError): + rd1.add(res2) + + rd2.add(res1) + with pytest.raises(ValueError): + rd1.update(rd2) + + with pytest.raises(ValueError): + rd1 |= rd2 +