Skip to content

Commit

Permalink
feat: Add additional error reporting to read_span_flatbuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
lokitoth committed Feb 14, 2024
1 parent 5df10a5 commit b7c7748
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 18 deletions.
3 changes: 3 additions & 0 deletions vowpalwabbit/core/include/vw/core/error_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ ERROR_CODE_DEFINITION(
ERROR_CODE_DEFINITION(
13, fb_parser_size_mismatch_ft_names_ft_values, "Size of feature names and feature values do not match. ")
ERROR_CODE_DEFINITION(14, unknown_label_type, "Label type in Flatbuffer not understood. ")
ERROR_CODE_DEFINITION(15, fb_parser_span_misaligned, "Input Flatbuffer span is not aligned to an 8-byte boundary. ")
ERROR_CODE_DEFINITION(16, fb_parser_span_length_mismatch, "Input Flatbuffer span does not match flatbuffer size prefix. ")


// TODO: This is temporary until we switch to the new error handling mechanism.
ERROR_CODE_DEFINITION(10000, vw_exception, "vw_exception: ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace flatbuffer
int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples);


bool read_span_flatbuffer(
VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples, example_sink_f example_sink = nullptr);
int read_span_flatbuffer(
VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples, example_sink_f example_sink = nullptr, VW::experimental::api_status* status = nullptr);

class parser
{
Expand Down
25 changes: 10 additions & 15 deletions vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl
return static_cast<int>(status.get_error_code() == VW::experimental::error_code::success);
}

bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory,
VW::multi_ex& examples, example_sink_f example_sink)
int read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory,
VW::multi_ex& examples, example_sink_f example_sink, VW::experimental::api_status* status)
{
int a = 0;
a++;
Expand All @@ -64,16 +64,15 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length
// thus context.size() = sizeof(length) + length
io_buf unused;

// TODO: How do we report errors out of here? (This is a general API problem with the parsers)
size_t address = reinterpret_cast<size_t>(span);
if (address % 8 != 0)
{
std::stringstream sstream;
sstream << "fb_parser error: flatbuffer data not aligned to 8 bytes" << std::endl;
sstream << " span => @" << std::hex << address << std::dec << " % " << 8 << " = " << address % 8
<< " (vs desired = " << 0 << ")";
THROW(sstream.str());
return false;

RETURN_ERROR_LS(status, fb_parser_span_misaligned) << sstream.str();
}

flatbuffers::uoffset_t flatbuffer_object_size =
Expand All @@ -84,8 +83,8 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length
sstream << "fb_parser error: flatbuffer size prefix does not match actual size" << std::endl;
sstream << " span => @" << std::hex << address << std::dec << " size_prefix = " << flatbuffer_object_size
<< " length = " << length;
THROW(sstream.str());
return false;

RETURN_ERROR_LS(status, fb_parser_span_length_mismatch) << sstream.str();
}

VW::multi_ex temp_ex;
Expand All @@ -108,10 +107,8 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length
else { temp_ex.push_back(&example_factory()); }

bool has_more = true;
VW::experimental::api_status status;

do {
switch (all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, &status))
switch (int result = all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, status))
{
case VW::experimental::error_code::success:
has_more = true;
Expand All @@ -120,10 +117,7 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length
has_more = false;
break;
default:
std::stringstream sstream;
sstream << "Error parsing examples: " << std::endl;
THROW(sstream.str());
return false;
RETURN_IF_FAIL(result);
}

has_more &= !temp_ex[0]->is_newline;
Expand All @@ -135,7 +129,7 @@ bool read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length
}
} while (has_more);

return true;
return VW::experimental::error_code::success;
}

const VW::parsers::flatbuffer::ExampleRoot* parser::data() { return _data; }
Expand Down Expand Up @@ -562,6 +556,7 @@ int parser::parse_flat_label(
break;
}
case Label_NONE:
case Label_no_label:
break;
default:
if (_active_collection && _active_multi_ex)
Expand Down
86 changes: 86 additions & 0 deletions vowpalwabbit/fb_parser/tests/example_data_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,25 @@
#pragma once

#include "flatbuffers/flatbuffers.h"
#include "vw/fb_parser/generated/example_generated.h"

#include "prototype_example.h"
#include "prototype_example_root.h"
#include "prototype_label.h"
#include "prototype_namespace.h"
#include "vw/common/hash.h"
#include "vw/common/random.h"
#include "vw/common/future_compat.h"

#include "vw/core/error_constants.h"

#include <vector>

USE_PROTOTYPE_MNEMONICS_EX

using namespace flatbuffers;
namespace fb = VW::parsers::flatbuffer;

