Skip to content

Commit

Permalink
refactor: [workspace] split 'all' into multiple structs. split config…
Browse files Browse the repository at this point in the history
… and runtime vars. (#4493)
  • Loading branch information
lalo authored Mar 20, 2023
1 parent adcaff2 commit a9b305f
Show file tree
Hide file tree
Showing 141 changed files with 2,296 additions and 1,831 deletions.
38 changes: 19 additions & 19 deletions cs/cli/vowpalwabbit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ VowpalWabbit::VowpalWabbit(VowpalWabbitSettings^ settings)
}

if (settings->ParallelOptions != nullptr)
{ m_vw->selected_all_reduce_type = all_reduce_type::THREAD;
{ m_vw->runtime_config.selected_all_reduce_type = all_reduce_type::THREAD;
auto total = settings->ParallelOptions->MaxDegreeOfParallelism;

if (settings->Root == nullptr)
{ m_vw->all_reduce.reset(new all_reduce_threads(total, settings->Node));
{ m_vw->runtime_state.all_reduce.reset(new all_reduce_threads(total, settings->Node));
}
else
{ auto parent_all_reduce = (all_reduce_threads*)settings->Root->m_vw->all_reduce.get();
{ auto parent_all_reduce = (all_reduce_threads*)settings->Root->m_vw->runtime_state.all_reduce.get();

m_vw->all_reduce.reset(new all_reduce_threads(parent_all_reduce, total, settings->Node));
m_vw->runtime_state.all_reduce.reset(new all_reduce_threads(parent_all_reduce, total, settings->Node));
}
}

Expand All @@ -64,9 +64,9 @@ void VowpalWabbit::Driver()
}

void VowpalWabbit::RunMultiPass()
{ if (m_vw->numpasses > 1)
{ if (m_vw->runtime_config.numpasses > 1)
{ try
{ m_vw->do_reset_source = true;
{ m_vw->runtime_state.do_reset_source = true;
VW::start_parser(*m_vw);
LEARNER::generic_driver(*m_vw);
VW::end_parser(*m_vw);
Expand All @@ -79,17 +79,17 @@ VowpalWabbitPerformanceStatistics^ VowpalWabbit::PerformanceStatistics::get()
{ // see parse_args.cc:finish(...)
auto stats = gcnew VowpalWabbitPerformanceStatistics();

if (m_vw->current_pass == 0)
if (m_vw->passes_config.current_pass == 0)
{ stats->NumberOfExamplesPerPass = m_vw->sd->example_number;
}
else
{ stats->NumberOfExamplesPerPass = m_vw->sd->example_number / m_vw->current_pass;
{ stats->NumberOfExamplesPerPass = m_vw->sd->example_number / m_vw->passes_config.current_pass;
}

stats->WeightedExampleSum = m_vw->sd->weighted_examples();
stats->WeightedLabelSum = m_vw->sd->weighted_labels;

if (m_vw->holdout_set_off)
if (m_vw->passes_config.holdout_set_off)
if (m_vw->sd->weighted_labeled_examples > 0)
stats->AverageLoss = m_vw->sd->sum_loss / m_vw->sd->weighted_labeled_examples;
else
Expand All @@ -100,7 +100,7 @@ VowpalWabbitPerformanceStatistics^ VowpalWabbit::PerformanceStatistics::get()
stats->AverageLoss = m_vw->sd->holdout_best_loss;

float best_constant; float best_constant_loss;
if (get_best_constant(*m_vw->loss, *m_vw->sd, best_constant, best_constant_loss))
if (get_best_constant(*m_vw->loss_config.loss, *m_vw->sd, best_constant, best_constant_loss))
{ stats->BestConstant = best_constant;
if (best_constant_loss != FLT_MIN)
{ stats->BestConstantLoss = best_constant_loss;
Expand All @@ -124,7 +124,7 @@ uint64_t VowpalWabbit::HashSpace(String^ s)
}

uint64_t VowpalWabbit::HashFeature(String^ s, size_t u)
{ auto newHash = m_hasher(s, u) & m_vw->parse_mask;
{ auto newHash = m_hasher(s, u) & m_vw->runtime_state.parse_mask;

#ifdef _DEBUG
auto oldHash = HashFeatureNative(s, u);
Expand Down Expand Up @@ -321,7 +321,7 @@ List<VowpalWabbitExample^>^ VowpalWabbit::ParseDecisionServiceJson(cli::array<By

VW::parsers::json::decision_service_interaction interaction;

if (m_vw->audit)
if (m_vw->output_config.audit)
VW::parsers::json::read_line_decision_service_json<true>(*m_vw, examples, reinterpret_cast<char*>(data), length, copyJson, std::bind(get_example_from_pool, &state), &interaction);
else
VW::parsers::json::read_line_decision_service_json<false>(*m_vw, examples, reinterpret_cast<char*>(data), length, copyJson, std::bind(get_example_from_pool, &state), &interaction);
Expand Down Expand Up @@ -385,7 +385,7 @@ List<VowpalWabbitExample^>^ VowpalWabbit::ParseDecisionServiceJson(cli::array<By

interior_ptr<ParseJsonState^> state_ptr = &state;

if (m_vw->audit)
if (m_vw->output_config.audit)
VW::parsers::json::read_line_json<true>(*m_vw, examples, reinterpret_cast<char*>(valueHandle.AddrOfPinnedObject().ToPointer()), (size_t)bytes->Length, std::bind(get_example_from_pool, &state));
else
VW::parsers::json::read_line_json<false>(*m_vw, examples, reinterpret_cast<char*>(valueHandle.AddrOfPinnedObject().ToPointer()), (size_t)bytes->Length, std::bind(get_example_from_pool, &state));
Expand Down Expand Up @@ -793,15 +793,15 @@ VowpalWabbitExample^ VowpalWabbit::GetOrCreateNativeExample()
{ try
{
auto ex = new VW::example;
m_vw->example_parser->lbl_parser.default_label(ex->l);
m_vw->parser_runtime.example_parser->lbl_parser.default_label(ex->l);
return gcnew VowpalWabbitExample(this, ex);
}
CATCHRETHROW
}

try
{ VW::empty_example(*m_vw, *ex->m_example);
m_vw->example_parser->lbl_parser.default_label(ex->m_example->l);
m_vw->parser_runtime.example_parser->lbl_parser.default_label(ex->m_example->l);

return ex;
}
Expand Down Expand Up @@ -833,9 +833,9 @@ void VowpalWabbit::ReturnExampleToPool(VowpalWabbitExample^ ex)
}

cli::array<List<VowpalWabbitFeature^>^>^ VowpalWabbit::GetTopicAllocation(int top)
{ uint64_t length = (uint64_t)1 << m_vw->num_bits;
{ uint64_t length = (uint64_t)1 << m_vw->initial_weights_config.num_bits;
// using jagged array to enable LINQ
auto K = (int)m_vw->lda;
auto K = (int)m_vw->reduction_state.lda;
auto allocation = gcnew cli::array<List<VowpalWabbitFeature^>^>(K);

// TODO: better way of peaking into lda?
Expand All @@ -858,10 +858,10 @@ cli::array<List<VowpalWabbitFeature^>^>^ VowpalWabbit::GetTopicAllocation(int to
template<typename T>
cli::array<cli::array<float>^>^ VowpalWabbit::FillTopicAllocation(T& weights)
{
uint64_t length = (uint64_t)1 << m_vw->num_bits;
uint64_t length = (uint64_t)1 << m_vw->initial_weights_config.num_bits;

// using jagged array to enable LINQ
auto K = (int)m_vw->lda;
auto K = (int)m_vw->reduction_state.lda;
auto allocation = gcnew cli::array<cli::array<float>^>(K);
for (int k = 0; k < K; k++)
allocation[k] = gcnew cli::array<float>((int)length);
Expand Down
17 changes: 8 additions & 9 deletions cs/cli/vw_arguments.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,17 @@ public ref class VowpalWabbitArguments
float m_power_t;

internal : VowpalWabbitArguments(VW::workspace* vw)
: m_data(gcnew String(vw->data_filename.c_str()))
, m_finalRegressor(gcnew String(vw->final_regressor_name.c_str()))
, m_testonly(!vw->training)
, m_passes((int)vw->numpasses)
: m_data(gcnew String(vw->parser_runtime.data_filename.c_str()))
, m_finalRegressor(gcnew String(vw->output_model_config.final_regressor_name.c_str()))
, m_testonly(!vw->runtime_config.training)
, m_passes((int)vw->runtime_config.numpasses)
{
auto options = vw->options.get();

if (vw->initial_regressors.size() > 0)
if (vw->initial_weights_config.initial_regressors.size() > 0)
{ m_regressors = gcnew List<String^>;

for (auto& r : vw->initial_regressors)
m_regressors->Add(gcnew String(r.c_str()));
for (auto& r : vw->initial_weights_config.initial_regressors) m_regressors->Add(gcnew String(r.c_str()));
}

VW::config::cli_options_serializer serializer;
Expand All @@ -66,8 +65,8 @@ public ref class VowpalWabbitArguments
m_numberOfActions = (int)options->get_typed_option<uint32_t>("cb").value();
}

m_learning_rate = vw->eta;
m_power_t = vw->power_t;
m_learning_rate = vw->update_rule_config.eta;
m_power_t = vw->update_rule_config.power_t;
}

public:
Expand Down
6 changes: 3 additions & 3 deletions cs/cli/vw_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void VowpalWabbitBase::InternalDispose()
try
{ if (m_vw != nullptr)
{
VW::details::reset_source(*m_vw, m_vw->num_bits);
VW::details::reset_source(*m_vw, m_vw->initial_weights_config.num_bits);

// make sure don't try to free m_vw twice in case VW::finish throws.
VW::workspace* vw_tmp = m_vw;
Expand Down Expand Up @@ -187,7 +187,7 @@ void VowpalWabbitBase::Reload([System::Runtime::InteropServices::Optional] Strin

try
{
VW::details::reset_source(*m_vw, m_vw->num_bits);
VW::details::reset_source(*m_vw, m_vw->initial_weights_config.num_bits);

auto buffer = std::make_shared<std::vector<char>>();
{
Expand Down Expand Up @@ -225,7 +225,7 @@ void VowpalWabbitBase::ID::set(String^ value)
}

void VowpalWabbitBase::SaveModel()
{ std::string name = m_vw->final_regressor_name;
{ std::string name = m_vw->output_model_config.final_regressor_name;
if (name.empty())
{ return;
}
Expand Down
6 changes: 3 additions & 3 deletions cs/cli/vw_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ bool VowpalWabbitExample::IsNewLine::get()

ILabel^ VowpalWabbitExample::Label::get()
{ ILabel^ label;
auto lp = m_owner->Native->m_vw->example_parser->lbl_parser;
auto lp = m_owner->Native->m_vw->parser_runtime.example_parser->lbl_parser;
if (!memcmp(&lp, &VW::simple_label_parser_global, sizeof(lp)))
label = gcnew SimpleLabel();
else if (!memcmp(&lp, &VW::cb_label_parser_global, sizeof(lp)))
Expand Down Expand Up @@ -103,7 +103,7 @@ void VowpalWabbitExample::Label::set(ILabel^ label)
label->UpdateExample(m_owner->Native->m_vw, m_example);

// we need to update the example weight as setup_example() can be called prior to this call.
m_example->weight = m_owner->Native->m_vw->example_parser->lbl_parser.get_weight(m_example->l, m_example->ex_reduction_features);
m_example->weight = m_owner->Native->m_vw->parser_runtime.example_parser->lbl_parser.get_weight(m_example->l, m_example->ex_reduction_features);
}

void VowpalWabbitExample::MakeEmpty(VowpalWabbit^ vw)
Expand Down Expand Up @@ -389,7 +389,7 @@ uint64_t VowpalWabbitFeature::WeightIndex::get()
throw gcnew InvalidOperationException("VowpalWabbitFeature must be initialized with example");

VW::workspace* vw = m_example->Owner->Native->m_vw;
return ((m_weight_index + m_example->m_example->ft_offset) >> vw->weights.stride_shift()) & vw->parse_mask;
return ((m_weight_index + m_example->m_example->ft_offset) >> vw->weights.stride_shift()) & vw->runtime_state.parse_mask;
}

float VowpalWabbitFeature::Weight::get()
Expand Down
4 changes: 2 additions & 2 deletions cs/cli/vw_prediction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ cli::array<float>^ VowpalWabbitTopicPredictionFactory::Create(VW::workspace* vw,
{ if (ex == nullptr)
throw gcnew ArgumentNullException("ex");

auto values = gcnew cli::array<float>(vw->lda);
Marshal::Copy(IntPtr(ex->pred.scalars.begin()), values, 0, vw->lda);
auto values = gcnew cli::array<float>(vw->reduction_state.lda);
Marshal::Copy(IntPtr(ex->pred.scalars.begin()), values, 0, vw->reduction_state.lda);

return values;
}
Expand Down
23 changes: 13 additions & 10 deletions cs/vw.net.native/vw.net.arguments.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
API void GetWorkspaceBasicArguments(
vw_net_native::workspace_context* workspace, vw_net_native::vw_basic_arguments_t* args)
{
args->is_test_only = !workspace->vw->training;
args->num_passes = (int)workspace->vw->numpasses;
args->learning_rate = workspace->vw->eta;
args->power_t = workspace->vw->power_t;
args->is_test_only = !workspace->vw->runtime_config.training;
args->num_passes = (int)workspace->vw->runtime_config.numpasses;
args->learning_rate = workspace->vw->update_rule_config.eta;
args->power_t = workspace->vw->update_rule_config.power_t;

if (workspace->vw->options->was_supplied("cb"))
{
Expand All @@ -19,12 +19,12 @@ API void GetWorkspaceBasicArguments(

API const char* GetWorkspaceDataFilename(vw_net_native::workspace_context* workspace)
{
return workspace->vw->data_filename.c_str();
return workspace->vw->parser_runtime.data_filename.c_str();
}

API const char* GetFinalRegressorFilename(vw_net_native::workspace_context* workspace)
{
return workspace->vw->final_regressor_name.c_str();
return workspace->vw->output_model_config.final_regressor_name.c_str();
}

API char* SerializeCommandLine(vw_net_native::workspace_context* workspace)
Expand All @@ -42,20 +42,23 @@ API char* SerializeCommandLine(vw_net_native::workspace_context* workspace)

API size_t GetInitialRegressorFilenamesCount(vw_net_native::workspace_context* workspace)
{
return workspace->vw->initial_regressors.size();
return workspace->vw->initial_weights_config.initial_regressors.size();
}

API vw_net_native::dotnet_size_t GetInitialRegressorFilenames(
vw_net_native::workspace_context* workspace, const char** filenames, vw_net_native::dotnet_size_t count)
{
std::vector<std::string>& initial_regressors = workspace->vw->initial_regressors;
std::vector<std::string>& initial_regressors = workspace->vw->initial_weights_config.initial_regressors;
size_t size = initial_regressors.size();
if ((size_t)count < size)
{
return vw_net_native::size_to_neg_dotnet_size(size); // Not enough space in destination buffer
}

for (size_t i = 0; i < size; i++) { filenames[i] = workspace->vw->initial_regressors[i].c_str(); }
for (size_t i = 0; i < size; i++)
{
filenames[i] = workspace->vw->initial_weights_config.initial_regressors[i].c_str();
}

return workspace->vw->initial_regressors.size();
return workspace->vw->initial_weights_config.initial_regressors.size();
}
9 changes: 5 additions & 4 deletions cs/vw.net.native/vw.net.example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
API VW::example* CreateExample(vw_net_native::workspace_context* workspace)
{
auto* ex = new VW::example;
workspace->vw->example_parser->lbl_parser.default_label(ex->l);
workspace->vw->parser_runtime.example_parser->lbl_parser.default_label(ex->l);
return ex;
}

Expand Down Expand Up @@ -189,12 +189,13 @@ API void MakeIntoNewlineExample(vw_net_native::workspace_context* workspace, VW:

API void MakeLabelDefault(vw_net_native::workspace_context* workspace, VW::example* example)
{
workspace->vw->example_parser->lbl_parser.default_label(example->l);
workspace->vw->parser_runtime.example_parser->lbl_parser.default_label(example->l);
}

API void UpdateExampleWeight(vw_net_native::workspace_context* workspace, VW::example* example)
{
example->weight = workspace->vw->example_parser->lbl_parser.get_weight(example->l, example->ex_reduction_features);
example->weight =
workspace->vw->parser_runtime.example_parser->lbl_parser.get_weight(example->l, example->ex_reduction_features);
}

API vw_net_native::namespace_enumerator* CreateNamespaceEnumerator(
Expand Down Expand Up @@ -256,7 +257,7 @@ API VW::feature_index GetShiftedWeightIndex(
vw_net_native::workspace_context* workspace, VW::example* example, VW::feature_index feature_index)
{
VW::workspace* vw = workspace->vw;
return ((feature_index + example->ft_offset) >> vw->weights.stride_shift()) & vw->parse_mask;
return ((feature_index + example->ft_offset) >> vw->weights.stride_shift()) & vw->runtime_state.parse_mask;
}

API float GetWeight(vw_net_native::workspace_context* workspace, VW::example* example, VW::feature_index feature_index)
Expand Down
9 changes: 6 additions & 3 deletions cs/vw.net.native/vw.net.predictions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,17 @@ API vw_net_native::dotnet_size_t GetPredictionActionScores(
return vw_net_native::v_copy_to_managed(ex->pred.a_s, values, count);
}

API size_t GetPredictionTopicProbsCount(VW::workspace* vw, VW::example* ex) { return static_cast<size_t>(vw->lda); }
API size_t GetPredictionTopicProbsCount(VW::workspace* vw, VW::example* ex)
{
return static_cast<size_t>(vw->reduction_state.lda);
}

API vw_net_native::dotnet_size_t GetPredictionTopicProbs(
VW::workspace* vw, VW::example* ex, float* values, vw_net_native::dotnet_size_t count)
{
if (count < vw->lda)
if (count < vw->reduction_state.lda)
{
return vw_net_native::size_to_neg_dotnet_size(vw->lda); // not enough space in the output array
return vw_net_native::size_to_neg_dotnet_size(vw->reduction_state.lda); // not enough space in the output array
}

const v_array<float>& scalars = ex->pred.scalars;
Expand Down
Loading

0 comments on commit a9b305f

Please sign in to comment.