diff --git a/securesystemslib/util.py b/securesystemslib/util.py index bae7e76a..7af85c94 100755 --- a/securesystemslib/util.py +++ b/securesystemslib/util.py @@ -35,13 +35,15 @@ import securesystemslib.settings import securesystemslib.hash import securesystemslib.formats +import securesystemslib.storage import six logger = logging.getLogger(__name__) -def get_file_details(filepath, hash_algorithms=['sha256']): +def get_file_details(filepath, hash_algorithms=['sha256'], + storage_backend=None): """ To get file's length and hash information. The hash is computed using the @@ -53,6 +55,13 @@ def get_file_details(filepath, hash_algorithms=['sha256']): Absolute file path of a file. hash_algorithms: + A list of hash algorithms with which the file's hash should be computed. + Defaults to ['sha256'] + + storage_backend: + An object which implements + securesystemslib.storage.StorageBackendInterface. When no object is + passed a FilesystemBackend will be instantiated and used. securesystemslib.exceptions.FormatError: If hash of the file does not match @@ -69,23 +78,22 @@ def get_file_details(filepath, hash_algorithms=['sha256']): securesystemslib.formats.PATH_SCHEMA.check_match(filepath) securesystemslib.formats.HASHALGORITHMS_SCHEMA.check_match(hash_algorithms) + if storage_backend is None: + storage_backend = securesystemslib.storage.FilesystemBackend() + # The returned file hashes of 'filepath'. file_hashes = {} - # Does the path exists? - if not os.path.exists(filepath): - raise securesystemslib.exceptions.Error('Path ' + repr(filepath) + ' doest' - ' not exist.') - filepath = os.path.abspath(filepath) # Obtaining length of the file. - file_length = os.path.getsize(filepath) + file_length = storage_backend.getsize(filepath) - # Obtaining hash of the file. - for algorithm in hash_algorithms: - digest_object = securesystemslib.hash.digest_filename(filepath, algorithm) - file_hashes.update({algorithm: digest_object.hexdigest()}) + with storage_backend.get(filepath) as fileobj: + # Obtaining hash of the file. + for algorithm in hash_algorithms: + digest_object = securesystemslib.hash.digest_fileobject(fileobj, algorithm) + file_hashes.update({algorithm: digest_object.hexdigest()}) # Performing a format check to ensure 'file_hash' corresponds HASHDICT_SCHEMA. # Raise 'securesystemslib.exceptions.FormatError' if there is a mismatch. @@ -94,11 +102,12 @@ def get_file_details(filepath, hash_algorithms=['sha256']): return file_length, file_hashes -def persist_temp_file(temp_file, persist_path): +def persist_temp_file(temp_file, persist_path, storage_backend=None, + should_close=True): """ Copies 'temp_file' (a file like object) to a newly created non-temp file at - 'persist_path' and closes 'temp_file' so that it is removed. + 'persist_path'. temp_file: @@ -108,6 +117,15 @@ def persist_temp_file(temp_file, persist_path): persist_path: File path to create the persistent file in. + storage_backend: + An object which implements + securesystemslib.storage.StorageBackendInterface. When no object is + passed a FilesystemBackend will be instantiated and used. + + should_close: + A boolean indicating whether the file should be closed after it has been + persisted. Default is True, the file is closed. + None. @@ -115,19 +133,16 @@ def persist_temp_file(temp_file, persist_path): None. """ - temp_file.flush() - temp_file.seek(0) + if storage_backend is None: + storage_backend = securesystemslib.storage.FilesystemBackend() + + storage_backend.put(temp_file, persist_path) - with open(persist_path, 'wb') as destination_file: - shutil.copyfileobj(temp_file, destination_file) - # Force the destination file to be written to disk from Python's internal - # and the operation system's buffers. os.fsync() should follow flush(). - destination_file.flush() - os.fsync(destination_file.fileno()) + if should_close: + temp_file.close() - temp_file.close() -def ensure_parent_dir(filename): +def ensure_parent_dir(filename, storage_backend=None): """ To ensure existence of the parent directory of 'filename'. If the parent @@ -140,6 +155,11 @@ def ensure_parent_dir(filename): filename: A path string. + storage_backend: + An object which implements + securesystemslib.storage.StorageBackendInterface. When no object is + passed a FilesystemBackend will be instantiated and used. + securesystemslib.exceptions.FormatError: If 'filename' is improperly formatted. @@ -156,12 +176,13 @@ def ensure_parent_dir(filename): # Raise 'securesystemslib.exceptions.FormatError' on a mismatch. securesystemslib.formats.PATH_SCHEMA.check_match(filename) + if storage_backend is None: + storage_backend = securesystemslib.storage.FilesystemBackend() + # Split 'filename' into head and tail, check if head exists. directory = os.path.split(filename)[0] - if directory and not os.path.exists(directory): - # mode = 'rwx------'. 448 (decimal) is 700 in octal. - os.makedirs(directory, 448) + storage_backend.create_folder(directory) def file_in_confined_directories(filepath, confined_directories): @@ -296,7 +317,7 @@ def load_json_string(data): return deserialized_object -def load_json_file(filepath): +def load_json_file(filepath, storage_backend=None): """ Deserialize a JSON object from a file containing the object. @@ -305,6 +326,11 @@ def load_json_file(filepath): filepath: Absolute path of JSON file. + storage_backend: + An object which implements + securesystemslib.storage.StorageBackendInterface. When no object is + passed a FilesystemBackend will be instantiated and used. + securesystemslib.exceptions.FormatError: If 'filepath' is improperly formatted. @@ -325,21 +351,22 @@ def load_json_file(filepath): # securesystemslib.exceptions.FormatError is raised on incorrect format. securesystemslib.formats.PATH_SCHEMA.check_match(filepath) - deserialized_object = None - fileobject = open(filepath) + if storage_backend is None: + storage_backend = securesystemslib.storage.FilesystemBackend() - try: - deserialized_object = json.load(fileobject) + deserialized_object = None + with storage_backend.get(filepath) as file_obj: + raw_data = file_obj.read().decode('utf-8') - except (ValueError, TypeError) as e: - raise securesystemslib.exceptions.Error('Cannot deserialize to a' - ' Python object: ' + repr(filepath)) + try: + deserialized_object = json.loads(raw_data) - else: - return deserialized_object + except (ValueError, TypeError) as e: + raise securesystemslib.exceptions.Error('Cannot deserialize to a' + ' Python object: ' + filepath) - finally: - fileobject.close() + else: + return deserialized_object def digests_are_equal(digest1, digest2): diff --git a/tests/test_interface.py b/tests/test_interface.py index b439413d..1a37213d 100755 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -42,6 +42,7 @@ else: import mock +import securesystemslib.exceptions import securesystemslib.formats import securesystemslib.hash import securesystemslib.interface as interface @@ -373,8 +374,8 @@ def test_import_ed25519_publickey_from_file(self): # Non-existent key file. nonexistent_keypath = os.path.join(temporary_directory, 'nonexistent_keypath') - self.assertRaises(IOError, interface.import_ed25519_publickey_from_file, - nonexistent_keypath) + self.assertRaises(securesystemslib.exceptions.StorageError, + interface.import_ed25519_publickey_from_file, nonexistent_keypath) # Invalid key file argument. invalid_keyfile = os.path.join(temporary_directory, 'invalid_keyfile') @@ -525,8 +526,8 @@ def test_import_ecdsa_publickey_from_file(self): # Non-existent key file. nonexistent_keypath = os.path.join(temporary_directory, 'nonexistent_keypath') - self.assertRaises(IOError, interface.import_ecdsa_publickey_from_file, - nonexistent_keypath) + self.assertRaises(securesystemslib.exceptions.StorageError, + interface.import_ecdsa_publickey_from_file, nonexistent_keypath) # Invalid key file argument. invalid_keyfile = os.path.join(temporary_directory, 'invalid_keyfile') diff --git a/tests/test_util.py b/tests/test_util.py index 6a108493..1a1ea60a 100755 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -171,7 +171,7 @@ def test_B6_load_json_file(self): securesystemslib.util.load_json_file, bogus_arg) # Non-existent path. - self.assertRaises(IOError, + self.assertRaises(securesystemslib.exceptions.StorageError, securesystemslib.util.load_json_file, 'non-existent.json') # Invalid JSON content. @@ -188,11 +188,23 @@ def test_B6_load_json_file(self): def test_B7_persist_temp_file(self): # Destination directory to save the temporary file in. dest_temp_dir = self.make_temp_directory() + + # Test the default of persisting the file and closing the tmpfile dest_path = os.path.join(dest_temp_dir, self.random_string()) tmpfile = tempfile.TemporaryFile() tmpfile.write(self.random_string().encode('utf-8')) securesystemslib.util.persist_temp_file(tmpfile, dest_path) self.assertTrue(dest_path) + self.assertTrue(tmpfile.closed) + + # Test persisting a file without automatically closing the tmpfile + dest_path2 = os.path.join(dest_temp_dir, self.random_string()) + tmpfile = tempfile.TemporaryFile() + tmpfile.write(self.random_string().encode('utf-8')) + securesystemslib.util.persist_temp_file(tmpfile, dest_path2, + should_close=False) + self.assertFalse(tmpfile.closed) + tmpfile.close()