Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add a training loss calculation to the predict method of PLT reduction #4534

Merged
11 changes: 7 additions & 4 deletions demo/plt/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Probabilistic Label Tree demo
-------------------------------
-----------------------------

This demo presents PLT for applications of logarithmic time multilabel classification.
It uses Mediamill dataset from the [LIBLINEAR datasets repository](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html)
Expand All @@ -12,19 +12,22 @@ The datasets and paremeters can be easliy edited in the script. The script requi
## PLT options
```
--plt Probabilistic Label Tree with <k> labels
--kary_tree use <k>-ary tree. By default = 2 (binary tree)
--threshold predict labels with conditional marginal probability greater than <thr> threshold"
--kary_tree use <k>-ary tree. By default = 2 (binary tree),
higher values usually give better results, but increase training time
--threshold predict labels with conditional marginal probability greater than <thr> threshold
--top_k predict top-<k> labels instead of labels above threshold
```


## Tips for using PLT
PLT accelerates training and prediction for a large number of classes,
if you have less than 10000 classes, you should probably use OAA.
If you have a huge number of labels and features at the same time,
you will need as many bits (`-b`) as can afford computationally for the best performance.
You may also consider using `--sgd` instead of default adaptive, normalized, and invariant updates to
gain more memory for feature weights what may lead to better performance.
You may also consider using `--holdout_off` if you have many rare labels in your data.
If you have many rare labels in your data, you should train with `--holdout_off`, that disables usage of holdout (validation) dataset for early stopping.


## References

Expand Down
24 changes: 18 additions & 6 deletions demo/plt/plt_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# This is demo example that demonstrates usage of PLT reduction on few popular multilabel datasets.

# Select dataset
dataset = "mediamill_exp1" # should be in ["mediamill_exp1", "eurlex", "rcv1x", "wiki10", "amazonCat"]
dataset = "eurlex" # should be in ["mediamill_exp1", "eurlex", "rcv1x", "wiki10", "amazonCat"]

# Select reduction
reduction = "plt" # should be in ["plt", "multilabel_oaa"]
Expand All @@ -18,10 +18,16 @@
output_model = f"{dataset}_{reduction}_model"

# Parameters
kary_tree = 16
l = 0.5
passes = 3
other_training_params = "--holdout_off"
kary_tree = 16 # arity of the tree,
# higher values usually give better results, but increase training time

passes = 5 # number of passes over the dataset,
# for some datasets you might want to change number of passes

l = 0.5 # learning rate

other_training_params = "--holdout_off" # because these datasets have many rare labels,
# disabling the holdout set improves the final performance

