From 305a795d5fe30c95ed308b3846fe4aacfc1c9a59 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Thu, 15 Feb 2024 14:41:24 -0500 Subject: [PATCH 1/2] feat: read_span_flatbuffer() memory use and correctness fixes (#4684) * read_span_flatbuffer: support cleaning up spare examples correctly Consumers of VW as a library can provide their own event pools, etc. Previous parsers were always able to predict when an even would be needed ahead of time, so would only allocate when necessary. This was done by relying on a single incoming event preallocation to let the external host deallocate in the case of nothing to be parsed. This does not work for the FB parser due to how it handles re-entrancy, and we do not want to spend the time re-architecting it to avoid this. The fix, in this case, is to expand the API to include a callback to return spare events back to the host's event pool. * support header for using VW and RL api_status.h side-by-side * Add additional error reporting to read_span_flatbuffer * Reset parser when re-entering after bad parse When re-using the flatbuffer parser across multiple invocations, the parser state could become invalid (retain references to deleted objects) * Add more tests for bad inputs * Add comments about what is going on in read_span_flatbuffer * Fix a place where the parser was returning the semantically incorrect error code * Remove dead code --- .../core/include/vw/core/error_data.h | 3 + .../vw/fb_parser/parse_example_flatbuffer.h | 24 ++- .../fb_parser/src/parse_example_flatbuffer.cc | 130 ++++++++------ vowpalwabbit/fb_parser/src/parse_label.cc | 1 + .../fb_parser/tests/example_data_generator.h | 88 ++++++++++ .../tests/flatbuffer_parser_tests.cc | 3 +- .../fb_parser/tests/prototype_typemappings.h | 6 + .../fb_parser/tests/read_span_tests.cc | 165 +++++++++++++++++- 8 files changed, 361 insertions(+), 59 deletions(-) diff --git a/vowpalwabbit/core/include/vw/core/error_data.h b/vowpalwabbit/core/include/vw/core/error_data.h index aea7e246b32..18ec7fb92bc 100644 --- a/vowpalwabbit/core/include/vw/core/error_data.h +++ b/vowpalwabbit/core/include/vw/core/error_data.h @@ -25,6 +25,9 @@ ERROR_CODE_DEFINITION( ERROR_CODE_DEFINITION( 13, fb_parser_size_mismatch_ft_names_ft_values, "Size of feature names and feature values do not match. ") ERROR_CODE_DEFINITION(14, unknown_label_type, "Label type in Flatbuffer not understood. ") +ERROR_CODE_DEFINITION(15, fb_parser_span_misaligned, "Input Flatbuffer span is not aligned to an 8-byte boundary. ") +ERROR_CODE_DEFINITION( + 16, fb_parser_span_length_mismatch, "Input Flatbuffer span does not match flatbuffer size prefix. ") // TODO: This is temporary until we switch to the new error handling mechanism. ERROR_CODE_DEFINITION(10000, vw_exception, "vw_exception: ") diff --git a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h index fa181c1ea46..7c5be6cd480 100644 --- a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h +++ b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h @@ -4,7 +4,6 @@ #pragma once -#include "vw/core/api_status.h" #include "vw/core/example.h" #include "vw/core/multi_ex.h" #include "vw/core/shared_data.h" @@ -14,15 +13,21 @@ namespace VW { +namespace experimental +{ class api_status; +} + +using example_sink_f = std::function; namespace parsers { namespace flatbuffer { int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples); -bool read_span_flatbuffer( - VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples); + +int read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, + VW::multi_ex& examples, example_sink_f example_sink = nullptr, VW::experimental::api_status* status = nullptr); class parser { @@ -57,6 +62,19 @@ class parser VW::experimental::api_status* status = nullptr); int get_namespace_index(const Namespace* ns, namespace_index& ni, VW::experimental::api_status* status = nullptr); + inline void reset_active_multi_ex() + { + _multi_ex_index = 0; + _active_multi_ex = false; + _multi_example_object = nullptr; + } + + inline void reset_active_collection() + { + _example_index = 0; + _active_collection = false; + } + void parse_simple_label(shared_data* sd, polylabel* l, reduction_features* red_features, const SimpleLabel* label); void parse_cb_label(polylabel* l, const CBLabel* label); void parse_ccb_label(polylabel* l, const CCBLabel* label); diff --git a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc index f70e61f6a93..966377505f9 100644 --- a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc +++ b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc @@ -5,12 +5,14 @@ #include "vw/fb_parser/parse_example_flatbuffer.h" #include "vw/core/action_score.h" +#include "vw/core/api_status.h" #include "vw/core/best_constant.h" #include "vw/core/cb.h" #include "vw/core/constant.h" #include "vw/core/error_constants.h" #include "vw/core/global_data.h" #include "vw/core/parser.h" +#include "vw/core/scope_exit.h" #include "vw/core/vw.h" #include @@ -43,8 +45,8 @@ int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl return static_cast(status.get_error_code() == VW::experimental::error_code::success); } -bool read_span_flatbuffer( - VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples) +int read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, + VW::multi_ex& examples, example_sink_f example_sink, VW::experimental::api_status* status) { // we expect context to contain a size_prefixed flatbuffer (technically a binary string) // which means: @@ -59,7 +61,6 @@ bool read_span_flatbuffer( // thus context.size() = sizeof(length) + length io_buf unused; - // TODO: How do we report errors out of here? (This is a general API problem with the parsers) size_t address = reinterpret_cast(span); if (address % 8 != 0) { @@ -67,8 +68,8 @@ bool read_span_flatbuffer( sstream << "fb_parser error: flatbuffer data not aligned to 8 bytes" << std::endl; sstream << " span => @" << std::hex << address << std::dec << " % " << 8 << " = " << address % 8 << " (vs desired = " << 0 << ")"; - THROW(sstream.str()); - return false; + + RETURN_ERROR_LS(status, fb_parser_span_misaligned) << sstream.str(); } flatbuffers::uoffset_t flatbuffer_object_size = @@ -79,42 +80,80 @@ bool read_span_flatbuffer( sstream << "fb_parser error: flatbuffer size prefix does not match actual size" << std::endl; sstream << " span => @" << std::hex << address << std::dec << " size_prefix = " << flatbuffer_object_size << " length = " << length; - THROW(sstream.str()); - return false; + + RETURN_ERROR_LS(status, fb_parser_span_length_mismatch) << sstream.str(); } VW::multi_ex temp_ex; - temp_ex.push_back(&example_factory()); + + // Use scope_exit because the parser reports errors by throwing exceptions (the code path in the vw driver + // uses the return value to signal completion, not errors). + auto scope_guard = VW::scope_exit( + [&temp_ex, &all, &example_sink]() + { + if (example_sink == nullptr) { VW::finish_example(*all, temp_ex); } + else { example_sink(std::move(temp_ex)); } + }); + + // There is a bit of unhappiness with the interface of the read_XYZ_() functions, because they often + // expect the input multi_ex to have a single "empty" example there. This contributes, in part, to the large + // proliferation of entry points into the JSON parser(s). We want to avoid exposing that insofar as possible, + // so we will check whether we already received a perfectly good example and use that, or create a new one if + // needed. + if (examples.size() > 0) + { + assert(examples.size() == 1); + temp_ex.push_back(examples[0]); + examples.pop_back(); + } + else { temp_ex.push_back(&example_factory()); } bool has_more = true; - VW::experimental::api_status status; do { - switch (all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, &status)) + switch (int result = all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, status)) { case VW::experimental::error_code::success: has_more = true; break; + // Because nothing_to_parse is not an error we have to filter it out here, otherwise + // we could simply do RETURN_IF_FAIL(result) and let the macro handle it. case VW::experimental::error_code::nothing_to_parse: has_more = false; break; default: - std::stringstream sstream; - sstream << "Error parsing examples: " << std::endl; - THROW(sstream.str()); - return false; + RETURN_IF_FAIL(result); } + // The underlying parser will emit a newline example when terminating the parsing + // of a multi_ex block. Since we are collecting it into a multi_ex, we want to + // swallow it here, but should the parser not have followed its contract w.r.t. + // the return value, we should use the presence of the newline example to override + // has_more. has_more &= !temp_ex[0]->is_newline; + // If this is a real example, we need to move it to the output multi_ex; we also + // need to create a new example to replace it for the next run through the parser. if (!temp_ex[0]->is_newline) { - examples.push_back(&example_factory()); - std::swap(examples[examples.size() - 1], temp_ex[0]); + // We avoid doing moves or copy construction here because multi_ex contains + // example pointers. The compile-time code here is meant to call attention + // to here if the underlying type changes. + using temp_ex_element_t = std::remove_reference::type; + using examples_element_t = std::remove_reference::type; + + static_assert(std::is_same::value && + std::is_same::value, + "temp_ex and example must be vector-like over VW::example*"); + + examples.push_back(temp_ex[0]); + + // Since we are using a vector of pointers, we can simply reassign the slot to + // the pointer of the newly created destination example for the parser. + temp_ex[0] = &example_factory(); } } while (has_more); - VW::finish_example(*all, temp_ex); - return true; + return VW::experimental::error_code::success; } const VW::parsers::flatbuffer::ExampleRoot* parser::data() { return _data; } @@ -198,16 +237,17 @@ int parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples, { _active_multi_ex = true; _multi_example_object = _data->example_obj_as_ExampleCollection()->multi_examples()->Get(_example_index); + + // read from active multi_ex RETURN_IF_FAIL(parse_multi_example(all, examples[0], _multi_example_object, status)); - // read from active collection + // if we are done with the multi example, move to the next one, or finish the collection if (!_active_multi_ex) { _example_index++; if (_example_index == _data->example_obj_as_ExampleCollection()->multi_examples()->size()) { - _example_index = 0; - _active_collection = false; + reset_active_collection(); } } } @@ -216,11 +256,7 @@ int parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples, const auto ex = _data->example_obj_as_ExampleCollection()->examples()->Get(_example_index); RETURN_IF_FAIL(parse_example(all, examples[0], ex, status)); _example_index++; - if (_example_index == _data->example_obj_as_ExampleCollection()->examples()->size()) - { - _example_index = 0; - _active_collection = false; - } + if (_example_index == _data->example_obj_as_ExampleCollection()->examples()->size()) { reset_active_collection(); } } return VW::experimental::error_code::success; } @@ -231,6 +267,20 @@ int parser::parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl #define RETURN_SUCCESS_FINISHED() \ return buffer_pointer ? VW::experimental::error_code::nothing_to_parse : VW::experimental::error_code::success; + // If we are re-using a single parser instance across multiple invocations, we need to reset + // the state when we get a new buffer_pointer. Otherwise we may be in the middle of a multi_ex + // or example_collection, and the following parse will attempt to reuse the object references + // from the previous buffer, which may have been deallocated. + // TODO: Rewrite the parser to avoid this convoluted, re-entrant logic. + if (buffer_pointer && _flatbuffer_pointer != buffer_pointer) + { + reset_active_multi_ex(); + reset_active_collection(); + } + + // The ExampleCollection processing code owns dispatching to parse_multi_example to handle + // iteration through the outer collection correctly, thus it must have the first chance to + // incoming parse request. if (_active_collection) { RETURN_IF_FAIL(process_collection_item(all, examples, status)); @@ -307,9 +357,7 @@ int parser::parse_multi_example( { // done with multi example, send a newline example and reset ae->is_newline = true; - _multi_ex_index = 0; - _active_multi_ex = false; - _multi_example_object = nullptr; + reset_active_multi_ex(); return VW::experimental::error_code::success; } @@ -325,30 +373,11 @@ int parser::get_namespace_index(const Namespace* ns, namespace_index& ni, VW::ex ni = static_cast(ns->name()->c_str()[0]); return VW::experimental::error_code::success; } - else if (flatbuffers::IsFieldPresent(ns, Namespace::VT_HASH)) + else { ni = ns->hash(); return VW::experimental::error_code::success; } - - if (_active_collection && _active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_name_hash_missing) - << "Either name or hash field must be specified to get the namespace index in collection item with example " - "index " - << _example_index << "and multi example index " << _multi_ex_index; - } - else if (_active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_name_hash_missing) - << "Either name or hash field must be specified to get the namespace index in multi example index " - << _multi_ex_index; - } - else - { - RETURN_ERROR_LS(status, fb_parser_name_hash_missing) - << "Either name or hash field must be specified to get the namespace index"; - } } bool get_namespace_hash(VW::workspace* all, const Namespace* ns, uint64_t& hash) @@ -462,7 +491,7 @@ int parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* n } else { - if (!has_hashes) { RETURN_NS_PARSER_ERROR(status, fb_parser_name_hash_missing) } + if (!has_hashes) { RETURN_NS_PARSER_ERROR(status, fb_parser_feature_hashes_names_missing) } if (ns->feature_hashes()->size() != ns->feature_values()->size()) { @@ -541,6 +570,7 @@ int parser::parse_flat_label( break; } case Label_NONE: + case Label_no_label: break; default: if (_active_collection && _active_multi_ex) diff --git a/vowpalwabbit/fb_parser/src/parse_label.cc b/vowpalwabbit/fb_parser/src/parse_label.cc index c236747569f..663d54241f1 100644 --- a/vowpalwabbit/fb_parser/src/parse_label.cc +++ b/vowpalwabbit/fb_parser/src/parse_label.cc @@ -3,6 +3,7 @@ // license as described in the file LICENSE. #include "vw/core/action_score.h" +#include "vw/core/api_status.h" #include "vw/core/best_constant.h" #include "vw/core/cb.h" #include "vw/core/constant.h" diff --git a/vowpalwabbit/fb_parser/tests/example_data_generator.h b/vowpalwabbit/fb_parser/tests/example_data_generator.h index b474d3b0c44..6b12f9636fe 100644 --- a/vowpalwabbit/fb_parser/tests/example_data_generator.h +++ b/vowpalwabbit/fb_parser/tests/example_data_generator.h @@ -9,13 +9,19 @@ #include "prototype_example_root.h" #include "prototype_label.h" #include "prototype_namespace.h" +#include "vw/common/future_compat.h" #include "vw/common/hash.h" #include "vw/common/random.h" +#include "vw/core/error_constants.h" +#include "vw/fb_parser/generated/example_generated.h" #include USE_PROTOTYPE_MNEMONICS_EX +using namespace flatbuffers; +namespace fb = VW::parsers::flatbuffer; + namespace vwtest { @@ -26,6 +32,10 @@ class example_data_generator static VW::rand_state create_test_random_state(); + inline bool random_bool() { return rng.get_and_update_random() >= 0.5; } + + inline int random_int(int min, int max) { return static_cast(rng.get_and_update_random() * (max - min) + min); } + prototype_namespace_t create_namespace(std::string name, uint8_t numeric_features, uint8_t string_features); prototype_example_t create_simple_example(uint8_t numeric_features, uint8_t string_features); @@ -40,8 +50,86 @@ class example_data_generator prototype_example_collection_t create_simple_log( uint8_t num_examples, uint8_t numeric_features, uint8_t string_features); +public: + enum NamespaceErrors + { + BAD_NAMESPACE_NO_ERROR = 0, + BAD_NAMESPACE_NAME_HASH_MISSING = 1, // not actually possible, due to how fb works + BAD_NAMESPACE_FEATURE_VALUES_MISSING = 2, + BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH = 4, + BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH = 8, + BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING = 16, + }; + + template + Offset create_bad_namespace(FlatBufferBuilder& builder, VW::workspace& w); + private: VW::rand_state rng; }; +template +Offset example_data_generator::create_bad_namespace(FlatBufferBuilder& builder, VW::workspace& w) +{ + prototype_namespace_t ns = create_namespace("BadNamespace", 1, 1); + if VW_STD17_CONSTEXPR (errors == NamespaceErrors::BAD_NAMESPACE_NO_ERROR) return ns.create_flatbuffer(builder, w); + + constexpr bool include_ns_name_hash = !(errors & NamespaceErrors::BAD_NAMESPACE_NAME_HASH_MISSING); + constexpr bool include_feature_values = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_MISSING); + + constexpr bool include_feature_hashes = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING || + // If we want to check for name/value mismatch, then we need to avoid + // including the feature hashes, as they will be used as a backup + errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH); + constexpr bool skip_a_feature_hash = (errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH); + static_assert(!skip_a_feature_hash || include_feature_hashes, "Cannot skip a feature hash if they are not included"); + + constexpr bool include_feature_names = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING); + constexpr bool skip_a_feature_name = (errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH); + static_assert(!skip_a_feature_name || include_feature_names, "Cannot skip a feature name if they are not included"); + + std::vector> feature_names; + std::vector feature_values; + std::vector feature_hashes; + + for (size_t i = 0; i < ns.features.size(); i++) + { + const auto& f = ns.features[i]; + + if (include_feature_names && (!skip_a_feature_name || i == 1)) + { + feature_names.push_back(builder.CreateString(f.name)); + } + + if VW_STD17_CONSTEXPR (include_feature_values) feature_values.push_back(f.value); + + if (include_feature_hashes && (!skip_a_feature_hash || i == 0)) { feature_hashes.push_back(f.hash); } + } + + Offset name_offset = Offset(); + if (include_ns_name_hash) { name_offset = builder.CreateString(ns.name); } + + // This function attempts to, insofar as possible, generate a layout that looks like it could have + // been created using the normal serialization code: In this case, that means that the strings for + // the feature names are serialized into the builder before a call to CreateNamespaceDirect is made, + // which is where the feature_names vector is allocated. + Offset>> feature_names_offset = + include_feature_names ? builder.CreateVector(feature_names) : Offset>>(); + Offset> feature_values_offset = + include_feature_values ? builder.CreateVector(feature_values) : Offset>(); + Offset> feature_hashes_offset = + include_feature_hashes ? builder.CreateVector(feature_hashes) : Offset>(); + + fb::NamespaceBuilder ns_builder(builder); + + if VW_STD17_CONSTEXPR (include_ns_name_hash) ns_builder.add_full_hash(VW::hash_space(w, ns.name)); + if VW_STD17_CONSTEXPR (include_feature_hashes) ns_builder.add_feature_hashes(feature_hashes_offset); + if VW_STD17_CONSTEXPR (include_feature_values) ns_builder.add_feature_values(feature_values_offset); + if VW_STD17_CONSTEXPR (include_feature_names) ns_builder.add_feature_names(feature_names_offset); + if VW_STD17_CONSTEXPR (include_ns_name_hash) ns_builder.add_name(name_offset); + + ns_builder.add_hash(ns.feature_group); + return ns_builder.Finish(); +} + } // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc index 35547b0f43e..4170330d5fd 100644 --- a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc +++ b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc @@ -9,6 +9,7 @@ #include "prototype_namespace.h" #include "vw/common/future_compat.h" #include "vw/common/string_view.h" +#include "vw/core/api_status.h" #include "vw/core/constant.h" #include "vw/core/error_constants.h" #include "vw/core/example.h" @@ -253,7 +254,7 @@ TEST(FlatbufferParser, SingleExample_MissingFeatureIndices) examples.push_back(&VW::get_unused_example(all.get())); VW::io_buf unused_buffer; EXPECT_EQ(all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf), - VW::experimental::error_code::fb_parser_name_hash_missing); + VW::experimental::error_code::fb_parser_feature_hashes_names_missing); EXPECT_EQ(all->parser_runtime.example_parser->reader(all.get(), unused_buffer, examples), 0); auto example = all->parser_runtime.flat_converter->data()->example_obj_as_Example(); diff --git a/vowpalwabbit/fb_parser/tests/prototype_typemappings.h b/vowpalwabbit/fb_parser/tests/prototype_typemappings.h index 455eae702ff..1a5ca09e21b 100644 --- a/vowpalwabbit/fb_parser/tests/prototype_typemappings.h +++ b/vowpalwabbit/fb_parser/tests/prototype_typemappings.h @@ -20,18 +20,24 @@ template <> struct fb_type { using type = VW::parsers::flatbuffer::Example; + + constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_Example; }; template <> struct fb_type { using type = VW::parsers::flatbuffer::MultiExample; + + constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_MultiExample; }; template <> struct fb_type { using type = VW::parsers::flatbuffer::ExampleCollection; + + constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_ExampleCollection; }; using union_t = void; diff --git a/vowpalwabbit/fb_parser/tests/read_span_tests.cc b/vowpalwabbit/fb_parser/tests/read_span_tests.cc index acbee2f529d..66a68960a27 100644 --- a/vowpalwabbit/fb_parser/tests/read_span_tests.cc +++ b/vowpalwabbit/fb_parser/tests/read_span_tests.cc @@ -9,6 +9,7 @@ #include "vw/common/string_view.h" #include "vw/core/constant.h" #include "vw/core/error_constants.h" +#include "vw/core/scope_exit.h" #include "vw/core/vw.h" #include "vw/fb_parser/parse_example_flatbuffer.h" #include "vw/test_common/test_common.h" @@ -66,7 +67,7 @@ inline void verify_multi_ex( } // namespace vwtest template ::type> -void create_flatbuffer_span_and_validate(VW::workspace& w, const T& prototype) +void create_flatbuffer_span_and_validate(VW::workspace& w, vwtest::example_data_generator& data_gen, const T& prototype) { // This is what we expect to see when we use read_span_flatbuffer, since this is intended // to be used for inference, and we would prefer not to force consumers of the API to have @@ -84,6 +85,8 @@ void create_flatbuffer_span_and_validate(VW::workspace& w, const T& prototype) flatbuffers::uoffset_t size = builder.GetSize(); VW::multi_ex parsed_examples; + if (data_gen.random_bool()) { parsed_examples.push_back(&ex_fac()); } + VW::parsers::flatbuffer::read_span_flatbuffer(&w, buffer, size, ex_fac, parsed_examples); verify_multi_ex(w, prototype, parsed_examples); @@ -99,7 +102,7 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_SingleExample) vwtest::prototype_example_t prototype = { {data_gen.create_namespace("A", 3, 4), data_gen.create_namespace("B", 2, 5)}, vwtest::simple_label(1.0f)}; - create_flatbuffer_span_and_validate(*all, prototype); + create_flatbuffer_span_and_validate(*all, data_gen, prototype); } TEST(FlatbufferParser, ReadSpanFlatbuffer_MultiExample) @@ -109,7 +112,7 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_MultiExample) vwtest::example_data_generator data_gen; vwtest::prototype_multiexample_t prototype = data_gen.create_cb_adf_example(3, 1, "tag"); - create_flatbuffer_span_and_validate(*all, prototype); + create_flatbuffer_span_and_validate(*all, data_gen, prototype); } TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionSinglelines) @@ -119,7 +122,7 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionSinglelines) vwtest::example_data_generator data_gen; vwtest::prototype_example_collection_t prototype = data_gen.create_simple_log(3, 3, 4); - create_flatbuffer_span_and_validate(*all, prototype); + create_flatbuffer_span_and_validate(*all, data_gen, prototype); } TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionMultiline) @@ -129,5 +132,157 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionMultiline) vwtest::example_data_generator data_gen; vwtest::prototype_example_collection_t prototype = data_gen.create_cb_adf_log(1, 3, 4); - create_flatbuffer_span_and_validate(*all, prototype); + create_flatbuffer_span_and_validate(*all, data_gen, prototype); +} + +template +void finish_flatbuffer_and_expect_error(FlatBufferBuilder& builder, Offset root, VW::workspace& w) +{ + VW::example_factory_t ex_fac = [&w]() -> VW::example& { return VW::get_unused_example(&w); }; + VW::example_sink_f ex_sink = [&w](VW::multi_ex&& ex) { VW::finish_example(w, ex); }; + if (vwtest::example_data_generator{}.random_bool()) + { + // This is only valid because ex_fac is grabbing an example from the VW example pool + ex_sink = nullptr; + } + + builder.FinishSizePrefixed(root); + + const uint8_t* buffer = builder.GetBufferPointer(); + flatbuffers::uoffset_t size = builder.GetSize(); + + std::vector buffer_copy(buffer, buffer + size); + + VW::multi_ex parsed_examples; + EXPECT_EQ(VW::parsers::flatbuffer::read_span_flatbuffer( + &w, buffer_copy.data(), buffer_copy.size(), ex_fac, parsed_examples, ex_sink), + error_code); +} + +using namespace_factory_f = std::function(FlatBufferBuilder&, VW::workspace&)>; + +Offset create_bad_ns_root_example(FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac) +{ + std::vector> namespaces = {ns_fac(builder, w)}; + + Offset label_offset = fb::Createno_label(builder).Union(); + return fb::CreateExample(builder, builder.CreateVector(namespaces), fb::Label_no_label, label_offset); +} + +Offset create_bad_ns_root_multiex( + FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac) +{ + std::vector> examples = {create_bad_ns_root_example(builder, w, ns_fac)}; + + return fb::CreateMultiExample(builder, builder.CreateVector(examples)); +} + +template ::type> +using builder_f = Offset (*)(FlatBufferBuilder&, VW::workspace&, namespace_factory_f); + +template +Offset create_bad_ns_root_collection( + FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac) +{ + if VW_STD17_CONSTEXPR (multiline) + { + // using "auto" here breaks the code coverage build due to template substitution failure + std::vector> inner_examples = {create_bad_ns_root_multiex(builder, w, ns_fac)}; + return fb::CreateExampleCollection(builder, builder.CreateVector(std::vector>()), + builder.CreateVector(inner_examples), multiline); + } + else + { + // using "auto" here breaks the code coverage build due to template substitution failure + std::vector> inner_examples = {create_bad_ns_root_example(builder, w, ns_fac)}; + return fb::CreateExampleCollection(builder, builder.CreateVector(inner_examples), + builder.CreateVector(std::vector>()), multiline); + } +} + +template +void create_flatbuffer_span_and_expect_error(VW::workspace& w, namespace_factory_f ns_fac, builder_f root_builder) +{ + FlatBufferBuilder builder; + Offset data_obj = root_builder(builder, w, ns_fac).Union(); + + Offset root_obj = fb::CreateExampleRoot(builder, root_type, data_obj); + + finish_flatbuffer_and_expect_error(builder, root_obj, w); +} + +using NamespaceErrors = vwtest::example_data_generator::NamespaceErrors; +template +void run_bad_namespace_test(VW::workspace& w) +{ + vwtest::example_data_generator data_gen; + + static_assert(errors != NamespaceErrors::BAD_NAMESPACE_NO_ERROR, "This test is intended to test bad namespaces"); + namespace_factory_f ns_fac = [&data_gen](FlatBufferBuilder& builder, VW::workspace& w) -> Offset + { return data_gen.create_bad_namespace(builder, w); }; + + create_flatbuffer_span_and_expect_error( + w, ns_fac, &create_bad_ns_root_example); + create_flatbuffer_span_and_expect_error( + w, ns_fac, &create_bad_ns_root_multiex); + + create_flatbuffer_span_and_expect_error( + w, ns_fac, &create_bad_ns_root_collection); + + create_flatbuffer_span_and_expect_error( + w, ns_fac, &create_bad_ns_root_collection); } + +template +void run_bad_namespace_test() +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + run_bad_namespace_test(*all); +} + +TEST(FlatbufferParser, BadNamespace_FeatureValuesMissing) +{ + namespace err = VW::experimental::error_code; + constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_MISSING; + constexpr int expected_error_code = err::fb_parser_feature_values_missing; + + run_bad_namespace_test(); +} + +TEST(FlatbufferParser, BadNamespace_FeatureHashesNamesMissing) +{ + namespace err = VW::experimental::error_code; + constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING; + constexpr int expected_error_code = err::fb_parser_feature_hashes_names_missing; + + run_bad_namespace_test(); +} + +TEST(FlatbufferParser, BadNamespace_FeatureValuesHashMismatch) +{ + namespace err = VW::experimental::error_code; + constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH; + constexpr int expected_error_code = err::fb_parser_size_mismatch_ft_hashes_ft_values; + + run_bad_namespace_test(); +} + +TEST(FlatbufferParser, BadNamespace_FeatureValuesNamesMismatch) +{ + namespace err = VW::experimental::error_code; + constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH; + constexpr int expected_error_code = err::fb_parser_size_mismatch_ft_names_ft_values; + + run_bad_namespace_test(); +} + +// This test is disabled because it is not possible to create a flatbuffer with a missing namespace name hash. +// TEST(FlatbufferParser, BadNamespace_NameHashMissing) +// { +// namespace err = VW::experimental::error_code; +// constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_NAME_HASH_MISSING; +// constexpr int expected_error_code = err::success; + +// run_bad_namespace_test(); +// } From 1675c6b8f8ed0fd6b75ca77a7d23f92ba886c6ed Mon Sep 17 00:00:00 2001 From: Geoffrey Thomas Date: Fri, 16 Feb 2024 08:58:45 -0500 Subject: [PATCH 2/2] build: Fix building against RapidJSON 1.1.0 as sys dep (#4682) The last release of RapidJSON, 1.1.0 from 2016 (see Tencent/rapidjson#1006 I guess), spells the variable in all caps: https://github.com/Tencent/rapidjson/blob/v1.1.0/RapidJSONConfig.cmake.in Use both spellings to accommodate both people on the last release and people who have picked a newer git commit. Co-authored-by: Jack Gerrits --- ext_libs/ext_libs.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext_libs/ext_libs.cmake b/ext_libs/ext_libs.cmake index ad2d0c27482..e03de8690e4 100644 --- a/ext_libs/ext_libs.cmake +++ b/ext_libs/ext_libs.cmake @@ -38,7 +38,7 @@ if(RAPIDJSON_SYS_DEP) # Since EXACT is not specified, any version compatible with 1.1.0 is accepted (>= 1.1.0) find_package(RapidJSON 1.1.0 CONFIG REQUIRED) add_library(RapidJSON INTERFACE) - target_include_directories(RapidJSON INTERFACE ${RapidJSON_INCLUDE_DIRS}) + target_include_directories(RapidJSON INTERFACE ${RapidJSON_INCLUDE_DIRS} ${RAPIDJSON_INCLUDE_DIRS}) else() add_library(RapidJSON INTERFACE) target_include_directories(RapidJSON SYSTEM INTERFACE "${CMAKE_CURRENT_LIST_DIR}/rapidjson/include") @@ -127,4 +127,4 @@ if(VW_FEAT_CB_GRAPH_FEEDBACK) target_include_directories(mlpack_ensmallen SYSTEM INTERFACE ${CMAKE_CURRENT_LIST_DIR}/armadillo-code/include) target_include_directories(mlpack_ensmallen SYSTEM INTERFACE ${CMAKE_CURRENT_LIST_DIR}/ensmallen/include) -endif() \ No newline at end of file +endif()