diff --git a/cpp/src/parquet/column/CMakeLists.txt b/cpp/src/parquet/column/CMakeLists.txt index 423f54498edc8..32ec11c386eb7 100644 --- a/cpp/src/parquet/column/CMakeLists.txt +++ b/cpp/src/parquet/column/CMakeLists.txt @@ -26,3 +26,4 @@ install(FILES ADD_PARQUET_TEST(column-reader-test) ADD_PARQUET_TEST(levels-test) +ADD_PARQUET_TEST(serialized-page-test) diff --git a/cpp/src/parquet/column/serialized-page-test.cc b/cpp/src/parquet/column/serialized-page-test.cc new file mode 100644 index 0000000000000..5c49021842058 --- /dev/null +++ b/cpp/src/parquet/column/serialized-page-test.cc @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include + +#include + +#include "parquet/types.h" +#include "parquet/thrift/parquet_types.h" +#include "parquet/thrift/util.h" +#include "parquet/column/serialized-page.h" +#include "parquet/column/page.h" +#include "parquet/column/reader.h" +#include "parquet/column/test-util.h" + + +namespace parquet_cpp { + +class TestSerializedPage : public ::testing::Test { + public: + void InitSerializedPageReader(const uint8_t* buffer, size_t header_size, + parquet::CompressionCodec::type codec) { + std::unique_ptr stream; + stream.reset(new InMemoryInputStream(buffer, header_size)); + page_reader_.reset(new SerializedPageReader(std::move(stream), codec)); + } + + protected: + std::unique_ptr page_reader_; +}; + +TEST_F(TestSerializedPage, TestLargePageHeaders) { + parquet::PageHeader in_page_header; + parquet::DataPageHeader data_page_header; + parquet::PageHeader out_page_header; + parquet::Statistics stats; + int expected_header_size = 512 * 1024; //512 KB + int stats_size = 256 * 1024; // 256 KB + std::string serialized_buffer; + int num_values = 4141; + + InitStats(stats_size, stats); + InitDataPage(stats, data_page_header, num_values); + InitPageHeader(data_page_header, in_page_header); + + // Serialize the Page header + ASSERT_NO_THROW(serialized_buffer = SerializeThriftMsg(&in_page_header, + expected_header_size)); + // check header size is between 256 KB to 16 MB + ASSERT_LE(stats_size, serialized_buffer.length()); + ASSERT_GE(DEFAULT_MAX_PAGE_HEADER_SIZE, serialized_buffer.length()); + + InitSerializedPageReader(reinterpret_cast(serialized_buffer.c_str()), + serialized_buffer.length(), parquet::CompressionCodec::UNCOMPRESSED); + + std::shared_ptr current_page = page_reader_->NextPage(); + ASSERT_EQ(parquet::PageType::DATA_PAGE, current_page->type()); + const DataPage* page = static_cast(current_page.get()); + ASSERT_EQ(num_values, page->num_values()); +} + +TEST_F(TestSerializedPage, TestFailLargePageHeaders) { + parquet::PageHeader in_page_header; + parquet::DataPageHeader data_page_header; + parquet::PageHeader out_page_header; + parquet::Statistics stats; + int expected_header_size = 512 * 1024; // 512 KB + int stats_size = 256 * 1024; // 256 KB + int max_header_size = 128 * 1024; // 128 KB + int num_values = 4141; + std::string serialized_buffer; + + InitStats(stats_size, stats); + InitDataPage(stats, data_page_header, num_values); + InitPageHeader(data_page_header, in_page_header); + + // Serialize the Page header + ASSERT_NO_THROW(serialized_buffer = SerializeThriftMsg(&in_page_header, + expected_header_size)); + // check header size is between 256 KB to 16 MB + ASSERT_LE(stats_size, serialized_buffer.length()); + ASSERT_GE(DEFAULT_MAX_PAGE_HEADER_SIZE, serialized_buffer.length()); + + InitSerializedPageReader(reinterpret_cast(serialized_buffer.c_str()), + serialized_buffer.length(), parquet::CompressionCodec::UNCOMPRESSED); + + // Set the max page header size to 128 KB, which is less than the current header size + page_reader_->set_max_page_header_size(max_header_size); + + ASSERT_THROW(page_reader_->NextPage(), ParquetException); +} +} // namespace parquet_cpp diff --git a/cpp/src/parquet/column/serialized-page.cc b/cpp/src/parquet/column/serialized-page.cc index b9d470c07c147..56b73a70b86b8 100644 --- a/cpp/src/parquet/column/serialized-page.cc +++ b/cpp/src/parquet/column/serialized-page.cc @@ -33,6 +33,7 @@ namespace parquet_cpp { SerializedPageReader::SerializedPageReader(std::unique_ptr stream, parquet::CompressionCodec::type codec) : stream_(std::move(stream)) { + max_page_header_size_ = DEFAULT_MAX_PAGE_HEADER_SIZE; switch (codec) { case parquet::CompressionCodec::UNCOMPRESSED: break; @@ -44,23 +45,42 @@ SerializedPageReader::SerializedPageReader(std::unique_ptr stream, } } -// TODO(wesm): this may differ from file to file -static constexpr int DATA_PAGE_SIZE = 64 * 1024; std::shared_ptr SerializedPageReader::NextPage() { // Loop here because there may be unhandled page types that we skip until // finding a page that we do know what to do with while (true) { int64_t bytes_read = 0; - const uint8_t* buffer = stream_->Peek(DATA_PAGE_SIZE, &bytes_read); - if (bytes_read == 0) { - return std::shared_ptr(nullptr); - } - - // This gets used, then set by DeserializeThriftMsg - uint32_t header_size = bytes_read; - DeserializeThriftMsg(buffer, &header_size, ¤t_page_header_); + int64_t bytes_available = 0; + uint32_t header_size = 0; + const uint8_t* buffer; + uint32_t allowed_page_size = DEFAULT_PAGE_HEADER_SIZE; + std::stringstream ss; + + // Page headers can be very large because of page statistics + // We try to deserialize a larger buffer progressively + // until a maximum allowed header limit + while (true) { + buffer = stream_->Peek(allowed_page_size, &bytes_available); + if (bytes_available == 0) { + return std::shared_ptr(nullptr); + } + // This gets used, then set by DeserializeThriftMsg + header_size = bytes_available; + try { + DeserializeThriftMsg(buffer, &header_size, ¤t_page_header_); + break; + } catch (std::exception& e) { + // Failed to deserialize. Double the allowed page header size and try again + ss << e.what(); + allowed_page_size *= 2; + if (allowed_page_size > max_page_header_size_) { + ss << "Deserializing page header failed.\n"; + throw ParquetException(ss.str()); + } + } + } // Advance the stream offset stream_->Read(header_size, &bytes_read); diff --git a/cpp/src/parquet/column/serialized-page.h b/cpp/src/parquet/column/serialized-page.h index c02152ffcc335..62bf66df9a697 100644 --- a/cpp/src/parquet/column/serialized-page.h +++ b/cpp/src/parquet/column/serialized-page.h @@ -32,6 +32,10 @@ namespace parquet_cpp { +// 16 MB is the default maximum page header size +static constexpr uint32_t DEFAULT_MAX_PAGE_HEADER_SIZE = 16 * 1024 * 1024; +// 16 KB is the default expected page header size +static constexpr uint32_t DEFAULT_PAGE_HEADER_SIZE = 16 * 1024; // This subclass delimits pages appearing in a serialized stream, each preceded // by a serialized Thrift parquet::PageHeader indicating the type of each page // and the page metadata. @@ -45,6 +49,10 @@ class SerializedPageReader : public PageReader { // Implement the PageReader interface virtual std::shared_ptr NextPage(); + void set_max_page_header_size(uint32_t size) { + max_page_header_size_ = size; + } + private: std::unique_ptr stream_; @@ -54,6 +62,8 @@ class SerializedPageReader : public PageReader { // Compression codec to use. std::unique_ptr decompressor_; std::vector decompression_buffer_; + // Maximum allowed page size + uint32_t max_page_header_size_; }; } // namespace parquet_cpp diff --git a/cpp/src/parquet/column/test-util.h b/cpp/src/parquet/column/test-util.h index 1cbcf8c9bb62c..90dde3bfc8425 100644 --- a/cpp/src/parquet/column/test-util.h +++ b/cpp/src/parquet/column/test-util.h @@ -25,9 +25,9 @@ #include #include #include +#include #include "parquet/column/page.h" - namespace parquet_cpp { namespace test { @@ -174,9 +174,33 @@ static std::shared_ptr MakeDataPage(const std::vector& values, return std::make_shared(&(*out_buffer)[0], out_buffer->size(), page_header); } +} // namespace test +static inline void InitDataPage(const parquet::Statistics& stat, + parquet::DataPageHeader& data_page, int nvalues) { + data_page.encoding = parquet::Encoding::PLAIN; + data_page.definition_level_encoding = parquet::Encoding::RLE; + data_page.repetition_level_encoding = parquet::Encoding::RLE; + data_page.num_values = nvalues; + data_page.__set_statistics(stat); +} -} // namespace test +static inline void InitStats(size_t stat_size, parquet::Statistics& stat) { + std::vector stat_buffer; + stat_buffer.resize(stat_size); + for (int i = 0; i < stat_size; i++) { + (reinterpret_cast(stat_buffer.data()))[i] = i % 255; + } + stat.__set_max(std::string(stat_buffer.data(), stat_size)); +} + +static inline void InitPageHeader(const parquet::DataPageHeader &data_page, + parquet::PageHeader& page_header) { + page_header.__set_data_page_header(data_page); + page_header.uncompressed_page_size = 0; + page_header.compressed_page_size = 0; + page_header.type = parquet::PageType::DATA_PAGE; +} } // namespace parquet_cpp diff --git a/cpp/src/parquet/thrift/CMakeLists.txt b/cpp/src/parquet/thrift/CMakeLists.txt index 384bc19295196..30150ca18fc6f 100644 --- a/cpp/src/parquet/thrift/CMakeLists.txt +++ b/cpp/src/parquet/thrift/CMakeLists.txt @@ -30,3 +30,5 @@ install(FILES parquet_constants.h util.h DESTINATION include/parquet/thrift) + +ADD_PARQUET_TEST(serializer-test) diff --git a/cpp/src/parquet/thrift/serializer-test.cc b/cpp/src/parquet/thrift/serializer-test.cc new file mode 100644 index 0000000000000..e89b1080db0e7 --- /dev/null +++ b/cpp/src/parquet/thrift/serializer-test.cc @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include + +#include + +#include "parquet/thrift/parquet_types.h" +#include "parquet/thrift/util.h" +#include "parquet/column/page.h" +#include "parquet/column/reader.h" +#include "parquet/column/test-util.h" + +using std::string; + +namespace parquet_cpp { + +class TestThrift : public ::testing::Test { + +}; + +TEST_F(TestThrift, TestSerializerDeserializer) { + parquet::PageHeader in_page_header; + parquet::DataPageHeader data_page_header; + parquet::PageHeader out_page_header; + parquet::Statistics stats; + uint32_t max_header_len = 1024; + uint32_t expected_header_size = 1024; + uint32_t stats_size = 512; + std::string serialized_buffer; + int num_values = 4444; + + InitStats(stats_size, stats); + InitDataPage(stats, data_page_header, num_values); + InitPageHeader(data_page_header, in_page_header); + + // Serialize the Page header + ASSERT_NO_THROW(serialized_buffer = SerializeThriftMsg(&in_page_header, expected_header_size)); + ASSERT_LE(stats_size, serialized_buffer.length()); + ASSERT_GE(max_header_len, serialized_buffer.length()); + + uint32_t header_size = 1024; + // Deserialize the serialized page buffer + ASSERT_NO_THROW(DeserializeThriftMsg(reinterpret_cast(serialized_buffer.c_str()), + &header_size, &out_page_header)); + ASSERT_LE(stats_size, header_size); + ASSERT_GE(max_header_len, header_size); + + ASSERT_EQ(parquet::Encoding::PLAIN, out_page_header.data_page_header.encoding); + ASSERT_EQ(parquet::Encoding::RLE, out_page_header.data_page_header.definition_level_encoding); + ASSERT_EQ(parquet::Encoding::RLE, out_page_header.data_page_header.repetition_level_encoding); + for(int i = 0; i < stats_size; i++){ + EXPECT_EQ(i % 255, (reinterpret_cast + (out_page_header.data_page_header.statistics.max.c_str()))[i]); + } + ASSERT_EQ(parquet::PageType::DATA_PAGE, out_page_header.type); + ASSERT_EQ(num_values, out_page_header.data_page_header.num_values); + +} + +} // namespace parquet_cpp diff --git a/cpp/src/parquet/thrift/util.h b/cpp/src/parquet/thrift/util.h index ecf24c65cd00f..a472dc27342e2 100644 --- a/cpp/src/parquet/thrift/util.h +++ b/cpp/src/parquet/thrift/util.h @@ -15,7 +15,9 @@ #include #include +#include +#include "parquet/util/logging.h" #include "parquet/exception.h" namespace parquet_cpp { @@ -34,13 +36,37 @@ inline void DeserializeThriftMsg(const uint8_t* buf, uint32_t* len, T* deseriali tproto_factory.getProtocol(tmem_transport); try { deserialized_msg->read(tproto.get()); - } catch (apache::thrift::protocol::TProtocolException& e) { - throw ParquetException("Couldn't deserialize thrift.", e); + } catch (std::exception& e) { + std::stringstream ss; + ss << "Couldn't deserialize thrift: " << e.what() << "\n"; + throw ParquetException(ss.str()); } uint32_t bytes_left = tmem_transport->available_read(); *len = *len - bytes_left; } +// Serialize obj into a buffer. The result is returned as a string. +// The arguments are the object to be serialized and +// the expected size of the serialized object +template +inline std::string SerializeThriftMsg(T* obj, uint32_t len) { + boost::shared_ptr mem_buffer( + new apache::thrift::transport::TMemoryBuffer(len)); + apache::thrift::protocol::TCompactProtocolFactoryT< + apache::thrift::transport::TMemoryBuffer> tproto_factory; + boost::shared_ptr tproto = + tproto_factory.getProtocol(mem_buffer); + try { + mem_buffer->resetBuffer(); + obj->write(tproto.get()); + } catch (std::exception& e) { + std::stringstream ss; + ss << "Couldn't serialize thrift: " << e.what() << "\n"; + throw ParquetException(ss.str()); + } + return mem_buffer->getBufferAsString(); +} + } // namespace parquet_cpp #endif // PARQUET_THRIFT_UTIL_H