# dict with params for different datasets (k and b)
params_dict = {
Expand All @@ -34,14 +40,20 @@
if dataset in params_dict:
k, b = params_dict[dataset]
else:
print(f"Dataset {dataset} is not supported for this demo.")
print(f"Dataset {dataset} is not supported by this demo.")

# Download dataset (source: http://manikvarma.org/downloads/XC/XMLRepository.html)
# Datasets were transformed to VW's format,
# and features values were normalized (this helps with performance).
if not os.path.exists(train_data):
print("Downloading train dataset:")
os.system("wget http://www.cs.put.poznan.pl/mwydmuch/data/{}".format(train_data))

if not os.path.exists(test_data):
print("Downloading test dataset:")
os.system("wget http://www.cs.put.poznan.pl/mwydmuch/data/{}".format(test_data))


print(f"\nTraining Vowpal Wabbit {reduction} on {dataset} dataset:\n")
start = time.time()
train_cmd = f"vw {train_data} -c --{reduction} {k} --loss_function logistic -l {l} --passes {passes} -b {b} -f {output_model} {other_training_params}"
Expand Down
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_predict.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = MULTILABELS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 2,1 2
0.000000 0.000000 4 4.0 3,4 4,3 2
0.000000 0.000000 8 8.0 8 8 2
1.655385 1.655385 1 1.0 0,1 1 2
2.160409 2.665433 2 2.0 1,2 2,1 2
1.881014 1.601619 4 4.0 3,4 4,3 2
1.582159 1.283305 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 1.375444
total feature number = 20
hamming loss = 0.200000
micro-precision = 1.000000
Expand Down
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_predict_probabilities.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = ACTION_PROBS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 2,1 2
0.000000 0.000000 4 4.0 3,4 4,3 2
0.000000 0.000000 8 8.0 8 8 2
1.655385 1.655385 1 1.0 0,1 1 2
2.160409 2.665433 2 2.0 1,2 2,1 2
1.881014 1.601619 4 4.0 3,4 4,3 2
1.582159 1.283305 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 1.375444
total feature number = 20
hamming loss = 0.200000
micro-precision = 1.000000
Expand Down
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_sgd_predict.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = MULTILABELS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 0,1 2
0.000000 0.000000 2 2.0 1,2 2,1,8 2
0.000000 0.000000 4 4.0 3,4 8 2
0.000000 0.000000 8 8.0 8 1,8 2
0.249245 0.249245 1 1.0 0,1 0,1 2
2.069833 3.890422 2 2.0 1,2 2,1,8 2
3.383714 4.697595 4 4.0 3,4 8 2
3.461616 3.539518 8 8.0 8 1,8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 2.905217
total feature number = 20
hamming loss = 1.700000
micro-precision = 0.562500
Expand Down
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_sgd_top1_predict.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = MULTILABELS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 2 2
0.000000 0.000000 4 4.0 3,4 8 2
0.000000 0.000000 8 8.0 8 8 2
0.249245 0.249245 1 1.0 0,1 1 2
2.069833 3.890422 2 2.0 1,2 2 2
3.383714 4.697595 4 4.0 3,4 8 2
3.461616 3.539518 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 2.905217
total feature number = 20
p@1 = 0.600000
r@1 = 0.450000
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_top1_predict.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = MULTILABELS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 1 2
0.000000 0.000000 4 4.0 3,4 3 2
0.000000 0.000000 8 8.0 8 8 2
1.655385 1.655385 1 1.0 0,1 1 2
2.160409 2.665433 2 2.0 1,2 1 2
1.881014 1.601619 4 4.0 3,4 3 2
1.582159 1.283305 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 1.375444
total feature number = 20
p@1 = 1.000000
r@1 = 0.625000
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_top1_predict_probabilities.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = ACTION_PROBS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 1 2
0.000000 0.000000 4 4.0 3,4 3 2
0.000000 0.000000 8 8.0 8 8 2
1.655385 1.655385 1 1.0 0,1 1 2
2.160409 2.665433 2 2.0 1,2 1 2
1.881014 1.601619 4 4.0 3,4 3 2
1.582159 1.283305 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 1.375444
total feature number = 20
p@1 = 1.000000
r@1 = 0.625000
86 changes: 58 additions & 28 deletions vowpalwabbit/core/src/reductions/plt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,8 @@ inline float learn_node(plt& p, uint32_t n, learner& base, VW::example& ec)
return ec.loss;
}

void learn(plt& p, learner& base, VW::example& ec)
void get_nodes_to_update(plt& p, VW::multilabel_label& multilabels)
{
auto multilabels = std::move(ec.l.multilabels);
VW::polyprediction pred = std::move(ec.pred);

double t = p.all->sd->t;
double weighted_holdout_examples = p.all->sd->weighted_holdout_examples;
p.all->sd->weighted_holdout_examples = 0;

p.positive_nodes.clear();
p.negative_nodes.clear();

Expand Down Expand Up @@ -139,7 +132,16 @@ void learn(plt& p, learner& base, VW::example& ec)
}
}
else { p.negative_nodes.insert(0); }
}

