Skip to content

Commit

Permalink
feat: add a training loss calculation to the predict method of PLT re…
Browse files Browse the repository at this point in the history
…duction (#4534)

* add a training loss calculation to the predict method of PLT reduction

* update PLT demo

* update the tests for PLT reduction

* disable the calculation of additional evaluation measures in PLT reduction when true labels are not available

* apply black formating to plt_demo.py

* remove unnecessary reset of weighted_holdout_examples variable in PLT reduction

* revert the change of the path to the exe in plt_demo.py

* apply black formating again to plt_demo.py

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
  • Loading branch information
mwydmuch and jackgerrits authored Mar 20, 2023
1 parent f7a197e commit adcaff2
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 68 deletions.
11 changes: 7 additions & 4 deletions demo/plt/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Probabilistic Label Tree demo
-------------------------------
-----------------------------

This demo presents PLT for applications of logarithmic time multilabel classification.
It uses Mediamill dataset from the [LIBLINEAR datasets repository](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html)
Expand All @@ -12,19 +12,22 @@ The datasets and paremeters can be easliy edited in the script. The script requi
## PLT options
```
--plt Probabilistic Label Tree with <k> labels
--kary_tree use <k>-ary tree. By default = 2 (binary tree)
--threshold predict labels with conditional marginal probability greater than <thr> threshold"
--kary_tree use <k>-ary tree. By default = 2 (binary tree),
higher values usually give better results, but increase training time
--threshold predict labels with conditional marginal probability greater than <thr> threshold
--top_k predict top-<k> labels instead of labels above threshold
```


## Tips for using PLT
PLT accelerates training and prediction for a large number of classes,
if you have less than 10000 classes, you should probably use OAA.
If you have a huge number of labels and features at the same time,
you will need as many bits (`-b`) as can afford computationally for the best performance.
You may also consider using `--sgd` instead of default adaptive, normalized, and invariant updates to
gain more memory for feature weights what may lead to better performance.
You may also consider using `--holdout_off` if you have many rare labels in your data.
If you have many rare labels in your data, you should train with `--holdout_off`, that disables usage of holdout (validation) dataset for early stopping.


## References

Expand Down
24 changes: 18 additions & 6 deletions demo/plt/plt_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# This is demo example that demonstrates usage of PLT reduction on few popular multilabel datasets.

# Select dataset
dataset = "mediamill_exp1" # should be in ["mediamill_exp1", "eurlex", "rcv1x", "wiki10", "amazonCat"]
dataset = "eurlex" # should be in ["mediamill_exp1", "eurlex", "rcv1x", "wiki10", "amazonCat"]

# Select reduction
reduction = "plt" # should be in ["plt", "multilabel_oaa"]
Expand All @@ -18,10 +18,16 @@
output_model = f"{dataset}_{reduction}_model"

# Parameters
kary_tree = 16
l = 0.5
passes = 3
other_training_params = "--holdout_off"
kary_tree = 16 # arity of the tree,
# higher values usually give better results, but increase training time

passes = 5 # number of passes over the dataset,
# for some datasets you might want to change number of passes

l = 0.5 # learning rate

other_training_params = "--holdout_off" # because these datasets have many rare labels,
# disabling the holdout set improves the final performance

# dict with params for different datasets (k and b)
params_dict = {
Expand All @@ -34,14 +40,20 @@
if dataset in params_dict:
k, b = params_dict[dataset]
else:
print(f"Dataset {dataset} is not supported for this demo.")
print(f"Dataset {dataset} is not supported by this demo.")

# Download dataset (source: http://manikvarma.org/downloads/XC/XMLRepository.html)
# Datasets were transformed to VW's format,
# and features values were normalized (this helps with performance).
if not os.path.exists(train_data):
print("Downloading train dataset:")
os.system("wget http://www.cs.put.poznan.pl/mwydmuch/data/{}".format(train_data))

if not os.path.exists(test_data):
print("Downloading test dataset:")
os.system("wget http://www.cs.put.poznan.pl/mwydmuch/data/{}".format(test_data))


print(f"\nTraining Vowpal Wabbit {reduction} on {dataset} dataset:\n")
start = time.time()
train_cmd = f"vw {train_data} -c --{reduction} {k} --loss_function logistic -l {l} --passes {passes} -b {b} -f {output_model} {other_training_params}"
Expand Down
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_predict.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = MULTILABELS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 2,1 2
0.000000 0.000000 4 4.0 3,4 4,3 2
0.000000 0.000000 8 8.0 8 8 2
1.655385 1.655385 1 1.0 0,1 1 2
2.160409 2.665433 2 2.0 1,2 2,1 2
1.881014 1.601619 4 4.0 3,4 4,3 2
1.582159 1.283305 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 1.375444
total feature number = 20
hamming loss = 0.200000
micro-precision = 1.000000
Expand Down
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_predict_probabilities.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = ACTION_PROBS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 2,1 2
0.000000 0.000000 4 4.0 3,4 4,3 2
0.000000 0.000000 8 8.0 8 8 2
1.655385 1.655385 1 1.0 0,1 1 2
2.160409 2.665433 2 2.0 1,2 2,1 2
1.881014 1.601619 4 4.0 3,4 4,3 2
1.582159 1.283305 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 1.375444
total feature number = 20
hamming loss = 0.200000
micro-precision = 1.000000
Expand Down
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_sgd_predict.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = MULTILABELS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 0,1 2
0.000000 0.000000 2 2.0 1,2 2,1,8 2
0.000000 0.000000 4 4.0 3,4 8 2
0.000000 0.000000 8 8.0 8 1,8 2
0.249245 0.249245 1 1.0 0,1 0,1 2
2.069833 3.890422 2 2.0 1,2 2,1,8 2
3.383714 4.697595 4 4.0 3,4 8 2
3.461616 3.539518 8 8.0 8 1,8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 2.905217
total feature number = 20
hamming loss = 1.700000
micro-precision = 0.562500
Expand Down
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_sgd_top1_predict.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = MULTILABELS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 2 2
0.000000 0.000000 4 4.0 3,4 8 2
0.000000 0.000000 8 8.0 8 8 2
0.249245 0.249245 1 1.0 0,1 1 2
2.069833 3.890422 2 2.0 1,2 2 2
3.383714 4.697595 4 4.0 3,4 8 2
3.461616 3.539518 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 2.905217
total feature number = 20
p@1 = 0.600000
r@1 = 0.450000
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_top1_predict.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = MULTILABELS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 1 2
0.000000 0.000000 4 4.0 3,4 3 2
0.000000 0.000000 8 8.0 8 8 2
1.655385 1.655385 1 1.0 0,1 1 2
2.160409 2.665433 2 2.0 1,2 1 2
1.881014 1.601619 4 4.0 3,4 3 2
1.582159 1.283305 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 1.375444
total feature number = 20
p@1 = 1.000000
r@1 = 0.625000
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_top1_predict_probabilities.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = ACTION_PROBS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 1 2
0.000000 0.000000 4 4.0 3,4 3 2
0.000000 0.000000 8 8.0 8 8 2
1.655385 1.655385 1 1.0 0,1 1 2
2.160409 2.665433 2 2.0 1,2 1 2
1.881014 1.601619 4 4.0 3,4 3 2
1.582159 1.283305 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 1.375444
total feature number = 20
p@1 = 1.000000
r@1 = 0.625000
86 changes: 58 additions & 28 deletions vowpalwabbit/core/src/reductions/plt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,8 @@ inline float learn_node(plt& p, uint32_t n, learner& base, VW::example& ec)
return ec.loss;
}

void learn(plt& p, learner& base, VW::example& ec)
void get_nodes_to_update(plt& p, VW::multilabel_label& multilabels)
{
auto multilabels = std::move(ec.l.multilabels);
VW::polyprediction pred = std::move(ec.pred);

double t = p.all->sd->t;
double weighted_holdout_examples = p.all->sd->weighted_holdout_examples;
p.all->sd->weighted_holdout_examples = 0;

p.positive_nodes.clear();
p.negative_nodes.clear();

Expand Down Expand Up @@ -139,7 +132,16 @@ void learn(plt& p, learner& base, VW::example& ec)
}
}
else { p.negative_nodes.insert(0); }
}

