From ca77f152929ec396a47ddde9b736b43a5b875e63 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Tue, 13 Feb 2024 11:51:22 -0500 Subject: [PATCH] fix: Fix Flatbuffer Parsing in various corner cases * Fix logic issues with detecting presence/absence of feature names and hashes in example data, leading to an access violation crash * Fix issues with properly parsing ExampleCollection buffers, around reporting doneness and advancing when parsing inner MultiExample objects. * Fix tests to include coverage for incoming flatbuffers with/without names/hashes * Fixes validation logic to avoid crashing when feature names/hashes are missing * Add read_span_flatbuffer tests --- vowpalwabbit/fb_parser/CMakeLists.txt | 4 +- .../vw/fb_parser/parse_example_flatbuffer.h | 5 +- .../fb_parser/src/parse_example_flatbuffer.cc | 160 ++++++----- .../tests/affordance_validation_tests.cc | 190 +++++++++++++ .../fb_parser/tests/example_data_generator.cc | 20 ++ .../fb_parser/tests/example_data_generator.h | 4 + ...ser_test.cc => flatbuffer_parser_tests.cc} | 252 ++---------------- .../fb_parser/tests/prototype_example.h | 54 ++-- .../fb_parser/tests/prototype_example_root.h | 42 +-- .../fb_parser/tests/prototype_label.h | 6 +- .../fb_parser/tests/prototype_namespace.h | 68 +++-- .../fb_parser/tests/prototype_typemappings.h | 44 +++ .../fb_parser/tests/read_span_tests.cc | 131 +++++++++ 13 files changed, 595 insertions(+), 385 deletions(-) create mode 100644 vowpalwabbit/fb_parser/tests/affordance_validation_tests.cc rename vowpalwabbit/fb_parser/tests/{flatbuffer_parser_test.cc => flatbuffer_parser_tests.cc} (71%) create mode 100644 vowpalwabbit/fb_parser/tests/prototype_typemappings.h create mode 100644 vowpalwabbit/fb_parser/tests/read_span_tests.cc diff --git a/vowpalwabbit/fb_parser/CMakeLists.txt b/vowpalwabbit/fb_parser/CMakeLists.txt index 19fdde7f7dc..1874a86131c 100644 --- a/vowpalwabbit/fb_parser/CMakeLists.txt +++ b/vowpalwabbit/fb_parser/CMakeLists.txt @@ -40,7 +40,9 @@ set(vw_fb_parser_test_sources tests/prototype_label.h tests/prototype_namespace.h - tests/flatbuffer_parser_test.cc + tests/affordance_validation_tests.cc + tests/read_span_tests.cc + tests/flatbuffer_parser_tests.cc ) message(STATUS "vw_fb_parser_test_sources: ${vw_fb_parser_test_sources}") diff --git a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h index bfd42192682..fa181c1ea46 100644 --- a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h +++ b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h @@ -5,11 +5,11 @@ #pragma once #include "vw/core/api_status.h" +#include "vw/core/example.h" #include "vw/core/multi_ex.h" #include "vw/core/shared_data.h" #include "vw/core/vw_fwd.h" #include "vw/fb_parser/generated/example_generated.h" -#include "vw/core/example.h" namespace VW { @@ -21,7 +21,8 @@ namespace parsers namespace flatbuffer { int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples); -bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples); +bool read_span_flatbuffer( + VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples); class parser { diff --git a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc index bf009ff9ad9..6b9b1305d36 100644 --- a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc +++ b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc @@ -42,7 +42,8 @@ int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl return static_cast(status.get_error_code() == VW::experimental::error_code::success); } -bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples) +bool read_span_flatbuffer( + VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples) { // we expect context to contain a size_prefixed flatbuffer (technically a binary string) // which means: @@ -69,8 +70,9 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length return false; } - uint32_t flatbuffer_object_size = *reinterpret_cast(span); - if (length != flatbuffer_object_size + sizeof(uint32_t)) + flatbuffers::uoffset_t flatbuffer_object_size = + flatbuffers::ReadScalar(span); //*reinterpret_cast(span); + if (length != flatbuffer_object_size + sizeof(flatbuffers::uoffset_t)) { std::stringstream sstream; sstream << "fb_parser error: flatbuffer size prefix does not match actual size" << std::endl; @@ -85,8 +87,7 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length bool has_more = true; VW::experimental::api_status status; - do - { + do { switch (all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, &status)) { case VW::experimental::error_code::success: @@ -102,7 +103,9 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length return false; } - if (has_more &= !temp_ex[0]->is_newline) + has_more &= !temp_ex[0]->is_newline; + + if (!temp_ex[0]->is_newline) { examples.push_back(&example_factory()); std::swap(examples[examples.size() - 1], temp_ex[0]); @@ -196,11 +199,15 @@ int parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples, _multi_example_object = _data->example_obj_as_ExampleCollection()->multi_examples()->Get(_example_index); RETURN_IF_FAIL(parse_multi_example(all, examples[0], _multi_example_object, status)); // read from active collection - _example_index++; - if (_example_index == _data->example_obj_as_ExampleCollection()->multi_examples()->size()) + + if (!_active_multi_ex) { - _example_index = 0; - _active_collection = false; + _example_index++; + if (_example_index == _data->example_obj_as_ExampleCollection()->multi_examples()->size()) + { + _example_index = 0; + _active_collection = false; + } } } else @@ -220,8 +227,19 @@ int parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples, int parser::parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples, const uint8_t* buffer_pointer, VW::experimental::api_status* status) { - if (_active_multi_ex) { RETURN_IF_FAIL(parse_multi_example(all, examples[0], _multi_example_object, status)); } - else if (_active_collection) { RETURN_IF_FAIL(process_collection_item(all, examples, status)); } +#define RETURN_SUCCESS_FINISHED() \ + return buffer_pointer ? VW::experimental::error_code::nothing_to_parse : VW::experimental::error_code::success; + + if (_active_collection) + { + RETURN_IF_FAIL(process_collection_item(all, examples, status)); + if (!_active_collection) RETURN_SUCCESS_FINISHED(); + } + else if (_active_multi_ex) + { + RETURN_IF_FAIL(parse_multi_example(all, examples[0], _multi_example_object, status)); + if (!_active_multi_ex) RETURN_SUCCESS_FINISHED(); + } else { // new object to be read from file @@ -233,26 +251,33 @@ int parser::parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl { const auto example = _data->example_obj_as_Example(); RETURN_IF_FAIL(parse_example(all, examples[0], example, status)); + RETURN_SUCCESS_FINISHED(); } break; case VW::parsers::flatbuffer::ExampleType_MultiExample: { _multi_example_object = _data->example_obj_as_MultiExample(); _active_multi_ex = true; + RETURN_IF_FAIL(parse_multi_example(all, examples[0], _multi_example_object, status)); + if (!_active_multi_ex) RETURN_SUCCESS_FINISHED(); } break; case VW::parsers::flatbuffer::ExampleType_ExampleCollection: { _active_collection = true; + RETURN_IF_FAIL(process_collection_item(all, examples, status)); + if (!_active_collection) RETURN_SUCCESS_FINISHED(); } break; default: + RETURN_ERROR_LS(status, fb_parser_unknown_example_type) << "Unknown example type"; break; } } + return VW::experimental::error_code::success; } @@ -342,8 +367,36 @@ bool get_namespace_hash(VW::workspace* all, const Namespace* ns, uint64_t& hash) return false; } +bool features_have_names(const Namespace& ns) +{ + return flatbuffers::IsFieldPresent(&ns, Namespace::VT_FEATURE_NAMES) && (ns.feature_names()->size() != 0); + // TODO: It is not clear what the right answer is when feature_values->size is 0 +} + +bool features_have_hashes(const Namespace& ns) +{ + return flatbuffers::IsFieldPresent(&ns, Namespace::VT_FEATURE_HASHES) && (ns.feature_hashes()->size() != 0); +} + +bool features_have_values(const Namespace& ns) +{ + return flatbuffers::IsFieldPresent(&ns, Namespace::VT_FEATURE_VALUES) && (ns.feature_values()->size() != 0); +} + int parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* ns, VW::experimental::api_status* status) { +#define RETURN_NS_PARSER_ERROR(status, error_code) \ + if (_active_collection && _active_multi_ex) \ + { \ + RETURN_ERROR_LS(status, error_code) << "Unable to parse namespace in collection item with example index " \ + << _example_index << "and multi example index " << _multi_ex_index; \ + } \ + else if (_active_multi_ex) \ + { \ + RETURN_ERROR_LS(status, error_code) << "Unable to parse namespace in multi example index " << _multi_ex_index; \ + } \ + else { RETURN_ERROR_LS(status, error_code) << "Unable to parse namespace "; } + namespace_index index; RETURN_IF_FAIL(parser::get_namespace_index(ns, index, status)); uint64_t hash = 0; @@ -355,46 +408,24 @@ int parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* n if (hash_found) { fs.start_ns_extent(hash); } - if (!flatbuffers::IsFieldPresent(ns, Namespace::VT_FEATURE_VALUES)) - { - if (_active_collection && _active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_feature_values_missing) - << "Unable to parse namespace in collection item with example index " << _example_index - << "and multi example index " << _multi_ex_index; - } - else if (_active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_feature_values_missing) - << "Unable to parse namespace in multi example index " << _multi_ex_index; - } - else { RETURN_ERROR_LS(status, fb_parser_feature_values_missing) << "Unable to parse namespace "; } - } + if (!features_have_values(*ns)) { RETURN_NS_PARSER_ERROR(status, fb_parser_feature_values_missing) } auto feature_value_iter = (ns->feature_values())->begin(); const auto feature_value_iter_end = (ns->feature_values())->end(); + bool has_hashes = features_have_hashes(*ns); + bool has_names = features_have_names(*ns); + // assuming the usecase that if feature names would exist, they would exist for all features in the namespace - if (flatbuffers::IsFieldPresent(ns, Namespace::VT_FEATURE_NAMES)) + if (has_names) { const auto ns_name = ns->name(); auto feature_name_iter = (ns->feature_names())->begin(); - if (flatbuffers::IsFieldPresent(ns, Namespace::VT_FEATURE_HASHES)) + if (has_hashes) { if (ns->feature_hashes()->size() != ns->feature_values()->size()) { - if (_active_collection && _active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_size_mismatch_ft_hashes_ft_values) - << "Unable to parse namespace in collection item with example index " << _example_index - << "and multi example index " << _multi_ex_index; - } - else if (_active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_size_mismatch_ft_hashes_ft_values) - << "Unable to parse namespace in multi example index " << _multi_ex_index; - } - else { RETURN_ERROR_LS(status, fb_parser_size_mismatch_ft_hashes_ft_values) << "Unable to parse namespace "; } + RETURN_NS_PARSER_ERROR(status, fb_parser_size_mismatch_ft_hashes_ft_values) } auto feature_hash_iter = (ns->feature_hashes())->begin(); @@ -413,23 +444,13 @@ int parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* n // assuming the usecase that if feature names would exist, they would exist for all features in the namespace if (ns->feature_names()->size() != ns->feature_values()->size()) { - if (_active_collection && _active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_size_mismatch_ft_names_ft_values) - << "Unable to parse namespace in collection item with example index " << _example_index - << "and multi example index " << _multi_ex_index; - } - else if (_active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_size_mismatch_ft_names_ft_values) - << "Unable to parse namespace in multi example index " << _multi_ex_index; - } - else { RETURN_ERROR_LS(status, fb_parser_size_mismatch_ft_names_ft_values) << "Unable to parse namespace "; } + RETURN_NS_PARSER_ERROR(status, fb_parser_size_mismatch_ft_names_ft_values) } for (; feature_value_iter != feature_value_iter_end; ++feature_value_iter, ++feature_name_iter) { const uint64_t word_hash = - all->parser_runtime.example_parser->hasher(feature_name_iter->c_str(), feature_name_iter->size(), _c_hash); + all->parser_runtime.example_parser->hasher(feature_name_iter->c_str(), feature_name_iter->size(), _c_hash) & + all->runtime_state.parse_mask; fs.push_back(*feature_value_iter, word_hash); if (ns_name != nullptr) { @@ -440,36 +461,13 @@ int parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* n } else { - if (!flatbuffers::IsFieldPresent(ns, Namespace::VT_FEATURE_HASHES)) - { - if (_active_collection && _active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_feature_hashes_names_missing) - << "Unable to parse namespace in collection item with example index " << _example_index - << "and multi example index " << _multi_ex_index; - } - else if (_active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_feature_hashes_names_missing) - << "Unable to parse namespace in multi example index " << _multi_ex_index; - } - else { RETURN_ERROR_LS(status, fb_parser_feature_hashes_names_missing) << "Unable to parse namespace "; } - } + if (!has_hashes) { RETURN_NS_PARSER_ERROR(status, fb_parser_name_hash_missing) } + if (ns->feature_hashes()->size() != ns->feature_values()->size()) { - if (_active_collection && _active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_size_mismatch_ft_hashes_ft_values) - << "Unable to parse namespace in collection item with example index " << _example_index - << "and multi example index " << _multi_ex_index; - } - else if (_active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_size_mismatch_ft_hashes_ft_values) - << "Unable to parse namespace in multi example index " << _multi_ex_index; - } - else { RETURN_ERROR_LS(status, fb_parser_size_mismatch_ft_hashes_ft_values) << "Unable to parse namespace "; } + RETURN_NS_PARSER_ERROR(status, fb_parser_size_mismatch_ft_hashes_ft_values) } + auto feature_hash_iter = (ns->feature_hashes())->begin(); for (; feature_value_iter != feature_value_iter_end; ++feature_value_iter, ++feature_hash_iter) { diff --git a/vowpalwabbit/fb_parser/tests/affordance_validation_tests.cc b/vowpalwabbit/fb_parser/tests/affordance_validation_tests.cc new file mode 100644 index 00000000000..11563515ec6 --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/affordance_validation_tests.cc @@ -0,0 +1,190 @@ + +#include "example_data_generator.h" +#include "prototype_example.h" +#include "prototype_example_root.h" +#include "prototype_label.h" +#include "prototype_namespace.h" +#include "prototype_typemappings.h" +#include "vw/common/future_compat.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" +#include "vw/test_common/test_common.h" + +template ::type> +void create_flatbuffer_and_validate(VW::workspace& w, const T& prototype) +{ + flatbuffers::FlatBufferBuilder builder; + + Offset buffer_offset = prototype.create_flatbuffer(builder, w); + builder.Finish(buffer_offset); + + const FB_t* fb_obj = GetRoot(builder.GetBufferPointer()); + + prototype.verify(w, fb_obj); +} + +template <> +void create_flatbuffer_and_validate( + VW::workspace& w, const vwtest::prototype_label_t& prototype) +{ + if (prototype.label_type == fb::Label_NONE) { return; } // there is no flatbuffer to create + + flatbuffers::FlatBufferBuilder builder; + + Offset buffer_offset = prototype.create_flatbuffer(builder, w); + builder.Finish(buffer_offset); + + switch (prototype.label_type) + { + case fb::Label_SimpleLabel: + case fb::Label_CBLabel: + case fb::Label_ContinuousLabel: + case fb::Label_Slates_Label: + { + prototype.verify(w, prototype.label_type, builder.GetBufferPointer()); + break; + } + case fb::Label_NONE: + { + break; + } + default: + { + THROW("Label type not currently supported for create_flatbuffer_and_validate"); + break; + } + } +} + +TEST(FlatBufferParser, ValidateTestAffordances_NoLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_label_t label_prototype = vwtest::no_label(); + create_flatbuffer_and_validate(*all, label_prototype); +} + +TEST(FlatBufferParser, ValidateTestAffordances_SimpleLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + create_flatbuffer_and_validate(*all, vwtest::simple_label(0.5, 1.0)); +} + +TEST(FlatBufferParser, ValidateTestAffordances_CBLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + create_flatbuffer_and_validate(*all, vwtest::cb_label({1.5, 2, 0.25f})); +} + +TEST(FlatBufferParser, ValidateTestAffordances_ContinuousLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + std::vector probabilities = {{1, 0.5f, 0.25}}; + + create_flatbuffer_and_validate(*all, vwtest::continuous_label(probabilities)); +} + +TEST(FlatBufferParser, ValidateTestAffordances_Slates) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--slates")); + + std::vector probabilities = {{1, 0.5f}, {2, 0.25f}}; + + VW::slates::example_type types[] = { + VW::slates::example_type::UNSET, + VW::slates::example_type::ACTION, + VW::slates::example_type::SHARED, + VW::slates::example_type::SLOT, + }; + + for (VW::slates::example_type type : types) + { + create_flatbuffer_and_validate(*all, vwtest::slates_label_raw(type, 0.5, true, 0.3, 1, probabilities)); + } +} + +TEST(FlatbufferParser, ValidateTestAffordances_Namespace) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_namespace_t ns_prototype = {"U_a", {{"a", 1.f}, {"b", 2.f}}}; + create_flatbuffer_and_validate(*all, ns_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_Simple) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_example_t ex_prototype = {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }, + vwtest::simple_label(0.5, 1.0)}; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_Unlabeled) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_example_t ex_prototype = {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }}; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_CBShared) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::prototype_example_t ex_prototype = {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }, + vwtest::cb_label_shared(), "tag1"}; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_CB) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::prototype_example_t ex_prototype = {{ + {"T_a", {{"a", 5.f}, {"b", 6.f}}}, + {"T_b", {{"a", 7.f}, {"b", 8.f}}}, + }, + vwtest::cb_label({1, 1, 0.5f}), "tag1"}; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_MultiExample) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_multiexample_t multiex_prototype = {{ + {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }, + vwtest::cb_label_shared(), "tag1"}, + { + { + {"T_a", {{"a", 5.f}, {"b", 6.f}}}, + {"T_b", {{"a", 7.f}, {"b", 8.f}}}, + }, + vwtest::cb_label({{1, 1, 0.5f}}), + }, + }}; + create_flatbuffer_and_validate(*all, multiex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_ExampleCollectionMultiline) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_example_collection_t prototype = data_gen.create_cb_adf_log(2, 2, 0.5f); + + create_flatbuffer_and_validate(*all, prototype); +} diff --git a/vowpalwabbit/fb_parser/tests/example_data_generator.cc b/vowpalwabbit/fb_parser/tests/example_data_generator.cc index ed0375e88cb..95a911a1bb2 100644 --- a/vowpalwabbit/fb_parser/tests/example_data_generator.cc +++ b/vowpalwabbit/fb_parser/tests/example_data_generator.cc @@ -28,6 +28,14 @@ prototype_namespace_t example_data_generator::create_namespace( return {name.c_str(), features}; } +prototype_example_t example_data_generator::create_simple_example(uint8_t numeric_features, uint8_t string_features) +{ + return {{ + create_namespace("Simple", numeric_features, string_features), + }, + simple_label(rng.get_and_update_random())}; +} + prototype_example_t example_data_generator::create_cb_action( uint8_t action, float probability, bool rewarded, const char* tag) { @@ -82,4 +90,16 @@ prototype_example_collection_t example_data_generator::create_cb_adf_log( return {{}, examples, true}; } +prototype_example_collection_t example_data_generator::create_simple_log( + uint8_t num_examples, uint8_t numeric_features, uint8_t string_features) +{ + std::vector examples; + for (uint8_t i = 0; i < num_examples; i++) + { + examples.push_back(create_simple_example(numeric_features, string_features)); + } + + return {examples, {}, false}; +} + } // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/example_data_generator.h b/vowpalwabbit/fb_parser/tests/example_data_generator.h index f41ba8cfefb..b474d3b0c44 100644 --- a/vowpalwabbit/fb_parser/tests/example_data_generator.h +++ b/vowpalwabbit/fb_parser/tests/example_data_generator.h @@ -27,6 +27,8 @@ class example_data_generator static VW::rand_state create_test_random_state(); prototype_namespace_t create_namespace(std::string name, uint8_t numeric_features, uint8_t string_features); + + prototype_example_t create_simple_example(uint8_t numeric_features, uint8_t string_features); prototype_example_t create_cb_action( uint8_t action, float probability = 0.0, bool rewarded = false, const char* tag = nullptr); prototype_example_t create_shared_context( @@ -35,6 +37,8 @@ class example_data_generator prototype_multiexample_t create_cb_adf_example( uint8_t num_actions, uint8_t reward_action_id, const char* tag = nullptr); prototype_example_collection_t create_cb_adf_log(uint8_t num_examples, uint8_t num_actions, float reward_p); + prototype_example_collection_t create_simple_log( + uint8_t num_examples, uint8_t numeric_features, uint8_t string_features); private: VW::rand_state rng; diff --git a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc similarity index 71% rename from vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc rename to vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc index c247d3492e5..35547b0f43e 100644 --- a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc +++ b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc @@ -252,7 +252,8 @@ TEST(FlatbufferParser, SingleExample_MissingFeatureIndices) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(all.get())); VW::io_buf unused_buffer; - EXPECT_EQ(all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf), 8); + EXPECT_EQ(all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf), + VW::experimental::error_code::fb_parser_name_hash_missing); EXPECT_EQ(all->parser_runtime.example_parser->reader(all.get(), unused_buffer, examples), 0); auto example = all->parser_runtime.flat_converter->data()->example_obj_as_Example(); @@ -274,11 +275,23 @@ TEST(FlatbufferParser, SingleExample_MissingFeatureIndices) VW::finish_example(*all, *examples[0]); } +namespace vwtest +{ +template +constexpr FeatureSerialization get_feature_serialization() +{ + return test_audit_strings ? FeatureSerialization::ExcludeFeatureHash : FeatureSerialization::ExcludeFeatureNames; +} +} // namespace vwtest + template void run_parse_and_verify_test(VW::workspace& w, const root_prototype_t& root_obj) { + constexpr FeatureSerialization feature_serialization = vwtest::get_feature_serialization(); + flatbuffers::FlatBufferBuilder builder; - auto root = vwtest::create_example_root(builder, w, root_obj); + + auto root = vwtest::create_example_root(builder, w, root_obj); builder.FinishSizePrefixed(root); VW::io_buf buf; @@ -339,8 +352,8 @@ void run_parse_and_verify_test(VW::workspace& w, const root_prototype_t& root_ob examples.clear(); } - vwtest::verify_example_root(w, w.parser_runtime.flat_converter->data(), root_obj); - vwtest::verify_example_root(w, (std::vector)wrapped, root_obj); + vwtest::verify_example_root(w, w.parser_runtime.flat_converter->data(), root_obj); + vwtest::verify_example_root(w, (std::vector)wrapped, root_obj); for (size_t i = 0; i < wrapped.size(); i++) { @@ -351,18 +364,18 @@ void run_parse_and_verify_test(VW::workspace& w, const root_prototype_t& root_ob TEST(FlatbufferParser, ExampleCollection_Multiline) { - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit", "--cb_explore_adf")); example_data_generator data_gen; - auto prototype = data_gen.create_cb_adf_log(3, 4, 0.4f); + auto prototype = data_gen.create_cb_adf_log(2, 1, 0.4f); run_parse_and_verify_test(*all, prototype); } TEST(FlatbufferParser, MultiExample_Multiline) { - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit", "--cb_explore_adf")); flatbuffers::FlatBufferBuilder builder; @@ -389,7 +402,7 @@ TEST(FlatBufferParser, LabelSmokeTest_ContinuousLabel) using namespace vwtest; using example = vwtest::example; - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit")); example_data_generator datagen; example ex = {{datagen.create_namespace("U_a", 1, 1)}, @@ -403,7 +416,7 @@ TEST(FlatBufferParser, LabelSmokeTest_Slates) { using namespace vwtest; - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--slates")); + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit", "--slates")); example_data_generator datagen; // this is not the best way to describe it as it is technically labelled in the strictest sense @@ -437,224 +450,5 @@ TEST(FlatBufferParser, LabelSmokeTest_Slates) slates::slot(0, {{1, 0.6}, {0, 0.4}})}}}; - run_parse_and_verify_test(*all, labeled_example); -} - -namespace vwtest -{ -template -struct fb_type -{ -}; - -template <> -struct fb_type -{ - using type = VW::parsers::flatbuffer::Namespace; -}; - -template <> -struct fb_type -{ - using type = VW::parsers::flatbuffer::Example; -}; - -template <> -struct fb_type -{ - using type = VW::parsers::flatbuffer::MultiExample; -}; - -template <> -struct fb_type -{ - using type = VW::parsers::flatbuffer::ExampleCollection; -}; - -using union_t = void; - -template <> -struct fb_type -{ - using type = union_t; -}; -} // namespace vwtest - -template ::type> -void create_flatbuffer_and_validate(VW::workspace& w, const T& prototype) -{ - flatbuffers::FlatBufferBuilder builder; - - Offset buffer_offset = prototype.create_flatbuffer(builder, w); - builder.Finish(buffer_offset); - - const FB_t* fb_obj = GetRoot(builder.GetBufferPointer()); - - prototype.verify(w, fb_obj); -} - -template <> -void create_flatbuffer_and_validate(VW::workspace& w, const prototype_label_t& prototype) -{ - if (prototype.label_type == fb::Label_NONE) { return; } // there is no flatbuffer to create - - flatbuffers::FlatBufferBuilder builder; - - Offset buffer_offset = prototype.create_flatbuffer(builder, w); - builder.Finish(buffer_offset); - - switch (prototype.label_type) - { - case fb::Label_SimpleLabel: - case fb::Label_CBLabel: - case fb::Label_ContinuousLabel: - case fb::Label_Slates_Label: - { - prototype.verify(w, prototype.label_type, builder.GetBufferPointer()); - break; - } - case fb::Label_NONE: - { - break; - } - default: - { - THROW("Label type not currently supported for create_flatbuffer_and_validate"); - break; - } - } -} - -TEST(FlatBufferParser, ValidateTestAffordances_NoLabel) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); - - prototype_label_t label_prototype = vwtest::no_label(); - create_flatbuffer_and_validate(*all, label_prototype); -} - -TEST(FlatBufferParser, ValidateTestAffordances_SimpleLabel) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); - create_flatbuffer_and_validate(*all, simple_label(0.5, 1.0)); -} - -TEST(FlatBufferParser, ValidateTestAffordances_CBLabel) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); - create_flatbuffer_and_validate(*all, cb_label({1.5, 2, 0.25f})); -} - -TEST(FlatBufferParser, ValidateTestAffordances_ContinuousLabel) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); - - std::vector probabilities = {{1, 0.5f, 0.25}}; - - create_flatbuffer_and_validate(*all, continuous_label(probabilities)); -} - -TEST(FlatBufferParser, ValidateTestAffordances_Slates) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--slates")); - - std::vector probabilities = {{1, 0.5f}, {2, 0.25f}}; - - VW::slates::example_type types[] = { - VW::slates::example_type::UNSET, - VW::slates::example_type::ACTION, - VW::slates::example_type::SHARED, - VW::slates::example_type::SLOT, - }; - - for (VW::slates::example_type type : types) - { - create_flatbuffer_and_validate(*all, slates_label_raw(type, 0.5, true, 0.3, 1, probabilities)); - } -} - -TEST(FlatbufferParser, ValidateTestAffordances_Namespace) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); - - prototype_namespace_t ns_prototype = {"U_a", {{"a", 1.f}, {"b", 2.f}}}; - create_flatbuffer_and_validate(*all, ns_prototype); -} - -TEST(FlatbufferParser, ValidateTestAffordances_Example_Simple) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); - - prototype_example_t ex_prototype = {{ - {"U_a", {{"a", 1.f}, {"b", 2.f}}}, - {"U_b", {{"a", 3.f}, {"b", 4.f}}}, - }, - vwtest::simple_label(0.5, 1.0)}; - create_flatbuffer_and_validate(*all, ex_prototype); -} - -TEST(FlatbufferParser, ValidateTestAffordances_Example_Unlabeled) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); - - prototype_example_t ex_prototype = {{ - {"U_a", {{"a", 1.f}, {"b", 2.f}}}, - {"U_b", {{"a", 3.f}, {"b", 4.f}}}, - }}; - create_flatbuffer_and_validate(*all, ex_prototype); -} - -TEST(FlatbufferParser, ValidateTestAffordances_Example_CBShared) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); - - prototype_example_t ex_prototype = {{ - {"U_a", {{"a", 1.f}, {"b", 2.f}}}, - {"U_b", {{"a", 3.f}, {"b", 4.f}}}, - }, - vwtest::cb_label_shared(), "tag1"}; - create_flatbuffer_and_validate(*all, ex_prototype); -} - -TEST(FlatbufferParser, ValidateTestAffordances_Example_CB) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); - - prototype_example_t ex_prototype = {{ - {"T_a", {{"a", 5.f}, {"b", 6.f}}}, - {"T_b", {{"a", 7.f}, {"b", 8.f}}}, - }, - vwtest::cb_label({1, 1, 0.5f}), "tag1"}; - create_flatbuffer_and_validate(*all, ex_prototype); -} - -TEST(FlatbufferParser, ValidateTestAffordances_MultiExample) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); - - prototype_multiexample_t multiex_prototype = {{ - {{ - {"U_a", {{"a", 1.f}, {"b", 2.f}}}, - {"U_b", {{"a", 3.f}, {"b", 4.f}}}, - }, - vwtest::cb_label_shared(), "tag1"}, - { - { - {"T_a", {{"a", 5.f}, {"b", 6.f}}}, - {"T_b", {{"a", 7.f}, {"b", 8.f}}}, - }, - vwtest::cb_label({{1, 1, 0.5f}}), - }, - }}; - create_flatbuffer_and_validate(*all, multiex_prototype); -} - -TEST(FlatbufferParser, ValidateTestAffordances_ExampleCollectionMultiline) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); - - example_data_generator data_gen; - prototype_example_collection_t prototype = data_gen.create_cb_adf_log(2, 2, 0.5f); - - create_flatbuffer_and_validate(*all, prototype); + run_parse_and_verify_test(*all, labeled_example); } diff --git a/vowpalwabbit/fb_parser/tests/prototype_example.h b/vowpalwabbit/fb_parser/tests/prototype_example.h index be84a33f2d4..78c0647424c 100644 --- a/vowpalwabbit/fb_parser/tests/prototype_example.h +++ b/vowpalwabbit/fb_parser/tests/prototype_example.h @@ -10,8 +10,8 @@ #include "vw/fb_parser/parse_example_flatbuffer.h" #ifndef VWFB_BUILDERS_ONLY -#include -#include +# include +# include #endif namespace fb = VW::parsers::flatbuffer; @@ -39,11 +39,11 @@ struct prototype_example_t return count; } - template + template Offset create_flatbuffer(flatbuffers::FlatBufferBuilder& builder, VW::workspace& w) const { std::vector> fb_namespaces; - for (auto& ns : namespaces) { fb_namespaces.push_back(ns.create_flatbuffer<>(builder, w)); } + for (auto& ns : namespaces) { fb_namespaces.push_back(ns.create_flatbuffer(builder, w)); } Offset>> fb_namespaces_vector = builder.CreateVector(fb_namespaces); @@ -57,25 +57,25 @@ struct prototype_example_t } #ifndef VWFB_BUILDERS_ONLY - template + template void verify(VW::workspace& w, const fb::Example* e) const { for (size_t i = 0; i < namespaces.size(); i++) { - namespaces[i].verify(w, e->namespaces()->Get(i)); + namespaces[i].verify(w, e->namespaces()->Get(i)); } label.verify(w, e); } - template + template void verify(VW::workspace& w, const VW::example& e) const { EXPECT_EQ(e.indices.size(), count_indices()); for (size_t i = 0; i < namespaces.size(); i++) { - namespaces[i].verify(w, namespaces[i].feature_group, e); + namespaces[i].verify(w, namespaces[i].feature_group, e); } label.verify(w, e); @@ -87,11 +87,11 @@ struct prototype_multiexample_t { std::vector examples; - template + template Offset create_flatbuffer(flatbuffers::FlatBufferBuilder& builder, VW::workspace& w) const { std::vector> fb_examples; - for (auto& ex : examples) { fb_examples.push_back(ex.create_flatbuffer(builder, w)); } + for (auto& ex : examples) { fb_examples.push_back(ex.create_flatbuffer(builder, w)); } Offset>> fb_examples_vector = builder.CreateVector(fb_examples); @@ -99,20 +99,23 @@ struct prototype_multiexample_t } #ifndef VWFB_BUILDERS_ONLY - template + template void verify(VW::workspace& w, const fb::MultiExample* e) const { EXPECT_EQ(e->examples()->size(), examples.size()); - for (size_t i = 0; i < examples.size(); i++) { examples[i].verify(w, e->examples()->Get(i)); } + for (size_t i = 0; i < examples.size(); i++) + { + examples[i].verify(w, e->examples()->Get(i)); + } } - template + template void verify(VW::workspace& w, const VW::multi_ex& e) const { EXPECT_EQ(e.size(), examples.size()); - for (size_t i = 0; i < examples.size(); i++) { examples[i].verify(w, *e[i]); } + for (size_t i = 0; i < examples.size(); i++) { examples[i].verify(w, *e[i]); } } #endif }; @@ -125,16 +128,16 @@ struct prototype_example_collection_t std::vector multi_examples; bool is_multiline; - template + template Offset create_flatbuffer(flatbuffers::FlatBufferBuilder& builder, VW::workspace& w) const { std::vector> fb_examples; - for (auto& ex : examples) { fb_examples.push_back(ex.create_flatbuffer(builder, w)); } + for (auto& ex : examples) { fb_examples.push_back(ex.create_flatbuffer(builder, w)); } std::vector> fb_multi_examples; for (auto& ex : multi_examples) { - fb_multi_examples.push_back(ex.create_flatbuffer(builder, w)); + fb_multi_examples.push_back(ex.create_flatbuffer(builder, w)); } Offset>> fb_examples_vector = builder.CreateVector(fb_examples); @@ -144,34 +147,37 @@ struct prototype_example_collection_t } #ifndef VWFB_BUILDERS_ONLY - template + template void verify(VW::workspace& w, const fb::ExampleCollection* e) const { EXPECT_EQ(e->examples()->size(), examples.size()); EXPECT_EQ(e->multi_examples()->size(), multi_examples.size()); - for (size_t i = 0; i < examples.size(); i++) { examples[i].verify(w, e->examples()->Get(i)); } + for (size_t i = 0; i < examples.size(); i++) + { + examples[i].verify(w, e->examples()->Get(i)); + } for (size_t i = 0; i < multi_examples.size(); i++) { - multi_examples[i].verify(w, e->multi_examples()->Get(i)); + multi_examples[i].verify(w, e->multi_examples()->Get(i)); } } - template + template void verify_singleline(VW::workspace& w, const VW::multi_ex& e) const { EXPECT_EQ(is_multiline, false); - for (size_t i = 0; i < examples.size(); i++) { examples[i].verify(w, *e[i]); } + for (size_t i = 0; i < examples.size(); i++) { examples[i].verify(w, *e[i]); } } - template + template void verify_multiline(VW::workspace& w, const std::vector& e) const { EXPECT_EQ(is_multiline, true); - for (size_t i = 0; i < multi_examples.size(); i++) { multi_examples[i].verify(w, e[i]); } + for (size_t i = 0; i < multi_examples.size(); i++) { multi_examples[i].verify(w, e[i]); } } #endif }; diff --git a/vowpalwabbit/fb_parser/tests/prototype_example_root.h b/vowpalwabbit/fb_parser/tests/prototype_example_root.h index 3838f3db496..8102ce62f2b 100644 --- a/vowpalwabbit/fb_parser/tests/prototype_example_root.h +++ b/vowpalwabbit/fb_parser/tests/prototype_example_root.h @@ -7,8 +7,8 @@ #include "prototype_example.h" #ifndef VWFB_BUILDERS_ONLY -#include -#include +# include +# include #endif namespace fb = VW::parsers::flatbuffer; @@ -17,55 +17,55 @@ using namespace flatbuffers; namespace vwtest { -template +template inline Offset create_example_root( FlatBufferBuilder& builder, VW::workspace& vw, const prototype_example_t& example) { - auto fb_example = example.create_flatbuffer(builder, vw); + auto fb_example = example.create_flatbuffer(builder, vw); return fb::CreateExampleRoot(builder, fb::ExampleType_Example, fb_example.Union()); } #ifndef VWFB_BUILDERS_ONLY -template +template inline void verify_example_root(VW::workspace& vw, const fb::ExampleRoot* root, const prototype_example_t& expected) { EXPECT_EQ(root->example_obj_type(), fb::ExampleType_Example); auto example = root->example_obj_as_Example(); - expected.verify(vw, example); + expected.verify(vw, example); } -template +template inline void verify_example_root( VW::workspace& vw, std::vector examples, const prototype_example_t& expected) { EXPECT_EQ(examples.size(), 1); EXPECT_EQ(examples[0].size(), 1); - expected.verify(vw, *(examples[0][0])); + expected.verify(vw, *(examples[0][0])); } #endif -template +template inline Offset create_example_root( FlatBufferBuilder& builder, VW::workspace& vw, const prototype_multiexample_t& multiex) { - auto fb_multiex = multiex.create_flatbuffer(builder, vw); + auto fb_multiex = multiex.create_flatbuffer(builder, vw); return fb::CreateExampleRoot(builder, fb::ExampleType_MultiExample, fb_multiex.Union()); } #ifndef VWFB_BUILDERS_ONLY -template +template inline void verify_example_root( VW::workspace& vw, const fb::ExampleRoot* root, const prototype_multiexample_t& expected) { EXPECT_EQ(root->example_obj_type(), fb::ExampleType_MultiExample); auto multiex = root->example_obj_as_MultiExample(); - expected.verify(vw, multiex); + expected.verify(vw, multiex); } -template +template inline void verify_example_root( VW::workspace& vw, std::vector examples, const prototype_multiexample_t& expected) { @@ -73,30 +73,30 @@ inline void verify_example_root( EXPECT_EQ(examples.size(), 1 - expecting_none); EXPECT_EQ(examples[0].size(), expected.examples.size()); - expected.verify(vw, examples[0]); + expected.verify(vw, examples[0]); } #endif -template +template inline Offset create_example_root( FlatBufferBuilder& builder, VW::workspace& vw, const prototype_example_collection_t& collection) { - auto fb_collection = collection.create_flatbuffer(builder, vw); + auto fb_collection = collection.create_flatbuffer(builder, vw); return fb::CreateExampleRoot(builder, fb::ExampleType_ExampleCollection, fb_collection.Union()); } #ifndef VWFB_BUILDERS_ONLY -template +template inline void verify_example_root( VW::workspace& vw, const fb::ExampleRoot* root, const prototype_example_collection_t& expected) { EXPECT_EQ(root->example_obj_type(), fb::ExampleType_ExampleCollection); auto collection = root->example_obj_as_ExampleCollection(); - expected.verify(vw, collection); + expected.verify(vw, collection); } -template +template inline void verify_example_root( VW::workspace& vw, std::vector examples, const prototype_example_collection_t& expected) { @@ -106,12 +106,12 @@ inline void verify_example_root( if (expected.is_multiline) { EXPECT_EQ(examples.size(), expected.multi_examples.size()); - expected.verify_multiline(vw, examples); + expected.verify_multiline(vw, examples); } else { EXPECT_EQ(examples[0].size(), expected.examples.size()); - expected.verify_singleline(vw, examples[0]); + expected.verify_singleline(vw, examples[0]); } } #endif diff --git a/vowpalwabbit/fb_parser/tests/prototype_label.h b/vowpalwabbit/fb_parser/tests/prototype_label.h index 4a52f0ecdf2..3ed1447585a 100644 --- a/vowpalwabbit/fb_parser/tests/prototype_label.h +++ b/vowpalwabbit/fb_parser/tests/prototype_label.h @@ -11,8 +11,8 @@ #include "vw/fb_parser/parse_example_flatbuffer.h" #ifndef VWFB_BUILDERS_ONLY -#include -#include +# include +# include #endif namespace fb = VW::parsers::flatbuffer; @@ -86,7 +86,7 @@ struct prototype_label_t prototype_label_t no_label(); -prototype_label_t simple_label(float label, float weight, float initial = 0.f); +prototype_label_t simple_label(float label, float weight = 1.f, float initial = 0.f); prototype_label_t cb_label(std::vector costs, float weight = 1.0f); prototype_label_t cb_label(VW::cb_class single_class, float weight = 1.0f); diff --git a/vowpalwabbit/fb_parser/tests/prototype_namespace.h b/vowpalwabbit/fb_parser/tests/prototype_namespace.h index 5f446e0c085..072c016c0ee 100644 --- a/vowpalwabbit/fb_parser/tests/prototype_namespace.h +++ b/vowpalwabbit/fb_parser/tests/prototype_namespace.h @@ -9,8 +9,8 @@ #include "vw/fb_parser/parse_example_flatbuffer.h" #ifndef VWFB_BUILDERS_ONLY -#include -#include +# include +# include #endif namespace fb = VW::parsers::flatbuffer; @@ -19,6 +19,20 @@ using namespace flatbuffers; namespace vwtest { +enum FeatureSerialization +{ + ExcludeFeatureNames, + IncludeFeatureNames, + ExcludeFeatureHash +}; + +constexpr bool include_hashes(FeatureSerialization serialization) { return serialization != ExcludeFeatureHash; } + +constexpr bool include_feature_names(FeatureSerialization serialization) +{ + return serialization != ExcludeFeatureNames; +} + struct feature_t { feature_t(std::string name, float value) : has_name(true), name(name), value(value), hash(0) {} @@ -74,7 +88,7 @@ struct prototype_namespace_t uint64_t hash; uint8_t feature_group; - template + template Offset create_flatbuffer(FlatBufferBuilder& builder, VW::workspace& w) const { // When building these objects, we interpret the presence of a string as a signal to @@ -88,13 +102,17 @@ struct prototype_namespace_t for (const auto& f : features) { - if VW_STD17_CONSTEXPR (include_feature_names) + if VW_STD17_CONSTEXPR (include_feature_names(feature_serialization)) { feature_names.push_back(f.has_name ? builder.CreateString(f.name) : Offset() /* nullptr */); } + if VW_STD17_CONSTEXPR (include_hashes(feature_serialization)) + { + feature_hashes.push_back(f.has_name ? VW::hash_feature(w, f.name, hash) : f.hash); + } + feature_values.push_back(f.value); - feature_hashes.push_back(f.has_name ? VW::hash_feature(w, f.name, hash) : f.hash); } const auto name_offset = has_name ? builder.CreateString(name) : Offset(); @@ -108,9 +126,14 @@ struct prototype_namespace_t } #ifndef VWFB_BUILDERS_ONLY - template + template void verify(VW::workspace& w, const fb::Namespace* ns) const { + constexpr bool expect_feature_names = include_feature_names(feature_serialization); + constexpr bool expect_feature_hashes = include_hashes(feature_serialization); + static_assert( + expect_feature_names || expect_feature_hashes, "At least one of feature names or hashes must be included"); + uint64_t hash = this->hash; if (has_name) { @@ -123,30 +146,30 @@ struct prototype_namespace_t EXPECT_EQ(ns->hash(), feature_group); if VW_STD17_CONSTEXPR (expect_feature_names) { EXPECT_EQ(ns->feature_names()->size(), features.size()); } + if VW_STD17_CONSTEXPR (expect_feature_hashes) { EXPECT_EQ(ns->feature_hashes()->size(), features.size()); } EXPECT_EQ(ns->feature_values()->size(), features.size()); - EXPECT_EQ(ns->feature_hashes()->size(), features.size()); for (size_t i = 0; i < features.size(); i++) { - if (features[i].has_name) - { - EXPECT_EQ(ns->feature_names()->Get(i)->str(), features[i].name); - EXPECT_EQ(ns->feature_hashes()->Get(i), VW::hash_feature(w, features[i].name, hash)); - } - else - { - EXPECT_EQ(ns->feature_names()->Get(i), nullptr); - EXPECT_EQ(ns->feature_hashes()->Get(i), features[i].hash); - } + if VW_STD17_CONSTEXPR (expect_feature_names) { EXPECT_EQ(ns->feature_names()->Get(i)->str(), features[i].name); } + + const uint64_t expected_hash = + features[i].has_name ? VW::hash_feature(w, features[i].name, hash) : features[i].hash; + if VW_STD17_CONSTEXPR (expect_feature_hashes) { EXPECT_EQ(ns->feature_hashes()->Get(i), expected_hash); } EXPECT_EQ(ns->feature_values()->Get(i), features[i].value); } } - template + template void verify(VW::workspace& w, const size_t, const VW::example& e) const { + constexpr bool expect_feature_names = include_feature_names(feature_serialization); + constexpr bool expect_feature_hashes = include_hashes(feature_serialization); + static_assert( + expect_feature_names || expect_feature_hashes, "At least one of feature names or hashes must be included"); + uint64_t hash = this->hash; if (has_name) { hash = VW::hash_space(w, name); } @@ -175,13 +198,10 @@ struct prototype_namespace_t for (size_t i_f = extent.begin_index, i = 0; i_f < extent.end_index && i < this->features.size(); i_f++, i++) { auto& f = this->features[i]; - if (f.has_name) - { - EXPECT_EQ(features.indices[i_f], VW::hash_feature(w, f.name, hash)); + if VW_STD17_CONSTEXPR (expect_feature_names) { EXPECT_EQ(features.space_names[i_f].name, f.name); } - if VW_STD17_CONSTEXPR (expect_feature_names) { EXPECT_EQ(features.space_names[i_f].name, f.name); } - } - else { EXPECT_EQ(features.indices[i_f], f.hash); } + const uint64_t expected_hash = f.has_name ? VW::hash_feature(w, f.name, hash) : f.hash; + if VW_STD17_CONSTEXPR (expect_feature_hashes) { EXPECT_EQ(features.indices[i_f], expected_hash); } EXPECT_EQ(features.values[i_f], f.value); } diff --git a/vowpalwabbit/fb_parser/tests/prototype_typemappings.h b/vowpalwabbit/fb_parser/tests/prototype_typemappings.h new file mode 100644 index 00000000000..455eae702ff --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_typemappings.h @@ -0,0 +1,44 @@ +#include "prototype_example_root.h" +#include "vw/fb_parser/generated/example_generated.h" + +#pragma once + +namespace vwtest +{ +template +struct fb_type +{ +}; + +template <> +struct fb_type +{ + using type = VW::parsers::flatbuffer::Namespace; +}; + +template <> +struct fb_type +{ + using type = VW::parsers::flatbuffer::Example; +}; + +template <> +struct fb_type +{ + using type = VW::parsers::flatbuffer::MultiExample; +}; + +template <> +struct fb_type +{ + using type = VW::parsers::flatbuffer::ExampleCollection; +}; + +using union_t = void; + +template <> +struct fb_type +{ + using type = union_t; +}; +} // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/read_span_tests.cc b/vowpalwabbit/fb_parser/tests/read_span_tests.cc new file mode 100644 index 00000000000..3b2162ecfa9 --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/read_span_tests.cc @@ -0,0 +1,131 @@ + +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#include "example_data_generator.h" +#include "prototype_typemappings.h" +#include "vw/common/future_compat.h" +#include "vw/common/string_view.h" +#include "vw/core/constant.h" +#include "vw/core/error_constants.h" +#include "vw/core/vw.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" +#include "vw/test_common/test_common.h" + +#include +#include + +#include +#include + +USE_PROTOTYPE_MNEMONICS + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; +// using namespace vwtest; + +namespace vwtest +{ +inline void verify_multi_ex(VW::workspace& w, const prototype_example_t& single_ex, VW::multi_ex& multi_ex) +{ + ASSERT_EQ(multi_ex.size(), 1); + + prototype_multiexample_t validator; + validator.examples.push_back(single_ex); + + validator.verify(w, multi_ex); +} + +inline void verify_multi_ex(VW::workspace& w, const prototype_multiexample_t& validator, VW::multi_ex& multi_ex) +{ + validator.verify(w, multi_ex); +} + +inline void verify_multi_ex( + VW::workspace& w, const prototype_example_collection_t& ex_collection, const VW::multi_ex& multi_ex) +{ + // we expect ex_collection to either have a set of singleexamples, or a single multiexample + if (ex_collection.examples.size() > 0) + { + ASSERT_EQ(multi_ex.size(), ex_collection.examples.size()); + ASSERT_EQ(ex_collection.multi_examples.size(), 0); + + prototype_multiexample_t validator = {ex_collection.examples}; + validator.verify(w, multi_ex); + } + else + { + ASSERT_EQ(ex_collection.multi_examples.size(), 1); + ASSERT_EQ(multi_ex.size(), ex_collection.multi_examples[0].examples.size()); + ASSERT_EQ(ex_collection.examples.size(), 0); + + ex_collection.multi_examples[0].verify(w, multi_ex); + } +} +} // namespace vwtest + +template ::type> +void create_flatbuffer_span_and_validate(VW::workspace& w, const T& prototype) +{ + // This is what we expect to see when we use read_span_flatbuffer, since this is intended + // to be used for inference, and we would prefer not to force consumers of the API to have + // to hash the input feature names manually. + constexpr vwtest::FeatureSerialization serialization = vwtest::FeatureSerialization::ExcludeFeatureHash; + + VW::example_factory_t ex_fac = [&w]() -> VW::example& { return VW::get_unused_example(&w); }; + + FlatBufferBuilder builder; + Offset example_root = vwtest::create_example_root(builder, w, prototype); + + builder.FinishSizePrefixed(example_root); + + const uint8_t* buffer = builder.GetBufferPointer(); + flatbuffers::uoffset_t size = builder.GetSize(); + + VW::multi_ex parsed_examples; + VW::parsers::flatbuffer::read_span_flatbuffer(&w, buffer, size, ex_fac, parsed_examples); + + verify_multi_ex(w, prototype, parsed_examples); +} + +TEST(FlatbufferParser, ReadSpanFlatbuffer_SingleExample) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_example_t prototype = { + {data_gen.create_namespace("A", 3, 4), data_gen.create_namespace("B", 2, 5)}, vwtest::simple_label(1.0f)}; + + create_flatbuffer_span_and_validate(*all, prototype); +} + +TEST(FlatbufferParser, ReadSpanFlatbuffer_MultiExample) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_multiexample_t prototype = data_gen.create_cb_adf_example(3, 1, "tag"); + + create_flatbuffer_span_and_validate(*all, prototype); +} + +TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionSinglelines) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_example_collection_t prototype = data_gen.create_simple_log(3, 3, 4); + + create_flatbuffer_span_and_validate(*all, prototype); +} + +TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionMultiline) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_example_collection_t prototype = data_gen.create_cb_adf_log(1, 3, 4); + + create_flatbuffer_span_and_validate(*all, prototype); +}