void learn(plt& p, learner& base, VW::example& ec)
{
auto multilabels = std::move(ec.l.multilabels);
VW::polyprediction pred = std::move(ec.pred);

get_nodes_to_update(p, multilabels);

double t = p.all->sd->t;
float loss = 0;
ec.l.simple = {1.f};
ec.ex_reduction_features.template get<VW::simple_label_reduction_features>().reset_to_default();
Expand All @@ -149,8 +151,6 @@ void learn(plt& p, learner& base, VW::example& ec)
for (auto& n : p.negative_nodes) { loss += learn_node(p, n, base, ec); }

p.all->sd->t = t;
p.all->sd->weighted_holdout_examples = weighted_holdout_examples;

ec.loss = loss;
ec.pred = std::move(pred);
ec.l.multilabels = std::move(multilabels);
Expand All @@ -166,12 +166,37 @@ inline float predict_node(uint32_t n, learner& base, VW::example& ec)
return sigmoid(ec.partial_prediction);
}

inline float evaluate_node(uint32_t n, learner& base, VW::example& ec)
{
base.predict(ec, n);
return ec.loss;
}

template <bool threshold>
void predict(plt& p, learner& base, VW::example& ec)
{
auto multilabels = std::move(ec.l.multilabels);
VW::polyprediction pred = std::move(ec.pred);

// if true labels are present (e.g., predicting on holdout set),
// calculate training loss without updating base learner/weights
if (multilabels.label_v.size() > 0)
{
get_nodes_to_update(p, multilabels);

float loss = 0;
ec.l.simple = {1.f};
ec.ex_reduction_features.template get<VW::simple_label_reduction_features>().reset_to_default();
for (auto& n : p.positive_nodes) { loss += evaluate_node(n, base, ec); }

ec.l.simple.label = -1.f;
for (auto& n : p.negative_nodes) { loss += evaluate_node(n, base, ec); }

ec.loss = loss;
}
else { ec.loss = 0; }

// start real prediction
if (p.probabilities) { pred.a_s.clear(); }
pred.multilabels.label_v.clear();

Expand Down Expand Up @@ -220,18 +245,21 @@ void predict(plt& p, learner& base, VW::example& ec)
}
}

// calculate evaluation measures
uint32_t tp = 0;
uint32_t pred_size = pred.multilabels.label_v.size();

for (uint32_t i = 0; i < pred_size; ++i)
// if there are true labels, calculate evaluation measures
if (p.true_labels.size() > 0)
{
uint32_t pred_label = pred.multilabels.label_v[i];
if (p.true_labels.count(pred_label)) { ++tp; }
uint32_t tp = 0;
uint32_t pred_size = pred.multilabels.label_v.size();

for (uint32_t i = 0; i < pred_size; ++i)
{
uint32_t pred_label = pred.multilabels.label_v[i];
if (p.true_labels.count(pred_label)) { ++tp; }
}
p.tp += tp;
p.fp += static_cast<uint32_t>(pred_size) - tp;
p.fn += static_cast<uint32_t>(p.true_labels.size()) - tp;
}
p.tp += tp;
p.fp += static_cast<uint32_t>(pred_size) - tp;
p.fn += static_cast<uint32_t>(p.true_labels.size()) - tp;
}

// top-k prediction
Expand Down Expand Up @@ -270,21 +298,23 @@ void predict(plt& p, learner& base, VW::example& ec)
}
}

// calculate precision and recall at
float tp_at = 0;
for (size_t i = 0; i < p.top_k; ++i)
// if there are true labels, calculate precision and recall at k
if (p.true_labels.size() > 0)
{
uint32_t pred_label = pred.multilabels.label_v[i];
if (p.true_labels.count(pred_label)) { tp_at += 1; }
p.p_at[i] += tp_at / (i + 1);
if (p.true_labels.size() > 0) { p.r_at[i] += tp_at / p.true_labels.size(); }
float tp_at = 0;
for (size_t i = 0; i < p.top_k; ++i)
{
uint32_t pred_label = pred.multilabels.label_v[i];
if (p.true_labels.count(pred_label)) { tp_at += 1; }
p.p_at[i] += tp_at / (i + 1);
if (p.true_labels.size() > 0) { p.r_at[i] += tp_at / p.true_labels.size(); }
}
}
}

++p.ec_count;
p.node_queue.clear();

ec.loss = 0;
ec.pred = std::move(pred);
ec.l.multilabels = std::move(multilabels);
}
Expand Down