void learn(plt& p, learner& base, VW::example& ec)
{
auto multilabels = std::move(ec.l.multilabels);
VW::polyprediction pred = std::move(ec.pred);

get_nodes_to_update(p, multilabels);

double t = p.all->sd->t;
float loss = 0;
ec.l.simple = {1.f};
ec.ex_reduction_features.template get<VW::simple_label_reduction_features>().reset_to_default();
Expand All @@ -149,8 +151,6 @@ void learn(plt& p, learner& base, VW::example& ec)
for (auto& n : p.negative_nodes) { loss += learn_node(p, n, base, ec); }

p.all->sd->t = t;
p.all->sd->weighted_holdout_examples = weighted_holdout_examples;

ec.loss = loss;
ec.pred = std::move(pred);
ec.l.multilabels = std::move(multilabels);
Expand All @@ -166,12 +166,37 @@ inline float predict_node(uint32_t n, learner& base, VW::example& ec)
return sigmoid(ec.partial_prediction);
}

inline float evaluate_node(uint32_t n, learner& base, VW::example& ec)
{
base.predict(ec, n);
return ec.loss;
}

template <bool threshold>
void predict(plt& p, learner& base, VW::example& ec)
{
auto multilabels = std::move(ec.l.multilabels);
VW::polyprediction pred = std::move(ec.pred);

// if true labels are present (e.g., predicting on holdout set),
// calculate training loss without updating base learner/weights
if (multilabels.label_v.size() > 0)
{
get_nodes_to_update(p, multilabels);

float loss = 0;
ec.l.simple = {1.f};
ec.ex_reduction_features.template get<VW::simple_label_reduction_features>().reset_to_default();
for (auto& n : p.positive_nodes) { loss += evaluate_node(n, base, ec); }

ec.l.simple.label = -1.f;
for (auto& n : p.negative_nodes) { loss += evaluate_node(n, base, ec); }

ec.loss = loss;
}
else { ec.loss = 0; }

// start real prediction
if (p.probabilities) { pred.a_s.clear(); }
pred.multilabels.label_v.clear();

Expand Down Expand Up @@ -220,18 +245,21 @@ void predict(plt& p, learner& base, VW::example& ec)
}
}