namespace vwtest
{

Expand All @@ -40,8 +48,86 @@ class example_data_generator
prototype_example_collection_t create_simple_log(
uint8_t num_examples, uint8_t numeric_features, uint8_t string_features);

public:
enum NamespaceErrors
{
BAD_NAMESPACE_NO_ERROR = 0,
BAD_NAMESPACE_NAME_HASH_MISSING = 1,
BAD_NAMESPACE_FEATURE_VALUES_MISSING = 2,
BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH = 4,
BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH = 8,
BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING = 16,
};

template <NamespaceErrors errors = NamespaceErrors::BAD_NAMESPACE_NO_ERROR>
Offset<fb::Namespace> create_bad_namespace(FlatBufferBuilder& builder, VW::workspace& w);

private:
VW::rand_state rng;
};

template <example_data_generator::NamespaceErrors errors>
Offset<fb::Namespace> example_data_generator::create_bad_namespace(FlatBufferBuilder& builder, VW::workspace& w)
{
prototype_namespace_t ns = create_namespace("BadNamespace", 1, 1);
if VW_STD17_CONSTEXPR (errors == NamespaceErrors::BAD_NAMESPACE_NO_ERROR) return ns.create_flatbuffer(builder, w);

constexpr bool include_ns_name_hash = !(errors & NamespaceErrors::BAD_NAMESPACE_NAME_HASH_MISSING);
constexpr bool include_feature_values = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_MISSING);

constexpr bool include_feature_hashes = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING);
constexpr bool skip_a_feature_hash = (errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH);
static_assert(!skip_a_feature_hash || include_feature_hashes, "Cannot skip a feature hash if they are not included");

constexpr bool include_feature_names = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING);
constexpr bool skip_a_feature_name = (errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH);
static_assert(!skip_a_feature_name || include_feature_names, "Cannot skip a feature name if they are not included");

std::vector<Offset<String>> feature_names;
std::vector<float> feature_values;
std::vector<uint64_t> feature_hashes;

for (size_t i = 0; i < ns.features.size(); i++)
{
const auto& f = ns.features[i];

if VW_STD17_CONSTEXPR (include_feature_names && (!skip_a_feature_name || i == 1))
{
feature_names.push_back(builder.CreateString(f.name));
}

if VW_STD17_CONSTEXPR (include_feature_values) feature_values.push_back(f.value);

if VW_STD17_CONSTEXPR (include_feature_hashes && (!skip_a_feature_hash || i == 0))
{
feature_hashes.push_back(f.hash);
}
}

Offset<String> name_offset = Offset<String>();
if (include_ns_name_hash)
{
name_offset = builder.CreateString(ns.name);
}

// This function attempts to, insofar as possible, generate a layout that looks like it could have
// been created using the normal serialization code: In this case, that means that the strings for
// the feature names are serialized into the builder before a call to CreateNamespaceDirect is made,
// which is where the feature_names vector is allocated.
Offset<Vector<Offset<String>>> feature_names_offset = include_feature_names ? builder.CreateVector(feature_names) : Offset<Vector<Offset<String>>>();
Offset<Vector<float>> feature_values_offset = include_feature_values ? builder.CreateVector(feature_values) : Offset<Vector<float>>();
Offset<Vector<uint64_t>> feature_hashes_offset = include_feature_hashes ? builder.CreateVector(feature_hashes) : Offset<Vector<uint64_t>>();

fb::NamespaceBuilder ns_builder(builder);

if VW_STD17_CONSTEXPR (include_ns_name_hash) ns_builder.add_full_hash(VW::hash_space(w, ns.name));
if VW_STD17_CONSTEXPR (include_feature_hashes) ns_builder.add_feature_hashes(feature_hashes_offset);
if VW_STD17_CONSTEXPR (include_feature_values) ns_builder.add_feature_values(feature_values_offset);
if VW_STD17_CONSTEXPR (include_feature_names) ns_builder.add_feature_names(feature_names_offset);
if VW_STD17_CONSTEXPR (include_ns_name_hash) ns_builder.add_name(name_offset);

ns_builder.add_hash(ns.feature_group);
return ns_builder.Finish();
}

} // namespace vwtest
1 change: 1 addition & 0 deletions vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "prototype_namespace.h"
#include "vw/common/future_compat.h"
#include "vw/common/string_view.h"
#include "vw/core/api_status.h"
#include "vw/core/constant.h"
#include "vw/core/error_constants.h"
#include "vw/core/example.h"
Expand Down
6 changes: 6 additions & 0 deletions vowpalwabbit/fb_parser/tests/prototype_typemappings.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,24 @@ template <>
struct fb_type<prototype_example_t>
{
using type = VW::parsers::flatbuffer::Example;

constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_Example;
};

template <>
struct fb_type<prototype_multiexample_t>
{
using type = VW::parsers::flatbuffer::MultiExample;

constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_MultiExample;
};

template <>
struct fb_type<prototype_example_collection_t>
{
using type = VW::parsers::flatbuffer::ExampleCollection;

constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_ExampleCollection;
};

using union_t = void;
Expand Down
90 changes: 89 additions & 1 deletion vowpalwabbit/fb_parser/tests/read_span_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.

