Skip to content

Commit

Permalink
refactor: make flat_example an implementation detail of ksvm (#4505)
Browse files Browse the repository at this point in the history
* refactor!: make flat_example an implementation detail of ksvm

* Update memory_tree.cc

* Absorb flat_example into svm_example

* revert "Absorb flat_example into svm_example"

This reverts commit b063feb.
  • Loading branch information
jackgerrits authored Mar 3, 2023
1 parent f08f1ec commit 69bf346
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 201 deletions.
41 changes: 2 additions & 39 deletions vowpalwabbit/core/include/vw/core/example.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,26 +143,8 @@ class example : public example_predict // core example datatype.

class workspace;

class flat_example
{
public:
polylabel l;
reduction_features ex_reduction_features;

VW::v_array<char> tag; // An identifier for the example.

size_t example_counter;
uint64_t ft_offset;
float global_weight;

size_t num_features; // precomputed, cause it's fast&easy.
float total_sum_feat_sq; // precomputed, cause it's kind of fast & easy.
features fs; // all the features
};

flat_example* flatten_example(VW::workspace& all, example* ec);
flat_example* flatten_sort_example(VW::workspace& all, example* ec);
void free_flatten_example(flat_example* fec);
// TODO: make workspace and example const
void flatten_features(VW::workspace& all, example& ec, features& fs);

inline bool example_is_newline(const example& ec) { return ec.is_newline; }

Expand Down Expand Up @@ -194,13 +176,6 @@ void truncate_example_namespace(VW::example& ec, VW::namespace_index ns, const f
void append_example_namespaces_from_example(VW::example& target, const VW::example& source);
void truncate_example_namespaces_from_example(VW::example& target, const VW::example& source);
} // namespace details

namespace model_utils
{
size_t read_model_field(io_buf& io, flat_example& fe, VW::label_parser& lbl_parser);
size_t write_model_field(io_buf& io, const flat_example& fe, const std::string& upstream_name, bool text,
VW::label_parser& lbl_parser, uint64_t parse_mask);
} // namespace model_utils
} // namespace VW

// Deprecated compat definitions
Expand All @@ -209,18 +184,6 @@ using polylabel VW_DEPRECATED("polylabel moved into VW namespace") = VW::polylab
using polyprediction VW_DEPRECATED("polyprediction moved into VW namespace") = VW::polyprediction;
using example VW_DEPRECATED("example moved into VW namespace") = VW::example;
using multi_ex VW_DEPRECATED("multi_ex moved into VW namespace") = VW::multi_ex;
using flat_example VW_DEPRECATED("flat_example moved into VW namespace") = VW::flat_example;

VW_DEPRECATED("flatten_example moved into VW namespace")
inline VW::flat_example* flatten_example(VW::workspace& all, VW::example* ec) { return VW::flatten_example(all, ec); }

VW_DEPRECATED("flatten_sort_example moved into VW namespace")
inline VW::flat_example* flatten_sort_example(VW::workspace& all, VW::example* ec)
{
return VW::flatten_sort_example(all, ec);
}
VW_DEPRECATED("free_flatten_example moved into VW namespace")
inline void free_flatten_example(VW::flat_example* fec) { return VW::free_flatten_example(fec); }

VW_DEPRECATED("example_is_newline moved into VW namespace")
inline bool example_is_newline(const VW::example& ec) { return VW::example_is_newline(ec); }
Expand Down
4 changes: 4 additions & 0 deletions vowpalwabbit/core/include/vw/core/feature_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,10 @@ class features
return all_extents_complete;
}
};

/// Both fs1 and fs2 must be sorted.
/// Most often used with VW::flatten_features
float features_dot_product(const features& fs1, const features& fs2);
} // namespace VW

using feature_value VW_DEPRECATED("Moved into VW namespace. Will be removed in VW 10.") = VW::feature_value;
Expand Down
82 changes: 13 additions & 69 deletions vowpalwabbit/core/src/example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ float VW::example::get_total_sum_feat_sq()

