diff --git a/demo/plt/README.md b/demo/plt/README.md index 060eeb13436..74b015841f4 100644 --- a/demo/plt/README.md +++ b/demo/plt/README.md @@ -1,5 +1,5 @@ Probabilistic Label Tree demo -------------------------------- +----------------------------- This demo presents PLT for applications of logarithmic time multilabel classification. It uses Mediamill dataset from the [LIBLINEAR datasets repository](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html) @@ -12,11 +12,13 @@ The datasets and paremeters can be easliy edited in the script. The script requi ## PLT options ``` --plt Probabilistic Label Tree with labels ---kary_tree use -ary tree. By default = 2 (binary tree) ---threshold predict labels with conditional marginal probability greater than threshold" +--kary_tree use -ary tree. By default = 2 (binary tree), + higher values usually give better results, but increase training time +--threshold predict labels with conditional marginal probability greater than threshold --top_k predict top- labels instead of labels above threshold ``` + ## Tips for using PLT PLT accelerates training and prediction for a large number of classes, if you have less than 10000 classes, you should probably use OAA. @@ -24,7 +26,8 @@ If you have a huge number of labels and features at the same time, you will need as many bits (`-b`) as can afford computationally for the best performance. You may also consider using `--sgd` instead of default adaptive, normalized, and invariant updates to gain more memory for feature weights what may lead to better performance. -You may also consider using `--holdout_off` if you have many rare labels in your data. +If you have many rare labels in your data, you should train with `--holdout_off`, that disables usage of holdout (validation) dataset for early stopping. + ## References diff --git a/demo/plt/plt_demo.py b/demo/plt/plt_demo.py index 6de9f0e1f01..b92cb49bdc2 100644 --- a/demo/plt/plt_demo.py +++ b/demo/plt/plt_demo.py @@ -6,7 +6,7 @@ # This is demo example that demonstrates usage of PLT reduction on few popular multilabel datasets. # Select dataset -dataset = "mediamill_exp1" # should be in ["mediamill_exp1", "eurlex", "rcv1x", "wiki10", "amazonCat"] +dataset = "eurlex" # should be in ["mediamill_exp1", "eurlex", "rcv1x", "wiki10", "amazonCat"] # Select reduction reduction = "plt" # should be in ["plt", "multilabel_oaa"] @@ -18,10 +18,16 @@ output_model = f"{dataset}_{reduction}_model" # Parameters -kary_tree = 16 -l = 0.5 -passes = 3 -other_training_params = "--holdout_off" +kary_tree = 16 # arity of the tree, +# higher values usually give better results, but increase training time + +passes = 5 # number of passes over the dataset, +# for some datasets you might want to change number of passes + +l = 0.5 # learning rate + +other_training_params = "--holdout_off" # because these datasets have many rare labels, +# disabling the holdout set improves the final performance # dict with params for different datasets (k and b) params_dict = { @@ -34,14 +40,20 @@ if dataset in params_dict: k, b = params_dict[dataset] else: - print(f"Dataset {dataset} is not supported for this demo.") + print(f"Dataset {dataset} is not supported by this demo.") # Download dataset (source: http://manikvarma.org/downloads/XC/XMLRepository.html) +# Datasets were transformed to VW's format, +# and features values were normalized (this helps with performance). if not os.path.exists(train_data): + print("Downloading train dataset:") os.system("wget http://www.cs.put.poznan.pl/mwydmuch/data/{}".format(train_data)) + if not os.path.exists(test_data): + print("Downloading test dataset:") os.system("wget http://www.cs.put.poznan.pl/mwydmuch/data/{}".format(test_data)) + print(f"\nTraining Vowpal Wabbit {reduction} on {dataset} dataset:\n") start = time.time() train_cmd = f"vw {train_data} -c --{reduction} {k} --loss_function logistic -l {l} --passes {passes} -b {b} -f {output_model} {other_training_params}" diff --git a/test/train-sets/ref/plt_predict.stderr b/test/train-sets/ref/plt_predict.stderr index d3dea9b4016..aa105fe6728 100644 --- a/test/train-sets/ref/plt_predict.stderr +++ b/test/train-sets/ref/plt_predict.stderr @@ -15,16 +15,16 @@ Input label = MULTILABEL Output pred = MULTILABELS average since example example current current current loss last counter weight label predict features -0.000000 0.000000 1 1.0 0,1 1 2 -0.000000 0.000000 2 2.0 1,2 2,1 2 -0.000000 0.000000 4 4.0 3,4 4,3 2 -0.000000 0.000000 8 8.0 8 8 2 +1.655385 1.655385 1 1.0 0,1 1 2 +2.160409 2.665433 2 2.0 1,2 2,1 2 +1.881014 1.601619 4 4.0 3,4 4,3 2 +1.582159 1.283305 8 8.0 8 8 2 finished run number of examples = 10 weighted example sum = 10.000000 weighted label sum = 0.000000 -average loss = 0.000000 +average loss = 1.375444 total feature number = 20 hamming loss = 0.200000 micro-precision = 1.000000 diff --git a/test/train-sets/ref/plt_predict_probabilities.stderr b/test/train-sets/ref/plt_predict_probabilities.stderr index 99042112722..b9b92409720 100644 --- a/test/train-sets/ref/plt_predict_probabilities.stderr +++ b/test/train-sets/ref/plt_predict_probabilities.stderr @@ -15,16 +15,16 @@ Input label = MULTILABEL Output pred = ACTION_PROBS average since example example current current current loss last counter weight label predict features -0.000000 0.000000 1 1.0 0,1 1 2 -0.000000 0.000000 2 2.0 1,2 2,1 2 -0.000000 0.000000 4 4.0 3,4 4,3 2 -0.000000 0.000000 8 8.0 8 8 2 +1.655385 1.655385 1 1.0 0,1 1 2 +2.160409 2.665433 2 2.0 1,2 2,1 2 +1.881014 1.601619 4 4.0 3,4 4,3 2 +1.582159 1.283305 8 8.0 8 8 2 finished run number of examples = 10 weighted example sum = 10.000000 weighted label sum = 0.000000 -average loss = 0.000000 +average loss = 1.375444 total feature number = 20 hamming loss = 0.200000 micro-precision = 1.000000 diff --git a/test/train-sets/ref/plt_sgd_predict.stderr b/test/train-sets/ref/plt_sgd_predict.stderr index d6451175cd0..6947e934701 100644 --- a/test/train-sets/ref/plt_sgd_predict.stderr +++ b/test/train-sets/ref/plt_sgd_predict.stderr @@ -15,16 +15,16 @@ Input label = MULTILABEL Output pred = MULTILABELS average since example example current current current loss last counter weight label predict features -0.000000 0.000000 1 1.0 0,1 0,1 2 -0.000000 0.000000 2 2.0 1,2 2,1,8 2 -0.000000 0.000000 4 4.0 3,4 8 2 -0.000000 0.000000 8 8.0 8 1,8 2 +0.249245 0.249245 1 1.0 0,1 0,1 2 +2.069833 3.890422 2 2.0 1,2 2,1,8 2 +3.383714 4.697595 4 4.0 3,4 8 2 +3.461616 3.539518 8 8.0 8 1,8 2 finished run number of examples = 10 weighted example sum = 10.000000 weighted label sum = 0.000000 -average loss = 0.000000 +average loss = 2.905217 total feature number = 20 hamming loss = 1.700000 micro-precision = 0.562500 diff --git a/test/train-sets/ref/plt_sgd_top1_predict.stderr b/test/train-sets/ref/plt_sgd_top1_predict.stderr index e3f9e7673fa..63af2af8111 100644 --- a/test/train-sets/ref/plt_sgd_top1_predict.stderr +++ b/test/train-sets/ref/plt_sgd_top1_predict.stderr @@ -15,16 +15,16 @@ Input label = MULTILABEL Output pred = MULTILABELS average since example example current current current loss last counter weight label predict features -0.000000 0.000000 1 1.0 0,1 1 2 -0.000000 0.000000 2 2.0 1,2 2 2 -0.000000 0.000000 4 4.0 3,4 8 2 -0.000000 0.000000 8 8.0 8 8 2 +0.249245 0.249245 1 1.0 0,1 1 2 +2.069833 3.890422 2 2.0 1,2 2 2 +3.383714 4.697595 4 4.0 3,4 8 2 +3.461616 3.539518 8 8.0 8 8 2 finished run number of examples = 10 weighted example sum = 10.000000 weighted label sum = 0.000000 -average loss = 0.000000 +average loss = 2.905217 total feature number = 20 p@1 = 0.600000 r@1 = 0.450000 diff --git a/test/train-sets/ref/plt_top1_predict.stderr b/test/train-sets/ref/plt_top1_predict.stderr index 88099c5566b..c15396706a4 100644 --- a/test/train-sets/ref/plt_top1_predict.stderr +++ b/test/train-sets/ref/plt_top1_predict.stderr @@ -15,16 +15,16 @@ Input label = MULTILABEL Output pred = MULTILABELS average since example example current current current loss last counter weight label predict features -0.000000 0.000000 1 1.0 0,1 1 2 -0.000000 0.000000 2 2.0 1,2 1 2 -0.000000 0.000000 4 4.0 3,4 3 2 -0.000000 0.000000 8 8.0 8 8 2 +1.655385 1.655385 1 1.0 0,1 1 2 +2.160409 2.665433 2 2.0 1,2 1 2 +1.881014 1.601619 4 4.0 3,4 3 2 +1.582159 1.283305 8 8.0 8 8 2 finished run number of examples = 10 weighted example sum = 10.000000 weighted label sum = 0.000000 -average loss = 0.000000 +average loss = 1.375444 total feature number = 20 p@1 = 1.000000 r@1 = 0.625000 diff --git a/test/train-sets/ref/plt_top1_predict_probabilities.stderr b/test/train-sets/ref/plt_top1_predict_probabilities.stderr index a8a7b8425b3..fbd5fc04b16 100644 --- a/test/train-sets/ref/plt_top1_predict_probabilities.stderr +++ b/test/train-sets/ref/plt_top1_predict_probabilities.stderr @@ -15,16 +15,16 @@ Input label = MULTILABEL Output pred = ACTION_PROBS average since example example current current current loss last counter weight label predict features -0.000000 0.000000 1 1.0 0,1 1 2 -0.000000 0.000000 2 2.0 1,2 1 2 -0.000000 0.000000 4 4.0 3,4 3 2 -0.000000 0.000000 8 8.0 8 8 2 +1.655385 1.655385 1 1.0 0,1 1 2 +2.160409 2.665433 2 2.0 1,2 1 2 +1.881014 1.601619 4 4.0 3,4 3 2 +1.582159 1.283305 8 8.0 8 8 2 finished run number of examples = 10 weighted example sum = 10.000000 weighted label sum = 0.000000 -average loss = 0.000000 +average loss = 1.375444 total feature number = 20 p@1 = 1.000000 r@1 = 0.625000 diff --git a/vowpalwabbit/core/src/reductions/plt.cc b/vowpalwabbit/core/src/reductions/plt.cc index 8db6e0dc56f..175f0f2f17b 100644 --- a/vowpalwabbit/core/src/reductions/plt.cc +++ b/vowpalwabbit/core/src/reductions/plt.cc @@ -90,15 +90,8 @@ inline float learn_node(plt& p, uint32_t n, learner& base, VW::example& ec) return ec.loss; } -void learn(plt& p, learner& base, VW::example& ec) +void get_nodes_to_update(plt& p, VW::multilabel_label& multilabels) { - auto multilabels = std::move(ec.l.multilabels); - VW::polyprediction pred = std::move(ec.pred); - - double t = p.all->sd->t; - double weighted_holdout_examples = p.all->sd->weighted_holdout_examples; - p.all->sd->weighted_holdout_examples = 0; - p.positive_nodes.clear(); p.negative_nodes.clear(); @@ -139,7 +132,16 @@ void learn(plt& p, learner& base, VW::example& ec) } } else { p.negative_nodes.insert(0); } +} +void learn(plt& p, learner& base, VW::example& ec) +{ + auto multilabels = std::move(ec.l.multilabels); + VW::polyprediction pred = std::move(ec.pred); + + get_nodes_to_update(p, multilabels); + + double t = p.all->sd->t; float loss = 0; ec.l.simple = {1.f}; ec.ex_reduction_features.template get().reset_to_default(); @@ -149,8 +151,6 @@ void learn(plt& p, learner& base, VW::example& ec) for (auto& n : p.negative_nodes) { loss += learn_node(p, n, base, ec); } p.all->sd->t = t; - p.all->sd->weighted_holdout_examples = weighted_holdout_examples; - ec.loss = loss; ec.pred = std::move(pred); ec.l.multilabels = std::move(multilabels); @@ -166,12 +166,37 @@ inline float predict_node(uint32_t n, learner& base, VW::example& ec) return sigmoid(ec.partial_prediction); } +inline float evaluate_node(uint32_t n, learner& base, VW::example& ec) +{ + base.predict(ec, n); + return ec.loss; +} + template void predict(plt& p, learner& base, VW::example& ec) { auto multilabels = std::move(ec.l.multilabels); VW::polyprediction pred = std::move(ec.pred); + // if true labels are present (e.g., predicting on holdout set), + // calculate training loss without updating base learner/weights + if (multilabels.label_v.size() > 0) + { + get_nodes_to_update(p, multilabels); + + float loss = 0; + ec.l.simple = {1.f}; + ec.ex_reduction_features.template get().reset_to_default(); + for (auto& n : p.positive_nodes) { loss += evaluate_node(n, base, ec); } + + ec.l.simple.label = -1.f; + for (auto& n : p.negative_nodes) { loss += evaluate_node(n, base, ec); } + + ec.loss = loss; + } + else { ec.loss = 0; } + + // start real prediction if (p.probabilities) { pred.a_s.clear(); } pred.multilabels.label_v.clear(); @@ -220,18 +245,21 @@ void predict(plt& p, learner& base, VW::example& ec) } } - // calculate evaluation measures - uint32_t tp = 0; - uint32_t pred_size = pred.multilabels.label_v.size(); - - for (uint32_t i = 0; i < pred_size; ++i) + // if there are true labels, calculate evaluation measures + if (p.true_labels.size() > 0) { - uint32_t pred_label = pred.multilabels.label_v[i]; - if (p.true_labels.count(pred_label)) { ++tp; } + uint32_t tp = 0; + uint32_t pred_size = pred.multilabels.label_v.size(); + + for (uint32_t i = 0; i < pred_size; ++i) + { + uint32_t pred_label = pred.multilabels.label_v[i]; + if (p.true_labels.count(pred_label)) { ++tp; } + } + p.tp += tp; + p.fp += static_cast(pred_size) - tp; + p.fn += static_cast(p.true_labels.size()) - tp; } - p.tp += tp; - p.fp += static_cast(pred_size) - tp; - p.fn += static_cast(p.true_labels.size()) - tp; } // top-k prediction @@ -270,21 +298,23 @@ void predict(plt& p, learner& base, VW::example& ec) } } - // calculate precision and recall at - float tp_at = 0; - for (size_t i = 0; i < p.top_k; ++i) + // if there are true labels, calculate precision and recall at k + if (p.true_labels.size() > 0) { - uint32_t pred_label = pred.multilabels.label_v[i]; - if (p.true_labels.count(pred_label)) { tp_at += 1; } - p.p_at[i] += tp_at / (i + 1); - if (p.true_labels.size() > 0) { p.r_at[i] += tp_at / p.true_labels.size(); } + float tp_at = 0; + for (size_t i = 0; i < p.top_k; ++i) + { + uint32_t pred_label = pred.multilabels.label_v[i]; + if (p.true_labels.count(pred_label)) { tp_at += 1; } + p.p_at[i] += tp_at / (i + 1); + if (p.true_labels.size() > 0) { p.r_at[i] += tp_at / p.true_labels.size(); } + } } } ++p.ec_count; p.node_queue.clear(); - ec.loss = 0; ec.pred = std::move(pred); ec.l.multilabels = std::move(multilabels); }