Skip to content
Jack Gerrits edited this page Aug 3, 2022 · 9 revisions

Model merging takes several compatible VW models and merges them into a single model that approximately represents all of the models combined. This is will probably never be as effective as a single model trained with all of the data sequentially. However, for situations where it is not feasible to train against all data sequentially the speedup from parallel computation can make a merged model which sees all data potentially more effective than a model trained on a subset of the data.

Generally speaking, merging is a weighted average of all given models based on relative amount of data processed. Values which act as counters are accumulated instead of averaged.

In the case of the GD reduction, when save_resume is in use, then the adaptive values are used to do a per model parameter weighted average. For all other averaged values in a model, the number of examples seen by a model is used for the given weighted average.

If a reduction defines a save_load function this implies that the reduction has training state which is persisted. Therefore, a rule of thumb is that if a reduction defines save_load it must also define merge. A warning will be emitted if any of the reductions in the stack have a save_load but no merge and an error will be emitted if the base reduction in a stack has no merge as it will definitely not work in that case.

Signatures

The signature of the merge function depends on if the reduction is a base or not. Ideally, all merge functions would use the non-base reduction signature but since base learners use the weights and other state in VW::Workspace it is not currently feasible.

using ReductionDataT = void; // ...

// Base reduction
using merge_with_all_fn = void (*)(const std::vector<float>& example_counts,
    const std::vector<const VW::workspace*>& all_workspaces, const std::vector<const ReductionDataT*>& all_data,
    VW::workspace& output_workspace, ReductionDataT& output_data);

// Non-base reduction
using merge_fn = void (*)(
    const std::vector<float>& example_counts, const std::vector<const ReductionDataT*>& all_data, ReductionDataT& output_data);

This is then set on the respective learner builder during construction.

merge is then exposed by the learner interface.

Clone this wiki locally