#include "example_data_generator.h"
#include "prototype_typemappings.h"
#include "example_data_generator.h"
#include "vw/common/future_compat.h"
#include "vw/common/string_view.h"
#include "vw/core/constant.h"
Expand Down Expand Up @@ -131,3 +131,91 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionMultiline)

create_flatbuffer_span_and_validate(*all, prototype);
}

template <int error_code>
void finish_flatbuffer_and_expect_error(FlatBufferBuilder& builder, Offset<fb::ExampleRoot> root, VW::workspace& w)
{
VW::example_factory_t ex_fac = [&w]() -> VW::example& { return VW::get_unused_example(&w); };

builder.FinishSizePrefixed(root);

const uint8_t* buffer = builder.GetBufferPointer();
flatbuffers::uoffset_t size = builder.GetSize();

VW::multi_ex parsed_examples;
EXPECT_EQ(VW::parsers::flatbuffer::read_span_flatbuffer(&w, buffer, size, ex_fac, parsed_examples), error_code);
}

using namespace_factory_f = std::function<Offset<fb::Namespace>(FlatBufferBuilder&, VW::workspace&)>;

Offset<fb::Example> create_bad_ns_root_example(FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac)
{
std::vector<Offset<fb::Namespace>> namespaces = { ns_fac(builder, w) };

Offset<> label_offset = fb::Createno_label(builder).Union();
return fb::CreateExample(builder, builder.CreateVector(namespaces), fb::Label_no_label, label_offset);
}

Offset<fb::MultiExample> create_bad_ns_root_multiex(FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac)
{
std::vector<Offset<fb::Example>> examples = { create_bad_ns_root_example(builder, w, ns_fac) };

return fb::CreateMultiExample(builder, builder.CreateVector(examples));
}

template <typename T, typename FB_t = typename vwtest::fb_type<T>::type>
using builder_f = Offset<FB_t> (*)(FlatBufferBuilder&, VW::workspace&, namespace_factory_f);

template <bool multiline>
Offset<fb::ExampleCollection> create_bad_ns_root_collection(FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac)
{
if VW_STD17_CONSTEXPR (multiline)
{
auto inner_examples = { create_bad_ns_root_multiex(builder, w, ns_fac) };
return fb::CreateExampleCollection(builder, builder.CreateVector(std::vector<Offset<fb::Example>>()), builder.CreateVector(inner_examples), multiline);
}
else
{
auto inner_examples = { create_bad_ns_root_example(builder, w, ns_fac) };
return fb::CreateExampleCollection(builder, builder.CreateVector(inner_examples), builder.CreateVector(std::vector<Offset<fb::MultiExample>>()), multiline);
}
}

template <int error_code, typename FB_t, fb::ExampleType root_type>
void create_flatbuffer_span_and_expect_error(VW::workspace& w, namespace_factory_f ns_fac, builder_f<FB_t> root_builder)
{
FlatBufferBuilder builder;
Offset<> data_obj = root_builder(builder, w, ns_fac).Union();

Offset<fb::ExampleRoot> root_obj = fb::CreateExampleRoot(builder, root_type, data_obj);

finish_flatbuffer_and_expect_error<error_code>(builder, root_obj, w);
}

using NamespaceErrors = vwtest::example_data_generator::NamespaceErrors;
template <NamespaceErrors errors, int error_code>
void run_bad_namespace_test(VW::workspace& w)
{
vwtest::example_data_generator data_gen;

static_assert(errors != NamespaceErrors::BAD_NAMESPACE_NO_ERROR, "This test is intended to test bad namespaces");
namespace_factory_f ns_fac = [&data_gen](FlatBufferBuilder& builder, VW::workspace& w) -> Offset<fb::Namespace> {
return data_gen.create_bad_namespace<errors>(builder, w);
};

create_flatbuffer_span_and_expect_error<error_code, vwtest::example, fb::ExampleType_Example>(w, ns_fac, &create_bad_ns_root_example);
create_flatbuffer_span_and_expect_error<error_code, vwtest::multiex, fb::ExampleType_MultiExample>(w, ns_fac, &create_bad_ns_root_multiex);

//create_flatbuffer_span_and_expect_error<error_code, vwtest::ex_collection, fb::ExampleType_ExampleCollection>(w, ns_fac, &create_bad_ns_root_collection<false>);

//create_flatbuffer_span_and_expect_error<error_code, vwtest::ex_collection, fb::ExampleType_ExampleCollection>(w, ns_fac, &create_bad_ns_root_collection<true>);
}

TEST(FlatbufferParser, BadNamespace_FeatureValuesMissing)
{
namespace err = VW::experimental::error_code;

auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer"));

run_bad_namespace_test<NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_MISSING,err::fb_parser_feature_values_missing>(*all);
}

0 comments on commit b7c7748

Please sign in to comment.