diff --git a/README.rst b/README.rst index a30568a..1112988 100644 --- a/README.rst +++ b/README.rst @@ -332,6 +332,40 @@ references to other configuration files. aws-crypto @my-encrypt -i $INPUT -o $OUTPUT + +Encoding +-------- +By default, ``aws-crypto`` will always output raw binary data and expect raw binary data +as input. However, there are some cases where you might not want this to be the case. + +Sometimes this might be for convenience: + +* Accepting ciphertext through stdin from a human. +* Presenting ciphertext through stdout to a human. + +Sometimes it might be out of necessity: + +* Saving ciphertext output to a shell variable. + + * Most shells apply a system encoding to any data stored in a variable. As a result, this + often results in corrupted data if binary data is stored without additional encoding. + +* Piping ciphertext in PowerShell. + + * Similar to the above, all data passed through a PowerShell pipe is encoded using the + system encoding. + +In order to address these scenarios, we provide two optional arguments: + +* ``--decode`` : Base64-decode input before processing. +* ``--encode`` : Base64-encode output after processing. + +These can be used independently or together, on any valid input or output. + +Be aware, however, that if you target multiple files either through a path expansion or by +targetting a directory, the requested decoding/encoding will be applied to all files. + + Execution ========= @@ -381,6 +415,8 @@ Execution -o OUTPUT, --output OUTPUT Output file or directory for encrypt/decrypt operation, or - for stdout. + --encode Base64-encode output after processing + --decode Base64-decode input before processing -c ENCRYPTION_CONTEXT [ENCRYPTION_CONTEXT ...], --encryption-context ENCRYPTION_CONTEXT [ENCRYPTION_CONTEXT ...] key-value pair encryption context values (encryption only). Must a set of "key=value" pairs. ex: -c diff --git a/doc/index.rst b/doc/index.rst index f0e3c8d..075a1fd 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -11,6 +11,7 @@ Modules aws_encryption_sdk_cli aws_encryption_sdk_cli.internal aws_encryption_sdk_cli.internal.arg_parsing + aws_encryption_sdk_cli.internal.encoding aws_encryption_sdk_cli.internal.identifiers aws_encryption_sdk_cli.internal.io_handling aws_encryption_sdk_cli.internal.master_key_parsing diff --git a/src/aws_encryption_sdk_cli/__init__.py b/src/aws_encryption_sdk_cli/__init__.py index dab2f15..ac96697 100644 --- a/src/aws_encryption_sdk_cli/__init__.py +++ b/src/aws_encryption_sdk_cli/__init__.py @@ -111,7 +111,9 @@ def process_cli_request( recursive, # type: bool interactive, # type: bool no_overwrite, # type: bool - suffix=None # type: Optional[str] + suffix=None, # type: Optional[str] + decode_input=False, # type: Optional[bool] + encode_output=False # type: Optional[bool] ): # type: (...) -> None """Maps the operation request to the appropriate function based on the type of input and output provided. @@ -123,6 +125,8 @@ def process_cli_request( :param bool interactive: Should prompt before overwriting existing files :param bool no_overwrite: Should never overwrite existing files :param str suffix: Suffix to append to output filename (optional) + :param bool decode_input: Should input be base64 decoded before operation (optional) + :param bool encode_output: Should output be base64 encoded after operation (optional) """ _catch_bad_destination_requests(destination) _catch_bad_stdin_stdout_requests(source, destination) @@ -134,7 +138,9 @@ def process_cli_request( source=source, destination=destination, interactive=interactive, - no_overwrite=no_overwrite + no_overwrite=no_overwrite, + decode_input=decode_input, + encode_output=encode_output ) return @@ -154,7 +160,9 @@ def process_cli_request( destination=_destination, interactive=interactive, no_overwrite=no_overwrite, - suffix=suffix + suffix=suffix, + decode_input=decode_input, + encode_output=encode_output ) elif os.path.isfile(_source): @@ -172,7 +180,9 @@ def process_cli_request( source=_source, destination=_destination, interactive=interactive, - no_overwrite=no_overwrite + no_overwrite=no_overwrite, + decode_input=decode_input, + encode_output=encode_output ) @@ -236,7 +246,9 @@ def cli(raw_args=None): recursive=args.recursive, interactive=args.interactive, no_overwrite=args.no_overwrite, - suffix=args.suffix + suffix=args.suffix, + decode_input=args.decode, + encode_output=args.encode ) return None except AWSEncryptionSDKCLIError as error: diff --git a/src/aws_encryption_sdk_cli/internal/arg_parsing.py b/src/aws_encryption_sdk_cli/internal/arg_parsing.py index 4dffaf9..e23a481 100644 --- a/src/aws_encryption_sdk_cli/internal/arg_parsing.py +++ b/src/aws_encryption_sdk_cli/internal/arg_parsing.py @@ -214,6 +214,17 @@ def _build_parser(): help='Output file or directory for encrypt/decrypt operation, or - for stdout.' ) + parser.add_argument( + '--encode', + action='store_true', + help='Base64-encode output after processing' + ) + parser.add_argument( + '--decode', + action='store_true', + help='Base64-decode input before processing' + ) + parser.add_argument( '-c', '--encryption-context', diff --git a/src/aws_encryption_sdk_cli/internal/encoding.py b/src/aws_encryption_sdk_cli/internal/encoding.py new file mode 100644 index 0000000..52c7d09 --- /dev/null +++ b/src/aws_encryption_sdk_cli/internal/encoding.py @@ -0,0 +1,308 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Base64 context manager.""" +from __future__ import division + +import base64 +import io +import logging +import string +from typing import IO, Iterable, List, Optional # noqa pylint: disable=unused-import +from types import TracebackType # noqa pylint: disable=unused-import + +import six + +from aws_encryption_sdk_cli.internal.logging_utils import LOGGER_NAME + +_LOGGER = logging.getLogger(LOGGER_NAME) +__all__ = ('Base64IO',) + + +class Base64IO(io.IOBase): + """Wraps a stream, base64-decoding read results before returning them and base64-encoding + written bytes before writing them to the stream. Unless ``close_wrapped_on_close`` is + set to True, the underlying stream is not closed when this object is closed. Instances + of this class are not reusable in order maintain consistency with the :class:`io.IOBase` + behavior on ``close()``. + + .. note:: + + Provides iterator and context manager interfaces. + + .. warning:: + + Because up to two bytes of data must be buffered to ensure correct base64 encoding + of all data written, this object **must** be closed after you are done writing to + avoid data loss. If used as a context manager, we take care of that for you. + + :param wrapped: Stream to wrap + :param bool close_wrapped_on_close: Should the wrapped stream be closed when this object is closed (default: False) + """ + + closed = False + + def __init__(self, wrapped, close_wrapped_on_close=False): + # type: (Base64IO, IO, Optional[bool]) -> None + """Check for required methods on wrapped stream and set up read buffer. + + :raises TypeError: if ``wrapped`` does not have attributes needed to determine the stream's state + """ + required_attrs = ('read', 'write', 'close', 'closed', 'flush') + if not all(hasattr(wrapped, attr) for attr in required_attrs): + raise TypeError('Base64IO wrapped object must have attributes: {}'.format(repr(sorted(required_attrs)))) + super(Base64IO, self).__init__() + self.__wrapped = wrapped + self.__close_wrapped_on_close = close_wrapped_on_close + self.__read_buffer = b'' + self.__write_buffer = b'' + + def __enter__(self): + # type: () -> Base64IO + """Return self on enter.""" + return self + + def __exit__(self, exc_type, exc_value, traceback): + # type: (type, BaseException, TracebackType) -> None + """Properly close self on exit.""" + self.close() + + def close(self): + # type: () -> None + """Closes this stream, encoding and writing any buffered bytes is present. + + .. note:: + + This does **not** close the wrapped stream unless otherwise specified when this + object was created. + """ + if self.__write_buffer: + self.__wrapped.write(base64.b64encode(self.__write_buffer)) + self.__write_buffer = b'' + self.closed = True + if self.__close_wrapped_on_close: + self.__wrapped.close() + + def _passthrough_interactive_check(self, method_name, mode): + # type: (str, str) -> bool + """Attempt to call the specified method on the wrapped stream and return the result. + If the method is not found on the wrapped stream, returns False. + + .. note:: + + Special Case: If wrapped stream is a Python 2 file, inspect the file mode. + + :param str method_name: Name of method to call + :param str mode: Python 2 mode character + :rtype: bool + """ + try: + method = getattr(self.__wrapped, method_name) + except AttributeError: + if six.PY2 and isinstance(self.__wrapped, file): # noqa pylint: disable=undefined-variable + if mode in self.__wrapped.mode: + return True + return False + else: + return method() + + def writable(self): + # type: () -> bool + """Determine if the stream can be written to. + Delegates to wrapped stream when possible. + Otherwise returns False. + + :rtype: bool + """ + return self._passthrough_interactive_check('writable', 'w') + + def readable(self): + # type: () -> bool + """Determine if the stream can be read from. + Delegates to wrapped stream when possible. + Otherwise returns False. + + :rtype: bool + """ + return self._passthrough_interactive_check('readable', 'r') + + def flush(self): + # type: () -> None + """Flush the write buffer of the wrapped stream.""" + return self.__wrapped.flush() + + def write(self, b): + # type: (bytes) -> int + """Base64-encode the bytes and write them to the wrapped stream, buffering any + bytes that would require padding for the next write call. + + .. warning:: + + Because up to two bytes of data must be buffered to ensure correct base64 encoding + of all data written, this object **must** be closed after you are done writing to + avoid data loss. If used as a context manager, we take care of that for you. + + :param bytes b: Bytes to write to wrapped stream + :raises ValueError: if called on closed Base64IO object + :raises IOError: if underlying stream is not writable + """ + if self.closed: + raise ValueError('I/O operation on closed file.') + + if not self.writable(): + raise IOError('Stream is not writable') + + # Load any stashed bytes and clear the buffer + _bytes_to_write = self.__write_buffer + b + self.__write_buffer = b'' + + # If an even base64 chunk or finalizing the stream, write through. + if len(_bytes_to_write) % 3 == 0: + return self.__wrapped.write(base64.b64encode(_bytes_to_write)) + + # We're not finalizing the stream, so stash the trailing bytes and encode the rest. + trailing_byte_pos = -1 * (len(_bytes_to_write) % 3) + self.__write_buffer = _bytes_to_write[trailing_byte_pos:] + return self.__wrapped.write(base64.b64encode(_bytes_to_write[:trailing_byte_pos])) + + def writelines(self, lines): + # type: (Iterable[bytes]) -> None + """Write a list of lines. + + :param list lines: Lines to write + """ + for line in lines: + self.write(line) + + def _read_additional_data_removing_whitespace(self, data, total_bytes_to_read): + # type: (bytes, int) -> bytes + """Read additional data from wrapped stream, removing any whitespace found, until we + reach the desired number of bytes. + + :param bytes data: Data that has already been read from wrapped stream + :param int total_bytes_to_read: Number of total non-whitespace bytes to read from wrapped stream + :returns: ``total_bytes_to_read`` bytes from wrapped stream with no whitespace + :rtype: bytes + """ + if total_bytes_to_read is None: + # If the requested number of bytes is None, we read the entire message, in which + # case the base64 module happily removes any whitespace. + return data + + _data_buffer = io.BytesIO() + _data_buffer.write(b''.join(data.split())) + _remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() + + while _remaining_bytes_to_read > 0: + _raw_additional_data = self.__wrapped.read(_remaining_bytes_to_read) + if not _raw_additional_data: + # No more data to read from wrapped stream. + break + + _data_buffer.write(b''.join(_raw_additional_data.split())) + _remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() + return _data_buffer.getvalue() + + def read(self, b=None): + # type: (Optional[int]) -> bytes + """Read bytes from wrapped stream, base64-decoding before return, and adjusting read + from wrapped stream to return correct number of bytes. + + :param int b: Number of bytes to read + :returns: Decoded bytes from wrapped stream + :rtype: bytes + """ + if self.closed: + raise ValueError('I/O operation on closed file.') + + if not self.readable(): + raise IOError('Stream is not readable') + + if b is not None and b < 0: + b = None + _bytes_to_read = None + if b is not None: + # Calculate number of encoded bytes that must be read to get b raw bytes. + _bytes_to_read = int((b - len(self.__read_buffer)) * 4 / 3) + _bytes_to_read += (4 - _bytes_to_read % 4) + + # Read encoded bytes from wrapped stream. + data = self.__wrapped.read(_bytes_to_read) + # Remove whitespace from read data and attempt to read more data to get the desired + # number of bytes. + if any([six.b(char) in data for char in string.whitespace]): + data = self._read_additional_data_removing_whitespace(data, _bytes_to_read) + + results = io.BytesIO() + # First, load any stashed bytes + results.write(self.__read_buffer) + # Decode encoded bytes. + results.write(base64.b64decode(data)) + + results.seek(0) + output_data = results.read(b) + # Stash any extra bytes for the next run. + self.__read_buffer = results.read() + + return output_data + + def __iter__(self): # type: ignore + # Until https://github.com/python/typing/issues/11 + # there's no good way to tell mypy about custom + # iterators that subclass io.IOBase. + """Let this class act as an iterator.""" + return self + + def readline(self, limit=-1): + # type: (int) -> bytes + """Read and return one line from the stream. + If limit is specified, at most limit bytes will be read. + + .. note:: + + Because the source that this reads from may not contain any OEL characters, we + read "lines" in chunks of length ``io.DEFAULT_BUFFER_SIZE``. + + :type limit: int + :rtype: bytes + """ + return self.read(limit if limit > 0 else io.DEFAULT_BUFFER_SIZE) + + def readlines(self, hint=-1): + # type: (int) -> List[bytes] + """Read and return a list of lines from the stream. hint can be specified to control + the number of lines read: no more lines will be read if the total size (in bytes/ + characters) of all lines so far exceeds hint. + + :type hint: int + :returns: Lines of data + :rtype: list of bytes + """ + lines = [] + for line in self: # type: ignore + lines.append(line) + if hint > 0 and len(lines) * io.DEFAULT_BUFFER_SIZE > hint: + break + return lines + + def __next__(self): + # type: () -> bytes + """Python 3 iterator hook.""" + line = self.readline() + if line: + return line + raise StopIteration() + + def next(self): + # type: () -> bytes + """Python 2 iterator hook.""" + return self.__next__() diff --git a/src/aws_encryption_sdk_cli/internal/io_handling.py b/src/aws_encryption_sdk_cli/internal/io_handling.py index cb3fbf0..35389e3 100644 --- a/src/aws_encryption_sdk_cli/internal/io_handling.py +++ b/src/aws_encryption_sdk_cli/internal/io_handling.py @@ -11,14 +11,18 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Helper functions for handling all input and output for this CLI.""" +from __future__ import division + +import copy import logging import os import sys -from typing import IO, Type # noqa pylint: disable=unused-import +from typing import cast, IO, Type, Union # noqa pylint: disable=unused-import import aws_encryption_sdk import six +from aws_encryption_sdk_cli.internal.encoding import Base64IO from aws_encryption_sdk_cli.internal.identifiers import OUTPUT_SUFFIX from aws_encryption_sdk_cli.internal.logging_utils import LOGGER_NAME from aws_encryption_sdk_cli.internal.mypy_types import SOURCE, STREAM_KWARGS # noqa pylint: disable=unused-import @@ -78,24 +82,42 @@ def _ensure_dir_exists(filename): _LOGGER.info('Created directory: %s', dest_final_dir) -def _single_io_write(stream_args, source, destination_writer): - # type: (STREAM_KWARGS, SOURCE, IO) -> None +def _encoder(stream, should_base64): + # type: (IO, bool) -> Union[IO, Base64IO] + """Wraps a stream in either a Base64IO transformer or results stream if wrapping is not requested. + + :param stream: Stream to wrap + :type stream: file-like object + :param bool should_base64: Should the stream be wrapped with Base64IO + :returns: wrapped stream + :rtype: io.IOBase + """ + if should_base64: + return Base64IO(stream) + return stream + + +def _single_io_write(stream_args, source, destination_writer, decode_input, encode_output): + # type: (STREAM_KWARGS, IO, IO, bool, bool) -> None """Performs the actual write operations for a single operation. :param dict stream_args: kwargs to pass to `aws_encryption_sdk.stream` :param source: source to write - :type source: str, stream, or file-like object + :type source: file-like object :param destination_writer: destination object to which to write - :type source: stream or file-like object + :type destination_writer: file-like object + :param bool decode_input: Should input be base64 decoded before operation + :param bool encode_output: Should output be base64 encoded after operation """ - with aws_encryption_sdk.stream(source=source, **stream_args) as handler: - for chunk in handler: - destination_writer.write(chunk) - destination_writer.flush() + with _encoder(source, decode_input) as _source, _encoder(destination_writer, encode_output) as _destination: + with aws_encryption_sdk.stream(source=_source, **stream_args) as handler: + for chunk in handler: + _destination.write(chunk) + _destination.flush() -def process_single_operation(stream_args, source, destination, interactive, no_overwrite): - # type: (STREAM_KWARGS, SOURCE, str, bool, bool) -> None +def process_single_operation(stream_args, source, destination, interactive, no_overwrite, decode_input, encode_output): + # type: (STREAM_KWARGS, SOURCE, str, bool, bool, bool, bool) -> None """Processes a single encrypt/decrypt operation given a pre-loaded source. :param dict stream_args: kwargs to pass to `aws_encryption_sdk.stream` @@ -104,6 +126,8 @@ def process_single_operation(stream_args, source, destination, interactive, no_o :param str destination: destination identifier :param bool interactive: Should prompt before overwriting existing files :param bool no_overwrite: Should never overwrite existing files + :param bool decode_input: Should input be base64 decoded before operation + :param bool encode_output: Should output be base64 encoded after operation """ if destination == '-': destination_writer = _stdout() @@ -112,16 +136,20 @@ def process_single_operation(stream_args, source, destination, interactive, no_o return _ensure_dir_exists(destination) destination_writer = open(destination, 'wb') + if source == '-': - source = _stdin() - try: + source_reader = _stdin() + else: + source_reader = cast(IO, source) + + with destination_writer: _single_io_write( stream_args=stream_args, - source=source, - destination_writer=destination_writer + source=source_reader, + destination_writer=destination_writer, + decode_input=decode_input, + encode_output=encode_output ) - finally: - destination_writer.close() def _should_write_file(filepath, interactive, no_overwrite): @@ -162,8 +190,8 @@ def _should_write_file(filepath, interactive, no_overwrite): return True -def process_single_file(stream_args, source, destination, interactive, no_overwrite): - # type: (STREAM_KWARGS, str, str, bool, bool) -> None +def process_single_file(stream_args, source, destination, interactive, no_overwrite, decode_input, encode_output): + # type: (STREAM_KWARGS, str, str, bool, bool, bool, bool) -> None """Processes a single encrypt/decrypt operation on a source file. :param dict stream_args: kwargs to pass to `aws_encryption_sdk.stream` @@ -171,6 +199,8 @@ def process_single_file(stream_args, source, destination, interactive, no_overwr :param str destination: Full file path to destination file :param bool interactive: Should prompt before overwriting existing files :param bool no_overwrite: Should never overwrite existing files + :param bool decode_input: Should input be base64 decoded before operation + :param bool encode_output: Should output be base64 encoded after operation """ if os.path.realpath(source) == os.path.realpath(destination): # File source, directory destination, empty suffix: @@ -178,13 +208,26 @@ def process_single_file(stream_args, source, destination, interactive, no_overwr return _LOGGER.info('%sing file %s to %s', stream_args['mode'], source, destination) + + _stream_args = copy.copy(stream_args) + # Because we can actually know size for files and Base64IO does not support seeking, + # set the source length manually for files. This allows enables data key caching when + # Base64-decoding a source file. + source_file_size = os.path.getsize(source) + if decode_input and not encode_output: + _stream_args['source_length'] = int(source_file_size * (3 / 4)) + else: + _stream_args['source_length'] = source_file_size + with open(source, 'rb') as source_reader: process_single_operation( - stream_args=stream_args, + stream_args=_stream_args, source=source_reader, destination=destination, interactive=interactive, - no_overwrite=no_overwrite + no_overwrite=no_overwrite, + decode_input=decode_input, + encode_output=encode_output ) @@ -221,8 +264,8 @@ def _output_dir(source_root, destination_root, source_dir): return os.path.join(destination_root, suffix) -def process_dir(stream_args, source, destination, interactive, no_overwrite, suffix): - # type: (STREAM_KWARGS, str, str, bool, bool, str) -> None +def process_dir(stream_args, source, destination, interactive, no_overwrite, suffix, decode_input, encode_output): + # type: (STREAM_KWARGS, str, str, bool, bool, str, bool, bool) -> None """Processes encrypt/decrypt operations on all files in a directory tree. :param dict stream_args: kwargs to pass to `aws_encryption_sdk.stream` @@ -231,6 +274,8 @@ def process_dir(stream_args, source, destination, interactive, no_overwrite, suf :param bool interactive: Should prompt before overwriting existing files :param bool no_overwrite: Should never overwrite existing files :param str suffix: Suffix to append to output filename + :param bool decode_input: Should input be base64 decoded before operation + :param bool encode_output: Should output be base64 encoded after operation """ _LOGGER.debug('%sing directory %s to %s', stream_args['mode'], source, destination) for base_dir, _dirs, files in os.walk(source): @@ -252,5 +297,7 @@ def process_dir(stream_args, source, destination, interactive, no_overwrite, suf source=source_filename, destination=destination_filename, interactive=interactive, - no_overwrite=no_overwrite + no_overwrite=no_overwrite, + decode_input=decode_input, + encode_output=encode_output ) diff --git a/src/pylintrc b/src/pylintrc index 1c5ffc0..ead6e21 100644 --- a/src/pylintrc +++ b/src/pylintrc @@ -1,12 +1,24 @@ +[MESSAGES CONTROL] +# Disabling messages that we either don't care about +# for tests or are necessary to break for tests. +# +# R0801 : duplicate-code (we have several functions in io_handling with similar profiles) +disable = R0801 + [BASIC] # Allow function names up to 50 characters function-rgx = [a-z_][a-z0-9_]{2,50}$ +# Whitelist argument names: iv, b +argument-rgx = ([a-z_][a-z0-9_]{2,30}$)|(^b$) +# Whitelist variable names: b, _b +variable-rgx = ([a-z_][a-z0-9_]{2,30}$)|(^b$) [VARIABLES] additional-builtins = raw_input [DESIGN] max-args = 10 +max-attributes = 10 [FORMAT] max-line-length = 120 diff --git a/test/integration/test_i_aws_encryption_sdk_cli.py b/test/integration/test_i_aws_encryption_sdk_cli.py index 96889ab..b1fa86a 100644 --- a/test/integration/test_i_aws_encryption_sdk_cli.py +++ b/test/integration/test_i_aws_encryption_sdk_cli.py @@ -11,10 +11,12 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Integration testing suite for AWS Encryption SDK CLI.""" +import base64 from distutils.spawn import find_executable # distutils confuses pylint: disable=import-error,no-name-in-module import filecmp import os import shlex +import shutil from subprocess import PIPE, Popen import pytest @@ -89,6 +91,54 @@ def test_file_to_file_cycle_target_through_symlink(tmpdir): assert filecmp.cmp(str(plaintext), str(decrypted)) +@pytest.mark.skipif(not _should_run_tests(), reason='Integration tests disabled. See test/integration/README.rst') +@pytest.mark.parametrize('encode, decode', ( + (True, False), + (False, True), + (True, True), + (False, False) +)) +def test_file_to_file_base64(tmpdir, encode, decode): + plaintext = tmpdir.join('source_plaintext') + ciphertext_a = tmpdir.join('ciphertext-a') + ciphertext_b = tmpdir.join('ciphertext-b') + decrypted = tmpdir.join('decrypted') + plaintext_source = os.urandom(10240) # make sure we have more than one chunk + with open(str(plaintext), 'wb') as f: + f.write(plaintext_source) + + encrypt_flag = ' --encode' if encode else '' + decrypt_flag = ' --decode' if decode else '' + + encrypt_args = ENCRYPT_ARGS_TEMPLATE.format( + source=str(plaintext), + target=str(ciphertext_a) + ) + encrypt_flag + decrypt_args = DECRYPT_ARGS_TEMPLATE.format( + source=str(ciphertext_b), + target=str(decrypted) + ) + decrypt_flag + + aws_encryption_sdk_cli.cli(shlex.split(encrypt_args)) + + if encode and not decode: + with open(str(ciphertext_a), 'rb') as ct_a, open(str(ciphertext_b), 'wb') as ct_b: + raw_ct = base64.b64decode(ct_a.read()) + ct_b.write(raw_ct) + elif decode and not encode: + with open(str(ciphertext_a), 'rb') as ct, open(str(ciphertext_b), 'wb') as b64_ct: + b64_ct.write(base64.b64encode(ct.read())) + else: + shutil.copy2(str(ciphertext_a), str(ciphertext_b)) + + aws_encryption_sdk_cli.cli(shlex.split(decrypt_args)) + + with open(str(decrypted), 'rb') as f: + decrypted_plaintext = f.read() + + assert decrypted_plaintext == plaintext_source + + @pytest.mark.skipif(not _should_run_tests(), reason='Integration tests disabled. See test/integration/README.rst') def test_file_to_file_cycle_with_caching(tmpdir): plaintext = tmpdir.join('source_plaintext') diff --git a/test/pylintrc b/test/pylintrc index 4e91297..d58b6e4 100644 --- a/test/pylintrc +++ b/test/pylintrc @@ -6,11 +6,12 @@ # C0111 : missing-docstring (we don't write docstrings for tests) # E1101 : no-member (raised on patched objects with mock checks) # R0801 : duplicate-code (unit tests for similar things tend to be similar) +# R0914 : too-many-locals (we prefer clarity over brevity in tests) # R0903 : too-few-public-methods (common for stub classes sometimes needed in tests) # W0212 : protected-access (raised when calling _ methods) # W0621 : redefined-outer-name (raised when using pytest-mock) # W0613 : unused-argument (raised when patches are needed but not called) -disable = C0103, C0111, E1101, R0801, R0903, W0212, W0621, W0613 +disable = C0103, C0111, E1101, R0801, R0903, R0914, W0212, W0621, W0613 [VARIABLES] diff --git a/test/unit/test_aws_encryption_sdk_cli.py b/test/unit/test_aws_encryption_sdk_cli.py index 19880af..8fa710d 100644 --- a/test/unit/test_aws_encryption_sdk_cli.py +++ b/test/unit/test_aws_encryption_sdk_cli.py @@ -198,7 +198,9 @@ def test_process_cli_request_source_dir_destination_dir(tmpdir, patch_for_proces recursive=True, interactive=sentinel.interactive, no_overwrite=sentinel.no_overwrite, - suffix=sentinel.suffix + suffix=sentinel.suffix, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) aws_encryption_sdk_cli.process_dir.assert_called_once_with( @@ -207,7 +209,9 @@ def test_process_cli_request_source_dir_destination_dir(tmpdir, patch_for_proces destination=str(destination), interactive=sentinel.interactive, no_overwrite=sentinel.no_overwrite, - suffix=sentinel.suffix + suffix=sentinel.suffix, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) assert not aws_encryption_sdk_cli.process_single_file.called assert not aws_encryption_sdk_cli.process_single_operation.called @@ -234,7 +238,9 @@ def test_process_cli_request_source_stdin(tmpdir, patch_for_process_cli_request) destination=str(destination), recursive=False, interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) assert not aws_encryption_sdk_cli.process_dir.called assert not aws_encryption_sdk_cli.process_single_file.called @@ -243,7 +249,9 @@ def test_process_cli_request_source_stdin(tmpdir, patch_for_process_cli_request) source='-', destination=str(destination), interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) @@ -258,7 +266,9 @@ def test_process_cli_request_source_file_destination_dir(tmpdir, patch_for_proce recursive=False, interactive=sentinel.interactive, no_overwrite=sentinel.no_overwrite, - suffix='CUSTOM_SUFFIX' + suffix='CUSTOM_SUFFIX', + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) assert not aws_encryption_sdk_cli.process_dir.called assert not aws_encryption_sdk_cli.process_single_operation.called @@ -267,7 +277,9 @@ def test_process_cli_request_source_file_destination_dir(tmpdir, patch_for_proce source=str(source), destination=str(destination.join('sourceCUSTOM_SUFFIX')), interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) @@ -282,7 +294,9 @@ def test_process_cli_request_source_file_destination_file(tmpdir, patch_for_proc destination=str(destination), recursive=False, interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) assert not aws_encryption_sdk_cli.process_dir.called assert not aws_encryption_sdk_cli.process_single_operation.called @@ -291,7 +305,9 @@ def test_process_cli_request_source_file_destination_file(tmpdir, patch_for_proc source=str(source), destination=str(destination), interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) @@ -365,7 +381,9 @@ def test_process_cli_request_source_contains_directory_nonrecursive( source=str(source_file), destination=ANY, interactive=False, - no_overwrite=False + no_overwrite=False, + decode_input=False, + encode_output=False ) for source_file in (test_file_a, test_file_c) ], @@ -507,7 +525,9 @@ def patch_for_cli(mocker): recursive=sentinel.recursive, interactive=sentinel.interactive, no_overwrite=sentinel.no_overwrite, - suffix=sentinel.suffix + suffix=sentinel.suffix, + decode=sentinel.decode_input, + encode=sentinel.encode_output ) mocker.patch.object(aws_encryption_sdk_cli, 'setup_logger') mocker.patch.object(aws_encryption_sdk_cli, 'build_crypto_materials_manager_from_args') @@ -540,7 +560,9 @@ def test_cli(patch_for_cli): recursive=sentinel.recursive, interactive=sentinel.interactive, no_overwrite=sentinel.no_overwrite, - suffix=sentinel.suffix + suffix=sentinel.suffix, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) assert test is None diff --git a/test/unit/test_encoding.py b/test/unit/test_encoding.py new file mode 100644 index 0000000..30de420 --- /dev/null +++ b/test/unit/test_encoding.py @@ -0,0 +1,373 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit test suite for ``aws_encryption_sdk_cli.internal.encoding``.""" +import base64 +import functools +import io +import os + +from mock import MagicMock, sentinel +import pytest + +from aws_encryption_sdk_cli.internal.encoding import Base64IO + + +def test_base64io_bad_wrap(): + with pytest.raises(TypeError) as excinfo: + Base64IO(7) + + excinfo.match(r'Base64IO wrapped object must have attributes: *') + + +def test_base64io_write_after_closed(): + with Base64IO(io.BytesIO()) as test: + with pytest.raises(ValueError) as excinfo: + test.close() + test.write(b'aksdhjf') + + excinfo.match(r'I/O operation on closed file.') + + +def test_base64io_read_after_closed(): + with Base64IO(io.BytesIO()) as test: + with pytest.raises(ValueError) as excinfo: + test.close() + test.read() + + excinfo.match(r'I/O operation on closed file.') + + +@pytest.mark.parametrize('method_name', ('isatty', 'seekable')) +def test_base64io_always_false_methods(method_name): + test = Base64IO(io.BytesIO()) + + assert not getattr(test, method_name)() + + +@pytest.mark.parametrize('method_name', ('fileno', 'seek', 'tell', 'truncate')) +def test_unsupported_methods(method_name): + test = Base64IO(io.BytesIO()) + + with pytest.raises(IOError): + getattr(test, method_name)() + + +@pytest.mark.parametrize('method_name', ('flush', 'writable', 'readable')) +def test_passthrough_methods_present(monkeypatch, method_name): + wrapped = io.BytesIO() + monkeypatch.setattr(wrapped, method_name, lambda: sentinel.passthrough) + wrapper = Base64IO(wrapped) + + assert getattr(wrapper, method_name)() is sentinel.passthrough + + +@pytest.mark.parametrize('method_name', ('writable', 'readable')) +def test_passthrough_methods_not_present(monkeypatch, method_name): + wrapped = MagicMock() + monkeypatch.delattr(wrapped, method_name, False) + wrapper = Base64IO(wrapped) + + assert not getattr(wrapper, method_name)() + + +@pytest.mark.parametrize('mode, method_name, expected', ( + ('wb', 'writable', True), + ('rb', 'readable', True), + ('rb', 'writable', False), + ('wb', 'readable', False) +)) +def test_passthrough_methods_file(tmpdir, method_name, mode, expected): + source = tmpdir.join('source') + source.write('some data') + + with open(str(source), mode) as reader: + with Base64IO(reader) as b64: + test = getattr(b64, method_name)() + + if expected: + assert test + else: + assert not test + + +@pytest.mark.parametrize('patch_method, call_method, call_arg', ( + ('writable', 'write', b''), + ('readable', 'read', 0) +)) +def test_non_interactive_error(monkeypatch, patch_method, call_method, call_arg): + wrapped = io.BytesIO() + monkeypatch.setattr(wrapped, patch_method, lambda: False) + + with Base64IO(wrapped) as wrapper: + with pytest.raises(IOError) as excinfo: + getattr(wrapper, call_method)(call_arg) + + excinfo.match(r'Stream is not ' + patch_method) + + +def build_test_cases(): + """Build test cases for read/write encoding checks. + + :returns: (bytes_to_generate, bytes_per_round, number_of_rounds, total_bytes_to_expect) + """ + test_cases = [] + + # exact single-shot, varying multiples + for size in (1, 2, 3, 4, 5, 6, 7, 222, 1024): + test_cases.append((size, size, 1, size)) + + test_cases.append((1024, None, 1, 1024)) # single-shot + test_cases.append((1024, -1, 1, 1024)) # single-shot + + # Odd multiples with operation smaller, equal to, and larger than total + for rounds in (1, 3, 5): + for read_size in (1, 2, 3, 4, 5, 1024, 1500): + test_cases.append((1024, read_size, rounds, min(read_size * rounds, 1024))) + + return test_cases + + +@pytest.mark.parametrize( + 'bytes_to_generate, bytes_per_round, number_of_rounds, total_bytes_to_expect', + build_test_cases() +) +def test_base64io_decode(bytes_to_generate, bytes_per_round, number_of_rounds, total_bytes_to_expect): + plaintext_source = os.urandom(bytes_to_generate) + plaintext_b64 = io.BytesIO(base64.b64encode(plaintext_source)) + plaintext_wrapped = Base64IO(plaintext_b64) + + test = b'' + for _round in range(number_of_rounds): + test += plaintext_wrapped.read(bytes_per_round) + + assert len(test) == total_bytes_to_expect + assert test == plaintext_source[:total_bytes_to_expect] + + +@pytest.mark.parametrize('source_bytes', [case[0] for case in build_test_cases()]) +def test_base64io_encode_context_manager(source_bytes): + plaintext_source = os.urandom(source_bytes) + plaintext_b64 = base64.b64encode(plaintext_source) + plaintext_stream = io.BytesIO() + + with Base64IO(plaintext_stream) as plaintext_wrapped: + plaintext_wrapped.write(plaintext_source) + + assert plaintext_stream.getvalue() == plaintext_b64 + + +def test_base64io_encode_context_manager_reuse(): + plaintext_source = os.urandom(10) + plaintext_stream = io.BytesIO() + + stream = Base64IO(plaintext_stream) + + with stream as plaintext_wrapped: + plaintext_wrapped.write(plaintext_source) + + with pytest.raises(ValueError) as excinfo: + with stream as plaintext_wrapped: + plaintext_wrapped.read() + + excinfo.match(r'I/O operation on closed file.') + + +def test_base64io_encode_use_after_context_manager_exit(): + plaintext_source = os.urandom(10) + plaintext_stream = io.BytesIO() + + stream = Base64IO(plaintext_stream) + + with stream as plaintext_wrapped: + plaintext_wrapped.write(plaintext_source) + + assert stream.closed + + with pytest.raises(ValueError) as excinfo: + stream.read() + + excinfo.match(r'I/O operation on closed file.') + + +@pytest.mark.parametrize('source_bytes', [case[0] for case in build_test_cases()]) +def test_base64io_encode(source_bytes): + plaintext_source = os.urandom(source_bytes) + plaintext_b64 = base64.b64encode(plaintext_source) + plaintext_stream = io.BytesIO() + + plaintext_wrapped = Base64IO(plaintext_stream) + try: + plaintext_wrapped.write(plaintext_source) + finally: + plaintext_wrapped.close() + + assert plaintext_stream.getvalue() == plaintext_b64 + + +@pytest.mark.parametrize('bytes_to_read, expected_bytes_read', ( + (-1, io.DEFAULT_BUFFER_SIZE), + (0, io.DEFAULT_BUFFER_SIZE), + (1, 1), + (10, 10) +)) +def test_base64io_decode_readline(bytes_to_read, expected_bytes_read): + source_plaintext = os.urandom(io.DEFAULT_BUFFER_SIZE * 2) + source_stream = io.BytesIO(base64.b64encode(source_plaintext)) + + with Base64IO(source_stream) as decoder: + test = decoder.readline(bytes_to_read) + + assert test == source_plaintext[:expected_bytes_read] + + +def build_b64_with_whitespace(source_bytes, line_length): + plaintext_source = os.urandom(source_bytes) + b64_plaintext = io.BytesIO(base64.b64encode(plaintext_source)) + b64_plaintext_with_whitespace = b'\n'.join([ + line for line + in iter(functools.partial(b64_plaintext.read, line_length), b'') + ]) + return plaintext_source, b64_plaintext_with_whitespace + + +def build_whitespace_testcases(): + scenarios = [] + for test_case in build_test_cases(): + scenarios.append(build_b64_with_whitespace(test_case[0], 3) + (test_case[-1],)) + + # first read is mostly whitespace + plaintext, b64_plaintext = build_b64_with_whitespace(100, 20) + b64_plaintext = (b' ' * 80) + b64_plaintext + scenarios.append((plaintext, b64_plaintext, 100)) + + # first several reads are entirely whitespace + plaintext, b64_plaintext = build_b64_with_whitespace(100, 20) + b64_plaintext = (b' ' * 500) + b64_plaintext + scenarios.append((plaintext, b64_plaintext, 100)) + + return scenarios + + +@pytest.mark.parametrize('plaintext_source, b64_plaintext_with_whitespace, read_bytes', build_whitespace_testcases()) +def test_base64io_decode_with_whitespace(plaintext_source, b64_plaintext_with_whitespace, read_bytes): + with Base64IO(io.BytesIO(b64_plaintext_with_whitespace)) as decoder: + test = decoder.read(read_bytes) + + assert test == plaintext_source[:read_bytes] + + +def test_base64io_decode_read_only_from_buffer(): + plaintext_source = b'12345' + plaintext_b64 = io.BytesIO(base64.b64encode(plaintext_source)) + plaintext_wrapped = Base64IO(plaintext_b64) + + test_1 = plaintext_wrapped.read(1) + test_2 = plaintext_wrapped.read(1) + test_3 = plaintext_wrapped.read() + + assert test_1 == b'1' + assert test_2 == b'2' + assert test_3 == b'345' + + +def test_base64io_decode_context_manager(): + source_plaintext = os.urandom(102400) + source_stream = io.BytesIO(base64.b64encode(source_plaintext)) + + test = io.BytesIO() + with Base64IO(source_stream) as stream: + for chunk in stream: + test.write(chunk) + + assert test.getvalue() == source_plaintext + assert not source_stream.closed + + +def test_base64io_decode_context_manager_close_source(): + source_plaintext = os.urandom(102400) + source_stream = io.BytesIO(base64.b64encode(source_plaintext)) + + test = io.BytesIO() + with Base64IO(source_stream, close_wrapped_on_close=True) as stream: + for chunk in stream: + test.write(chunk) + + assert test.getvalue() == source_plaintext + assert source_stream.closed + + +@pytest.mark.parametrize('hint_bytes, expected_bytes_read', ( + (-1, 102400), + (0, 102400), + (1, io.DEFAULT_BUFFER_SIZE), + (io.DEFAULT_BUFFER_SIZE + 99, io.DEFAULT_BUFFER_SIZE * 2) +)) +def test_base64io_decode_readlines(hint_bytes, expected_bytes_read): + source_plaintext = os.urandom(102400) + source_stream = io.BytesIO(base64.b64encode(source_plaintext)) + + test = io.BytesIO() + with Base64IO(source_stream) as stream: + for chunk in stream.readlines(hint_bytes): + test.write(chunk) + + assert len(test.getvalue()) == expected_bytes_read + assert test.getvalue() == source_plaintext[:expected_bytes_read] + + +def test_base64io_encode_writelines(): + source_plaintext = [os.urandom(1024) for _ in range(100)] + b64_plaintext = base64.b64encode(b''.join(source_plaintext)) + + test = io.BytesIO() + with Base64IO(test) as encoder: + encoder.writelines(source_plaintext) + + assert test.getvalue() == b64_plaintext + + +def test_base64io_decode_file(tmpdir): + source_plaintext = os.urandom(1024 * 1024) + b64_plaintext = tmpdir.join('base64_plaintext') + b64_plaintext.write(base64.b64encode(source_plaintext)) + decoded_plaintext = tmpdir.join('decoded_plaintext') + + with open(str(b64_plaintext), 'rb') as source, open(str(decoded_plaintext), 'wb') as raw: + with Base64IO(source) as decoder: + for chunk in decoder: + raw.write(chunk) + + with open(str(decoded_plaintext), 'rb') as raw: + decoded = raw.read() + + assert decoded == source_plaintext + + +def test_base64io_encode_file(tmpdir): + source_plaintext = os.urandom(1024 * 1024) + plaintext_b64 = base64.b64encode(source_plaintext) + plaintext = tmpdir.join('plaintext') + b64_plaintext = tmpdir.join('base64_plaintext') + + with open(str(plaintext), 'wb') as file: + file.write(source_plaintext) + + with open(str(plaintext), 'rb') as source, open(str(b64_plaintext), 'wb') as target: + with Base64IO(target) as encoder: + for chunk in source: + encoder.write(chunk) + + with open(str(b64_plaintext), 'rb') as file2: + encoded = file2.read() + + assert encoded == plaintext_b64 diff --git a/test/unit/test_io_handling.py b/test/unit/test_io_handling.py index 64c7c81..73e249b 100644 --- a/test/unit/test_io_handling.py +++ b/test/unit/test_io_handling.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Unit test suite for ``aws_encryption_sdk_cli.internal.io_handling``.""" +import base64 import io import os import sys @@ -116,42 +117,79 @@ def test_ensure_dir_exists_current_directory(patch_makedirs): assert not patch_makedirs.called +@pytest.mark.parametrize('should_base64', (True, False)) +def test_encoder(mocker, should_base64): + mocker.patch.object(io_handling, 'Base64IO') + + test = io_handling._encoder(sentinel.stream, should_base64) + + if should_base64: + assert test is io_handling.Base64IO.return_value + else: + assert test is sentinel.stream + + def test_single_io_write_stream(tmpdir, patch_aws_encryption_sdk_stream): patch_aws_encryption_sdk_stream.return_value = io.BytesIO(DATA) target_file = tmpdir.join('target') + mock_source = MagicMock() with open(str(target_file), 'wb') as destination_writer: io_handling._single_io_write( stream_args={ 'a': sentinel.a, 'b': sentinel.b }, - source=sentinel.source, - destination_writer=destination_writer + source=mock_source, + destination_writer=destination_writer, + decode_input=False, + encode_output=False ) patch_aws_encryption_sdk_stream.assert_called_once_with( - source=sentinel.source, + source=mock_source.__enter__.return_value, a=sentinel.a, b=sentinel.b ) assert target_file.read('rb') == DATA +def test_single_io_write_stream_encode_output(tmpdir, patch_aws_encryption_sdk_stream): + patch_aws_encryption_sdk_stream.return_value = io.BytesIO(DATA) + target_file = tmpdir.join('target') + mock_source = MagicMock() + with open(str(target_file), 'wb') as destination_writer: + io_handling._single_io_write( + stream_args={ + 'a': sentinel.a, + 'b': sentinel.b + }, + source=mock_source, + destination_writer=destination_writer, + decode_input=False, + encode_output=True + ) + + assert target_file.read('rb') == base64.b64encode(DATA) + + def test_process_single_operation_stdout(patch_for_process_single_operation, patch_should_write_file): io_handling.process_single_operation( stream_args=sentinel.stream_args, source=sentinel.source, destination='-', interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) io_handling._single_io_write.assert_called_once_with( stream_args=sentinel.stream_args, source=sentinel.source, - destination_writer=io_handling._stdout.return_value + destination_writer=io_handling._stdout.return_value, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) assert not patch_should_write_file.called - io_handling._stdout.return_value.close.assert_called_once_with() def test_process_single_operation_stdin_stdout(patch_for_process_single_operation, patch_should_write_file): @@ -160,12 +198,16 @@ def test_process_single_operation_stdin_stdout(patch_for_process_single_operatio source='-', destination='-', interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) io_handling._single_io_write.assert_called_once_with( stream_args=sentinel.stream_args, source=io_handling._stdin.return_value, - destination_writer=io_handling._stdout.return_value + destination_writer=io_handling._stdout.return_value, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) @@ -176,7 +218,9 @@ def test_process_single_operation_file(patch_for_process_single_operation, patch source=sentinel.source, destination=sentinel.destination_file, interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) io_handling._ensure_dir_exists.assert_called_once_with(sentinel.destination_file) patch_should_write_file.assert_called_once_with( @@ -188,9 +232,10 @@ def test_process_single_operation_file(patch_for_process_single_operation, patch io_handling._single_io_write.assert_called_once_with( stream_args=sentinel.stream_args, source=sentinel.source, - destination_writer=mock_open.return_value + destination_writer=mock_open.return_value, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) - mock_open.return_value.close.assert_called_once_with() def test_process_single_operation_file_should_not_write(patch_for_process_single_operation, patch_should_write_file): @@ -201,7 +246,9 @@ def test_process_single_operation_file_should_not_write(patch_for_process_single source=sentinel.source, destination=sentinel.destination_file, interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) assert not io_handling._ensure_dir_exists.called assert not mock_open.called @@ -251,25 +298,58 @@ def test_should_write_file_does_exist(tmpdir, patch_input, interactive, no_overw assert not should_write -def test_process_single_file(tmpdir, patch_process_single_operation): - source = tmpdir.join('source') +@pytest.mark.parametrize('mode, decode_input, encode_output, expected_multiplier', ( + ('encrypt', False, False, 1.0), + ('encrypt', True, False, 0.75), + ('encrypt', False, True, 1.0), + ('encrypt', True, True, 1.0), + ('decrypt', False, False, 1.0), + ('decrypt', True, False, 0.75), + ('decrypt', False, True, 1.0), + ('decrypt', True, True, 1.0) +)) +def test_process_single_file( + tmpdir, + patch_process_single_operation, + mode, + decode_input, + encode_output, + expected_multiplier +): + source = tmpdir.join('source_file') source.write('some data') destination = tmpdir.join('destination') + initial_kwargs = dict( + mode=mode, + a=sentinel.a, + b=sentinel.b + ) + expected_length = int(os.path.getsize(str(source)) * expected_multiplier) + updated_kwargs = dict( + mode=mode, + a=sentinel.a, + b=sentinel.b, + source_length=expected_length + ) with patch('aws_encryption_sdk_cli.internal.io_handling.open', create=True) as mock_open: io_handling.process_single_file( - stream_args={'mode': 'encrypt'}, + stream_args=initial_kwargs, source=str(source), destination=str(destination), interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=decode_input, + encode_output=encode_output ) mock_open.assert_called_once_with(str(source), 'rb') patch_process_single_operation.assert_called_once_with( - stream_args={'mode': 'encrypt'}, + stream_args=updated_kwargs, source=mock_open.return_value.__enter__.return_value, destination=str(destination), interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=decode_input, + encode_output=encode_output ) @@ -283,7 +363,9 @@ def test_process_single_file_source_is_destination(tmpdir, patch_process_single_ source=str(source), destination=str(source), interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) assert not mock_open.called @@ -302,7 +384,9 @@ def test_process_single_file_destination_is_symlink_to_source(tmpdir, patch_proc source=str(source), destination=destination, interactive=sentinel.interactive, - no_overwrite=sentinel.no_overwrite + no_overwrite=sentinel.no_overwrite, + decode_input=sentinel.decode_input, + encode_output=sentinel.encode_output ) assert not mock_open.called @@ -409,7 +493,9 @@ def test_process_dir(tmpdir, patch_aws_encryption_sdk_stream): destination=str(target), interactive=False, no_overwrite=False, - suffix=None + suffix=None, + decode_input=False, + encode_output=False ) for filename, suffix in (