// calculate evaluation measures
uint32_t tp = 0;
uint32_t pred_size = pred.multilabels.label_v.size();

for (uint32_t i = 0; i < pred_size; ++i)
// if there are true labels, calculate evaluation measures
if (p.true_labels.size() > 0)
{
uint32_t pred_label = pred.multilabels.label_v[i];
if (p.true_labels.count(pred_label)) { ++tp; }
uint32_t tp = 0;
uint32_t pred_size = pred.multilabels.label_v.size();

for (uint32_t i = 0; i < pred_size; ++i)
{
uint32_t pred_label = pred.multilabels.label_v[i];
if (p.true_labels.count(pred_label)) { ++tp; }
}
p.tp += tp;
p.fp += static_cast<uint32_t>(pred_size) - tp;
p.fn += static_cast<uint32_t>(p.true_labels.size()) - tp;
}
p.tp += tp;
p.fp += static_cast<uint32_t>(pred_size) - tp;
p.fn += static_cast<uint32_t>(p.true_labels.size()) - tp;
}

// top-k prediction
Expand Down Expand Up @@ -270,21 +298,23 @@ void predict(plt& p, learner& base, VW::example& ec)
}
}

// calculate precision and recall at
float tp_at = 0;
for (size_t i = 0; i < p.top_k; ++i)
// if there are true labels, calculate precision and recall at k
if (p.true_labels.size() > 0)
{
uint32_t pred_label = pred.multilabels.label_v[i];
if (p.true_labels.count(pred_label)) { tp_at += 1; }
p.p_at[i] += tp_at / (i + 1);
if (p.true_labels.size() > 0) { p.r_at[i] += tp_at / p.true_labels.size(); }
float tp_at = 0;
for (size_t i = 0; i < p.top_k; ++i)
{
uint32_t pred_label = pred.multilabels.label_v[i];
if (p.true_labels.count(pred_label)) { tp_at += 1; }
p.p_at[i] += tp_at / (i + 1);
if (p.true_labels.size() > 0) { p.r_at[i] += tp_at / p.true_labels.size(); }
}
}
}

++p.ec_count;
p.node_queue.clear();

ec.loss = 0;
ec.pred = std::move(pred);
ec.l.multilabels = std::move(multilabels);
}
Expand Down

0 comments on commit adcaff2

Please sign in to comment.