diff --git a/vowpalwabbit/core/include/vw/core/error_data.h b/vowpalwabbit/core/include/vw/core/error_data.h index aea7e246b32..5865de563c7 100644 --- a/vowpalwabbit/core/include/vw/core/error_data.h +++ b/vowpalwabbit/core/include/vw/core/error_data.h @@ -25,6 +25,9 @@ ERROR_CODE_DEFINITION( ERROR_CODE_DEFINITION( 13, fb_parser_size_mismatch_ft_names_ft_values, "Size of feature names and feature values do not match. ") ERROR_CODE_DEFINITION(14, unknown_label_type, "Label type in Flatbuffer not understood. ") +ERROR_CODE_DEFINITION(15, fb_parser_span_misaligned, "Input Flatbuffer span is not aligned to an 8-byte boundary. ") +ERROR_CODE_DEFINITION(16, fb_parser_span_length_mismatch, "Input Flatbuffer span does not match flatbuffer size prefix. ") + // TODO: This is temporary until we switch to the new error handling mechanism. ERROR_CODE_DEFINITION(10000, vw_exception, "vw_exception: ") diff --git a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h index b7f45a7c92a..176d5bfbb75 100644 --- a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h +++ b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h @@ -27,8 +27,8 @@ namespace flatbuffer int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples); -bool read_span_flatbuffer( - VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples, example_sink_f example_sink = nullptr); +int read_span_flatbuffer( + VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples, example_sink_f example_sink = nullptr, VW::experimental::api_status* status = nullptr); class parser { diff --git a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc index fbbcccfdaaf..b1c882e055f 100644 --- a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc +++ b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc @@ -45,8 +45,8 @@ int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl return static_cast(status.get_error_code() == VW::experimental::error_code::success); } -bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, - VW::multi_ex& examples, example_sink_f example_sink) +int read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, + VW::multi_ex& examples, example_sink_f example_sink, VW::experimental::api_status* status) { int a = 0; a++; @@ -64,7 +64,6 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length // thus context.size() = sizeof(length) + length io_buf unused; - // TODO: How do we report errors out of here? (This is a general API problem with the parsers) size_t address = reinterpret_cast(span); if (address % 8 != 0) { @@ -72,8 +71,8 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length sstream << "fb_parser error: flatbuffer data not aligned to 8 bytes" << std::endl; sstream << " span => @" << std::hex << address << std::dec << " % " << 8 << " = " << address % 8 << " (vs desired = " << 0 << ")"; - THROW(sstream.str()); - return false; + + RETURN_ERROR_LS(status, fb_parser_span_misaligned) << sstream.str(); } flatbuffers::uoffset_t flatbuffer_object_size = @@ -84,8 +83,8 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length sstream << "fb_parser error: flatbuffer size prefix does not match actual size" << std::endl; sstream << " span => @" << std::hex << address << std::dec << " size_prefix = " << flatbuffer_object_size << " length = " << length; - THROW(sstream.str()); - return false; + + RETURN_ERROR_LS(status, fb_parser_span_length_mismatch) << sstream.str(); } VW::multi_ex temp_ex; @@ -108,10 +107,8 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length else { temp_ex.push_back(&example_factory()); } bool has_more = true; - VW::experimental::api_status status; - do { - switch (all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, &status)) + switch (int result = all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, status)) { case VW::experimental::error_code::success: has_more = true; @@ -120,10 +117,7 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length has_more = false; break; default: - std::stringstream sstream; - sstream << "Error parsing examples: " << std::endl; - THROW(sstream.str()); - return false; + RETURN_IF_FAIL(result); } has_more &= !temp_ex[0]->is_newline; @@ -135,7 +129,7 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length } } while (has_more); - return true; + return VW::experimental::error_code::success; } const VW::parsers::flatbuffer::ExampleRoot* parser::data() { return _data; } @@ -562,6 +556,7 @@ int parser::parse_flat_label( break; } case Label_NONE: + case Label_no_label: break; default: if (_active_collection && _active_multi_ex) diff --git a/vowpalwabbit/fb_parser/tests/example_data_generator.h b/vowpalwabbit/fb_parser/tests/example_data_generator.h index b474d3b0c44..60e5dbd856d 100644 --- a/vowpalwabbit/fb_parser/tests/example_data_generator.h +++ b/vowpalwabbit/fb_parser/tests/example_data_generator.h @@ -5,17 +5,25 @@ #pragma once #include "flatbuffers/flatbuffers.h" +#include "vw/fb_parser/generated/example_generated.h" + #include "prototype_example.h" #include "prototype_example_root.h" #include "prototype_label.h" #include "prototype_namespace.h" #include "vw/common/hash.h" #include "vw/common/random.h" +#include "vw/common/future_compat.h" + +#include "vw/core/error_constants.h" #include USE_PROTOTYPE_MNEMONICS_EX +using namespace flatbuffers; +namespace fb = VW::parsers::flatbuffer; + namespace vwtest { @@ -40,8 +48,86 @@ class example_data_generator prototype_example_collection_t create_simple_log( uint8_t num_examples, uint8_t numeric_features, uint8_t string_features); +public: + enum NamespaceErrors + { + BAD_NAMESPACE_NO_ERROR = 0, + BAD_NAMESPACE_NAME_HASH_MISSING = 1, + BAD_NAMESPACE_FEATURE_VALUES_MISSING = 2, + BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH = 4, + BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH = 8, + BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING = 16, + }; + + template + Offset create_bad_namespace(FlatBufferBuilder& builder, VW::workspace& w); + private: VW::rand_state rng; }; +template +Offset example_data_generator::create_bad_namespace(FlatBufferBuilder& builder, VW::workspace& w) +{ + prototype_namespace_t ns = create_namespace("BadNamespace", 1, 1); + if VW_STD17_CONSTEXPR (errors == NamespaceErrors::BAD_NAMESPACE_NO_ERROR) return ns.create_flatbuffer(builder, w); + + constexpr bool include_ns_name_hash = !(errors & NamespaceErrors::BAD_NAMESPACE_NAME_HASH_MISSING); + constexpr bool include_feature_values = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_MISSING); + + constexpr bool include_feature_hashes = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING); + constexpr bool skip_a_feature_hash = (errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH); + static_assert(!skip_a_feature_hash || include_feature_hashes, "Cannot skip a feature hash if they are not included"); + + constexpr bool include_feature_names = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING); + constexpr bool skip_a_feature_name = (errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH); + static_assert(!skip_a_feature_name || include_feature_names, "Cannot skip a feature name if they are not included"); + + std::vector> feature_names; + std::vector feature_values; + std::vector feature_hashes; + + for (size_t i = 0; i < ns.features.size(); i++) + { + const auto& f = ns.features[i]; + + if VW_STD17_CONSTEXPR (include_feature_names && (!skip_a_feature_name || i == 1)) + { + feature_names.push_back(builder.CreateString(f.name)); + } + + if VW_STD17_CONSTEXPR (include_feature_values) feature_values.push_back(f.value); + + if VW_STD17_CONSTEXPR (include_feature_hashes && (!skip_a_feature_hash || i == 0)) + { + feature_hashes.push_back(f.hash); + } + } + + Offset name_offset = Offset(); + if (include_ns_name_hash) + { + name_offset = builder.CreateString(ns.name); + } + + // This function attempts to, insofar as possible, generate a layout that looks like it could have + // been created using the normal serialization code: In this case, that means that the strings for + // the feature names are serialized into the builder before a call to CreateNamespaceDirect is made, + // which is where the feature_names vector is allocated. + Offset>> feature_names_offset = include_feature_names ? builder.CreateVector(feature_names) : Offset>>(); + Offset> feature_values_offset = include_feature_values ? builder.CreateVector(feature_values) : Offset>(); + Offset> feature_hashes_offset = include_feature_hashes ? builder.CreateVector(feature_hashes) : Offset>(); + + fb::NamespaceBuilder ns_builder(builder); + + if VW_STD17_CONSTEXPR (include_ns_name_hash) ns_builder.add_full_hash(VW::hash_space(w, ns.name)); + if VW_STD17_CONSTEXPR (include_feature_hashes) ns_builder.add_feature_hashes(feature_hashes_offset); + if VW_STD17_CONSTEXPR (include_feature_values) ns_builder.add_feature_values(feature_values_offset); + if VW_STD17_CONSTEXPR (include_feature_names) ns_builder.add_feature_names(feature_names_offset); + if VW_STD17_CONSTEXPR (include_ns_name_hash) ns_builder.add_name(name_offset); + + ns_builder.add_hash(ns.feature_group); + return ns_builder.Finish(); +} + } // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc index 35547b0f43e..3bf849f442e 100644 --- a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc +++ b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc @@ -9,6 +9,7 @@ #include "prototype_namespace.h" #include "vw/common/future_compat.h" #include "vw/common/string_view.h" +#include "vw/core/api_status.h" #include "vw/core/constant.h" #include "vw/core/error_constants.h" #include "vw/core/example.h" diff --git a/vowpalwabbit/fb_parser/tests/prototype_typemappings.h b/vowpalwabbit/fb_parser/tests/prototype_typemappings.h index 455eae702ff..1a5ca09e21b 100644 --- a/vowpalwabbit/fb_parser/tests/prototype_typemappings.h +++ b/vowpalwabbit/fb_parser/tests/prototype_typemappings.h @@ -20,18 +20,24 @@ template <> struct fb_type { using type = VW::parsers::flatbuffer::Example; + + constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_Example; }; template <> struct fb_type { using type = VW::parsers::flatbuffer::MultiExample; + + constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_MultiExample; }; template <> struct fb_type { using type = VW::parsers::flatbuffer::ExampleCollection; + + constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_ExampleCollection; }; using union_t = void; diff --git a/vowpalwabbit/fb_parser/tests/read_span_tests.cc b/vowpalwabbit/fb_parser/tests/read_span_tests.cc index acbee2f529d..f1e62d98789 100644 --- a/vowpalwabbit/fb_parser/tests/read_span_tests.cc +++ b/vowpalwabbit/fb_parser/tests/read_span_tests.cc @@ -3,8 +3,8 @@ // individual contributors. All rights reserved. Released under a BSD (revised) // license as described in the file LICENSE. -#include "example_data_generator.h" #include "prototype_typemappings.h" +#include "example_data_generator.h" #include "vw/common/future_compat.h" #include "vw/common/string_view.h" #include "vw/core/constant.h" @@ -131,3 +131,91 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionMultiline) create_flatbuffer_span_and_validate(*all, prototype); } + +template +void finish_flatbuffer_and_expect_error(FlatBufferBuilder& builder, Offset root, VW::workspace& w) +{ + VW::example_factory_t ex_fac = [&w]() -> VW::example& { return VW::get_unused_example(&w); }; + + builder.FinishSizePrefixed(root); + + const uint8_t* buffer = builder.GetBufferPointer(); + flatbuffers::uoffset_t size = builder.GetSize(); + + VW::multi_ex parsed_examples; + EXPECT_EQ(VW::parsers::flatbuffer::read_span_flatbuffer(&w, buffer, size, ex_fac, parsed_examples), error_code); +} + +using namespace_factory_f = std::function(FlatBufferBuilder&, VW::workspace&)>; + +Offset create_bad_ns_root_example(FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac) +{ + std::vector> namespaces = { ns_fac(builder, w) }; + + Offset<> label_offset = fb::Createno_label(builder).Union(); + return fb::CreateExample(builder, builder.CreateVector(namespaces), fb::Label_no_label, label_offset); +} + +Offset create_bad_ns_root_multiex(FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac) +{ + std::vector> examples = { create_bad_ns_root_example(builder, w, ns_fac) }; + + return fb::CreateMultiExample(builder, builder.CreateVector(examples)); +} + +template ::type> +using builder_f = Offset (*)(FlatBufferBuilder&, VW::workspace&, namespace_factory_f); + +template +Offset create_bad_ns_root_collection(FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac) +{ + if VW_STD17_CONSTEXPR (multiline) + { + auto inner_examples = { create_bad_ns_root_multiex(builder, w, ns_fac) }; + return fb::CreateExampleCollection(builder, builder.CreateVector(std::vector>()), builder.CreateVector(inner_examples), multiline); + } + else + { + auto inner_examples = { create_bad_ns_root_example(builder, w, ns_fac) }; + return fb::CreateExampleCollection(builder, builder.CreateVector(inner_examples), builder.CreateVector(std::vector>()), multiline); + } +} + +template +void create_flatbuffer_span_and_expect_error(VW::workspace& w, namespace_factory_f ns_fac, builder_f root_builder) +{ + FlatBufferBuilder builder; + Offset<> data_obj = root_builder(builder, w, ns_fac).Union(); + + Offset root_obj = fb::CreateExampleRoot(builder, root_type, data_obj); + + finish_flatbuffer_and_expect_error(builder, root_obj, w); +} + +using NamespaceErrors = vwtest::example_data_generator::NamespaceErrors; +template +void run_bad_namespace_test(VW::workspace& w) +{ + vwtest::example_data_generator data_gen; + + static_assert(errors != NamespaceErrors::BAD_NAMESPACE_NO_ERROR, "This test is intended to test bad namespaces"); + namespace_factory_f ns_fac = [&data_gen](FlatBufferBuilder& builder, VW::workspace& w) -> Offset { + return data_gen.create_bad_namespace(builder, w); + }; + + create_flatbuffer_span_and_expect_error(w, ns_fac, &create_bad_ns_root_example); + create_flatbuffer_span_and_expect_error(w, ns_fac, &create_bad_ns_root_multiex); + + //create_flatbuffer_span_and_expect_error(w, ns_fac, &create_bad_ns_root_collection); + + //create_flatbuffer_span_and_expect_error(w, ns_fac, &create_bad_ns_root_collection); +} + +TEST(FlatbufferParser, BadNamespace_FeatureValuesMissing) +{ + namespace err = VW::experimental::error_code; + + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + run_bad_namespace_test(*all); +}