float collision_cleanup(VW::features& fs)
{
// Input must be sorted.
assert(std::is_sorted(fs.indices.begin(), fs.indices.end()));

// This loops over the sequence of feature values and their indexes
// when an index is repeated this combines them by adding their values.
// This assumes that fs is sorted (which is the case in `flatten_sort_example`).
Expand Down Expand Up @@ -105,46 +108,23 @@ void vec_ffs_store(full_features_and_source& p, float fx, uint64_t fi)
}
namespace VW
{
flat_example* flatten_example(VW::workspace& all, example* ec)
{
flat_example& fec = VW::details::calloc_or_throw<flat_example>();
fec.l = ec->l;
fec.tag = ec->tag;
fec.ex_reduction_features = ec->ex_reduction_features;
fec.example_counter = ec->example_counter;
fec.ft_offset = ec->ft_offset;
fec.num_features = ec->num_features;

void flatten_features(VW::workspace& all, example& ec, features& fs)
{
fs.clear();
full_features_and_source ffs;
ffs.fs = std::move(fs);
ffs.stride_shift = all.weights.stride_shift();
if (all.weights.not_null())
{ // TODO:temporary fix. all.weights is not initialized at this point in some cases.
{
// TODO:temporary fix. all.weights is not initialized at this point in some cases.
ffs.mask = all.weights.mask() >> all.weights.stride_shift();
}
else { ffs.mask = static_cast<uint64_t>(LONG_MAX) >> all.weights.stride_shift(); }
VW::foreach_feature<full_features_and_source, uint64_t, vec_ffs_store>(all, *ec, ffs);

std::swap(fec.fs, ffs.fs);

return &fec;
}

flat_example* flatten_sort_example(VW::workspace& all, example* ec)
{
flat_example* fec = flatten_example(all, ec);
fec->fs.sort(all.parse_mask);
fec->total_sum_feat_sq = collision_cleanup(fec->fs);
return fec;
}

void free_flatten_example(flat_example* fec)
{
// note: The label memory should be freed by by freeing the original example.
if (fec)
{
fec->fs.~features();
free(fec);
}
VW::foreach_feature<full_features_and_source, uint64_t, vec_ffs_store>(all, ec, ffs);
ffs.fs.sort(all.parse_mask);
ffs.fs.sum_feat_sq = collision_cleanup(ffs.fs);
fs = std::move(ffs.fs);
}

void return_multiple_example(VW::workspace& all, VW::multi_ex& examples)
Expand Down Expand Up @@ -213,42 +193,6 @@ void truncate_example_namespaces_from_example(VW::example& target, const VW::exa
}
}
} // namespace details

namespace model_utils
{
size_t read_model_field(io_buf& io, flat_example& fe, VW::label_parser& lbl_parser)
{
size_t bytes = 0;
lbl_parser.default_label(fe.l);
bytes += lbl_parser.read_cached_label(fe.l, fe.ex_reduction_features, io);
bytes += read_model_field(io, fe.tag);
bytes += read_model_field(io, fe.example_counter);
bytes += read_model_field(io, fe.ft_offset);
bytes += read_model_field(io, fe.global_weight);
bytes += read_model_field(io, fe.num_features);
bytes += read_model_field(io, fe.total_sum_feat_sq);
unsigned char index = 0;
bytes += ::VW::parsers::cache::details::read_cached_index(io, index);
bool sorted = true;
bytes += ::VW::parsers::cache::details::read_cached_features(io, fe.fs, sorted);
return bytes;
}
size_t write_model_field(io_buf& io, const flat_example& fe, const std::string& upstream_name, bool text,
VW::label_parser& lbl_parser, uint64_t parse_mask)
{
size_t bytes = 0;
lbl_parser.cache_label(fe.l, fe.ex_reduction_features, io, upstream_name + "_label", text);
bytes += write_model_field(io, fe.tag, upstream_name + "_tag", text);
bytes += write_model_field(io, fe.example_counter, upstream_name + "_example_counter", text);
bytes += write_model_field(io, fe.ft_offset, upstream_name + "_ft_offset", text);
bytes += write_model_field(io, fe.global_weight, upstream_name + "_global_weight", text);
bytes += write_model_field(io, fe.num_features, upstream_name + "_num_features", text);
bytes += write_model_field(io, fe.total_sum_feat_sq, upstream_name + "_total_sum_feat_sq", text);
::VW::parsers::cache::details::cache_index(io, 0);
::VW::parsers::cache::details::cache_features(io, fe.fs, parse_mask);
return bytes;
}
} // namespace model_utils
} // namespace VW

namespace VW
Expand Down
25 changes: 25 additions & 0 deletions vowpalwabbit/core/src/feature_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,28 @@ void VW::features::end_ns_extent()
}
}
}

float VW::features_dot_product(const features& fs1, const features& fs2)
{
assert(std::is_sorted(fs1.indices.begin(), fs1.indices.end()));
assert(std::is_sorted(fs2.indices.begin(), fs2.indices.end()));

float dotprod = 0;
if (fs2.indices.empty()) { return 0.f; }

for (size_t idx1 = 0, idx2 = 0; idx1 < fs1.size() && idx2 < fs2.size(); idx1++)
{
uint64_t ec1pos = fs1.indices[idx1];
uint64_t ec2pos = fs2.indices[idx2];
if (ec1pos < ec2pos) { continue; }

while (ec1pos > ec2pos && ++idx2 < fs2.size()) { ec2pos = fs2.indices[idx2]; }

if (ec1pos == ec2pos)
{
dotprod += fs1.values[idx1] * fs2.values[idx2];
++idx2;
}
}
return dotprod;
}
13 changes: 7 additions & 6 deletions vowpalwabbit/core/src/reductions/eigen_memory_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "vw/config/options.h"
#include "vw/core/array_parameters.h"
#include "vw/core/example.h"
#include "vw/core/feature_group.h"
#include "vw/core/learner.h"
#include "vw/core/memory.h"
#include "vw/core/model_utils.h"
Expand Down Expand Up @@ -75,14 +76,14 @@ emt_example::emt_example(VW::workspace& all, VW::example* ex)
std::vector<std::vector<VW::namespace_index>> base_interactions;

ex->interactions = &base_interactions;
auto* ex1 = VW::flatten_sort_example(all, ex);
for (auto& f : ex1->fs) { base.emplace_back(f.index(), f.value()); }
VW::free_flatten_example(ex1);
VW::features fs;
VW::flatten_features(all, *ex, fs);
for (auto& f : fs) { base.emplace_back(f.index(), f.value()); }

fs.clear();
ex->interactions = full_interactions;
auto* ex2 = VW::flatten_sort_example(all, ex);
for (auto& f : ex2->fs) { full.emplace_back(f.index(), f.value()); }
VW::free_flatten_example(ex2);
VW::flatten_features(all, *ex, fs);
for (auto& f : fs) { full.emplace_back(f.index(), f.value()); }
}

emt_lru::emt_lru(uint64_t max_size) : max_size(max_size) {}
Expand Down
Loading

0 comments on commit 69bf346

Please sign in to comment.