From 86493885d514248c1f8ea15f368c01a388ce10f9 Mon Sep 17 00:00:00 2001 From: Hang Zheng Date: Fri, 30 Aug 2024 20:24:33 +0800 Subject: [PATCH] refine according to review comments --- cpp/src/arrow/filesystem/filesystem_test.cc | 14 ++++++++------ cpp/src/arrow/filesystem/s3fs.cc | 8 ++++---- cpp/src/arrow/filesystem/util_internal.cc | 12 +++++------- cpp/src/arrow/filesystem/util_internal.h | 5 ++--- 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/filesystem/filesystem_test.cc b/cpp/src/arrow/filesystem/filesystem_test.cc index df21d88e41c6d..afa2d796c75f8 100644 --- a/cpp/src/arrow/filesystem/filesystem_test.cc +++ b/cpp/src/arrow/filesystem/filesystem_test.cc @@ -55,12 +55,14 @@ TEST(FileInfo, BaseName) { } TEST(CalculateSSECKeyMD5, Sanity) { - std::string lResult; - ASSERT_FALSE(CalculateSSECKeyMD5("", lResult)); // invalid base64 - ASSERT_FALSE(CalculateSSECKeyMD5("%^H", lResult)); // invalid base64 - ASSERT_FALSE(CalculateSSECKeyMD5("INVALID", lResult)); // invalid base64 - ASSERT_FALSE(CalculateSSECKeyMD5("MTIzNDU2Nzg5", lResult)); // invalid, the input key size not match - ASSERT_TRUE(CalculateSSECKeyMD5("1WH9aTJ0+Tn0NLbTMHZn9aCW3Li3ViAdBsoIldPCREw=", lResult)); // valid case + ASSERT_FALSE(CalculateSSECKeyMD5("").ok()); // invalid base64 + ASSERT_FALSE(CalculateSSECKeyMD5("%^H").ok()); // invalid base64 + ASSERT_FALSE(CalculateSSECKeyMD5("INVALID").ok()); // invalid base64 + ASSERT_FALSE(CalculateSSECKeyMD5("MTIzNDU2Nzg5").ok()); // invalid, the input key size not match + // valid case + auto result = CalculateSSECKeyMD5("1WH9aTJ0+Tn0NLbTMHZn9aCW3Li3ViAdBsoIldPCREw="); + ASSERT_TRUE(result.ok()); // valid case + ASSERT_STREQ(result->c_str(), "3HYIM58NCLwrIOdPpWnYwQ=="); // valid case } TEST(PathUtil, SplitAbstractPath) { diff --git a/cpp/src/arrow/filesystem/s3fs.cc b/cpp/src/arrow/filesystem/s3fs.cc index 65f64d2dc91f4..036b9c7b3ea3a 100644 --- a/cpp/src/arrow/filesystem/s3fs.cc +++ b/cpp/src/arrow/filesystem/s3fs.cc @@ -454,14 +454,14 @@ Status SetSSECustomerKey(S3RequestType& request, if (sse_customer_key.empty()) { return Status::OK(); // do nothing if the sse_customer_key is not configured } - std::string sse_customer_key_md5; - if (internal::CalculateSSECKeyMD5(sse_customer_key, sse_customer_key_md5)) { - request.SetSSECustomerKeyMD5(sse_customer_key_md5); + auto result = internal::CalculateSSECKeyMD5(sse_customer_key); + if (result.ok()) { + request.SetSSECustomerKeyMD5(*result); request.SetSSECustomerKey(sse_customer_key); request.SetSSECustomerAlgorithm("AES256"); return Status::OK(); } else { - return Status::Invalid("sse_customer_key is not a vaild 256-bit base64-encoded encryption key"); + return result.status(); } } diff --git a/cpp/src/arrow/filesystem/util_internal.cc b/cpp/src/arrow/filesystem/util_internal.cc index 010be0074d97a..043cfba7300d9 100644 --- a/cpp/src/arrow/filesystem/util_internal.cc +++ b/cpp/src/arrow/filesystem/util_internal.cc @@ -263,15 +263,15 @@ Result GlobFiles(const std::shared_ptr& filesystem, return out; } -bool CalculateSSECKeyMD5(const std::string& base64_encoded_key, std::string& md5_result, +Result CalculateSSECKeyMD5(const std::string& base64_encoded_key, int expect_input_key_size) { if (base64_encoded_key.size() < 2) { - return false; + return Status::Invalid("At least 2 bytes needed for the base64 encoded string"); } // Check if the string contains only valid Base64 characters for (char c : base64_encoded_key) { if (!std::isalnum(c) && c != '+' && c != '/' && c != '=') { - return false; + return Status::Invalid("Invalid character found in the base64 encoded string"); } } @@ -282,7 +282,7 @@ bool CalculateSSECKeyMD5(const std::string& base64_encoded_key, std::string& md5 // the key needs to be // 256 bits(32 bytes)according to // https://docs.aws.amazon.com/AmazonS3/latest/userguide/ServerSideEncryptionCustomerKeys.html#specifying-s3-c-encryption if (rawKey.GetLength() != expect_input_key_size) { - return false; + return Status::Invalid("Invalid Length for the key"); } // Convert the raw binary key to an Aws::String @@ -295,9 +295,7 @@ bool CalculateSSECKeyMD5(const std::string& base64_encoded_key, std::string& md5 // Base64-encode the MD5 hash Aws::String awsEncodedHash = Aws::Utils::HashingUtils::Base64Encode(md5Hash); - // Return the Base64-encoded MD5 hash as a std::string - md5_result = std::string(awsEncodedHash.begin(), awsEncodedHash.end()); - return true; + return std::string(awsEncodedHash.begin(), awsEncodedHash.end()); } diff --git a/cpp/src/arrow/filesystem/util_internal.h b/cpp/src/arrow/filesystem/util_internal.h index a9eef51607593..dc1b0668c2644 100644 --- a/cpp/src/arrow/filesystem/util_internal.h +++ b/cpp/src/arrow/filesystem/util_internal.h @@ -101,11 +101,10 @@ Result GlobFiles(const std::shared_ptr& filesystem, /// \brief Decode the Input SSE key,calculate the MD5 /// \param base64_encoded_key is the input base64 encoded sse key -/// \param md5_result, output resut /// \param expect_input_key_size, default 32 -/// \return true if the decode and calculate MD5 success, otherwise return false +/// \return the base64 encoded MD5 for the input key ARROW_EXPORT -bool CalculateSSECKeyMD5(const std::string& base64_encoded_key, std::string& md5_result, +Result CalculateSSECKeyMD5(const std::string& base64_encoded_key, int expect_input_key_size = 32);