diff --git a/src/nanoarrow/nanoarrow_testing.hpp b/src/nanoarrow/nanoarrow_testing.hpp index d59bea0ac..e1b562b76 100644 --- a/src/nanoarrow/nanoarrow_testing.hpp +++ b/src/nanoarrow/nanoarrow_testing.hpp @@ -2916,6 +2916,140 @@ class TestingJSONComparison { /// @} +namespace dsl { +/// \defgroup nanoarrow_testing-schema Schema factory DSL +/// +/// @{ + +/// \brief An alias to express a sequence of key value pairs. +/// +/// Each pair is a string formatted like "key=value". +using metadata = std::vector; + +/// \brief A wrapper around UniqueSchema with easy constructors +class schema; + +/// \brief Alias schema to readably tag a vector of schemas +using children = std::vector; + +/// \brief A wrapper around schema to readably tag a dictionary +struct dictionary : UniqueSchema { + template + explicit dictionary(Args... args); +}; + +class schema { + public: + friend struct dictionary; + + schema(schema const& other) : unique_schema{std::move(other.unique_schema)} {} + schema(schema&& other) = default; + + schema(UniqueSchema unique_schema) : unique_schema{std::move(unique_schema)} {} + + template + schema(Args... args) { + ArrowSchemaInit(get()); + get()->private_data = new Private; + get()->release = [](struct ArrowSchema* schema) { + delete static_cast(schema->private_data); + }; + set(std::move(args)...); + } + + private: + mutable UniqueSchema unique_schema; + + struct Private { + std::string format; + std::string name; + std::string metadata; + + std::vector schemas; + std::vector children; + UniqueSchema dictionary; + }; + + Private* get_private() { return static_cast(get()->private_data); } + ArrowSchema* get() { return unique_schema.get(); } + + void set() { + // set string pointers in the actual ArrowSchema + // if format was not set, then set it to struct now + get()->format = get_private()->format.empty() ? "+s" : get_private()->format.c_str(); + get()->name = get_private()->name.c_str(); + get()->metadata = get_private()->metadata.c_str(); + } + + template + void set(std::string format_or_name, Args... args) { + (get_private()->format.empty() // assign to format if it's empty + ? get_private()->format + : get_private()->name) = std::move(format_or_name); + } + + template + void set(int64_t flags, Args... args) { + get()->flags = flags; + set(std::move(args)...); + } + + template + void set(dictionary d, Args... args) { + get_private()->dictionary = std::move(d); + set(std::move(args)...); + } + + template + void set(children c, Args... args) { + get()->n_children = c.size(); + get_private()->children.resize(c.size()); + get()->children = get_private()->children.data(); + get_private()->schemas = std::move(c); + + int i = 0; + for (auto& child : get_private()->children) { + child = get_private()->schemas[i++].get(); + } + set(std::move(args)...); + } + + template + void set(metadata m, Args... args) { + std::string metadata_buf; + auto append_int32 = [&](size_t i) { + auto i32 = static_cast(i); + char chars[sizeof(int32_t)]; + memcpy(&chars, &i32, sizeof(int32_t)); + metadata_buf.append(chars, sizeof(int32_t)); + }; + append_int32(m.size()); + for (const std::string& kv : m) { + size_t key_size = kv.find_first_of('='); + if (key_size == std::string::npos) { + continue; + } + append_int32(key_size); + metadata_buf.append(kv.data(), key_size); + + size_t value_size = kv.size() - key_size - 1; + append_int32(value_size); + metadata_buf.append(kv.data() + key_size + 1, value_size); + } + get_private()->metadata = std::move(metadata_buf); + get()->metadata = get_private()->metadata.c_str(); + + set(std::move(args)...); + } +}; + +template +dictionary::dictionary(Args... args) + : UniqueSchema{std::move(schema{std::move(args)...}.unique_schema)} {} + +/// @} +} // namespace dsl + } // namespace testing } // namespace nanoarrow diff --git a/src/nanoarrow/testing/testing_test.cc b/src/nanoarrow/testing/testing_test.cc index 90abef8a1..fe0594300 100644 --- a/src/nanoarrow/testing/testing_test.cc +++ b/src/nanoarrow/testing/testing_test.cc @@ -1908,3 +1908,19 @@ TEST(NanoarrowTestingTest, NanoarrowTestingTestArrayStreamComparison) { )"); } + +TEST(SchemaDsl, Basic) { + using namespace nanoarrow::testing::dsl; + + schema{children{ + {"i", "int32 field name", + metadata{ + "some_key=some_value", + }}, + {"i", dictionary{"u"}, "dictionary field name", + metadata{ + "some_key=some_value", + }, + ARROW_FLAG_NULLABLE}, + }}; +}