Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add a training loss calculation to the predict method of PLT reduction #4534

Merged
11 changes: 7 additions & 4 deletions demo/plt/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Probabilistic Label Tree demo
-------------------------------
-----------------------------

This demo presents PLT for applications of logarithmic time multilabel classification.
It uses Mediamill dataset from the [LIBLINEAR datasets repository](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html)
Expand All @@ -12,19 +12,22 @@ The datasets and paremeters can be easliy edited in the script. The script requi
## PLT options
```
--plt Probabilistic Label Tree with <k> labels
--kary_tree use <k>-ary tree. By default = 2 (binary tree)
--threshold predict labels with conditional marginal probability greater than <thr> threshold"
--kary_tree use <k>-ary tree. By default = 2 (binary tree),
higher values usually give better results, but increase training time
--threshold predict labels with conditional marginal probability greater than <thr> threshold
--top_k predict top-<k> labels instead of labels above threshold
```


## Tips for using PLT
PLT accelerates training and prediction for a large number of classes,
if you have less than 10000 classes, you should probably use OAA.
If you have a huge number of labels and features at the same time,
you will need as many bits (`-b`) as can afford computationally for the best performance.
You may also consider using `--sgd` instead of default adaptive, normalized, and invariant updates to
gain more memory for feature weights what may lead to better performance.
You may also consider using `--holdout_off` if you have many rare labels in your data.
If you have many rare labels in your data, you should train with `--holdout_off`, that disables usage of holdout (validation) dataset for early stopping.


## References

Expand Down
28 changes: 18 additions & 10 deletions demo/plt/plt_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
test_data = f"{dataset}_test.vw"
output_model = f"{dataset}_{reduction}_model"

# Parameters
kary_tree = 16
l = 0.5
passes = 3
other_training_params = "--holdout_off"
# Parameters (for some datasets you might want to change number of passes)
kary_tree = 16 # arity of the tree,
# higher values usually give better results, but increase training time
passes = 5 # number of passes over the dataset
l = 0.5 # learning rate
other_training_params = "--holdout_off" # because these dataset have many rare labels,
# disabling holdout set improves final performance

# dict with params for different datasets (k and b)
params_dict = {
Expand All @@ -34,17 +36,23 @@
if dataset in params_dict:
k, b = params_dict[dataset]
else:
print(f"Dataset {dataset} is not supported for this demo.")
print(f"Dataset {dataset} is not supported by this demo.")

# Download dataset (source: http://manikvarma.org/downloads/XC/XMLRepository.html)
# Datasets were transformed to VW's format,
# and features values were normalized (this helps with performance).
if not os.path.exists(train_data):
os.system("wget http://www.cs.put.poznan.pl/mwydmuch/data/{}".format(train_data))
print("Downloading train dataset:")
os.system("wget http://www.cs.put.poznan.pl/mwydmuch/data/{}".format(train_data))

if not os.path.exists(test_data):
print("Downloading test dataset:")
os.system("wget http://www.cs.put.poznan.pl/mwydmuch/data/{}".format(test_data))


print(f"\nTraining Vowpal Wabbit {reduction} on {dataset} dataset:\n")
start = time.time()
train_cmd = f"vw {train_data} -c --{reduction} {k} --loss_function logistic -l {l} --passes {passes} -b {b} -f {output_model} {other_training_params}"
train_cmd = f"../../build/vowpalwabbit/cli/vw {train_data} -c --{reduction} {k} --loss_function logistic -l {l} --passes {passes} -b {b} -f {output_model} {other_training_params}"
if reduction == "plt":
train_cmd += f" --kary_tree {kary_tree}"
print(train_cmd)
Expand All @@ -55,7 +63,7 @@

print("\nTesting with probability threshold = 0.5 (default prediction mode)\n")
start = time.time()
test_threshold_cmd = f"vw {test_data} -i {output_model} --loss_function logistic -t"
test_threshold_cmd = f"../../build/vowpalwabbit/cli/vw {test_data} -i {output_model} --loss_function logistic -t"
if reduction == "plt":
test_threshold_cmd += " --threshold 0.5"
print(test_threshold_cmd)
Expand All @@ -68,7 +76,7 @@
print("\nTesting with top-5 prediction\n")
start = time.time()
test_topk_cmd = (
f"vw {test_data} -i {output_model} --loss_function logistic --top_k 5 -t"
f"../../build/vowpalwabbit/cli/vw {test_data} -i {output_model} --loss_function logistic --top_k 5 -t"
)
print(test_topk_cmd)
os.system(test_topk_cmd)
Expand Down
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_predict.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = MULTILABELS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 2,1 2
0.000000 0.000000 4 4.0 3,4 4,3 2
0.000000 0.000000 8 8.0 8 8 2
1.655385 1.655385 1 1.0 0,1 1 2
2.160409 2.665433 2 2.0 1,2 2,1 2
1.881014 1.601619 4 4.0 3,4 4,3 2
1.582159 1.283305 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 1.375444
total feature number = 20
hamming loss = 0.200000
micro-precision = 1.000000
Expand Down
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_predict_probabilities.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = ACTION_PROBS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 2,1 2
0.000000 0.000000 4 4.0 3,4 4,3 2
0.000000 0.000000 8 8.0 8 8 2
1.655385 1.655385 1 1.0 0,1 1 2
2.160409 2.665433 2 2.0 1,2 2,1 2
1.881014 1.601619 4 4.0 3,4 4,3 2
1.582159 1.283305 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 1.375444
total feature number = 20
hamming loss = 0.200000
micro-precision = 1.000000
Expand Down
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_sgd_predict.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = MULTILABELS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 0,1 2
0.000000 0.000000 2 2.0 1,2 2,1,8 2
0.000000 0.000000 4 4.0 3,4 8 2
0.000000 0.000000 8 8.0 8 1,8 2
0.249245 0.249245 1 1.0 0,1 0,1 2
2.069833 3.890422 2 2.0 1,2 2,1,8 2
3.383714 4.697595 4 4.0 3,4 8 2
3.461616 3.539518 8 8.0 8 1,8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 2.905217
total feature number = 20
hamming loss = 1.700000
micro-precision = 0.562500
Expand Down
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_sgd_top1_predict.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = MULTILABELS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 2 2
0.000000 0.000000 4 4.0 3,4 8 2
0.000000 0.000000 8 8.0 8 8 2
0.249245 0.249245 1 1.0 0,1 1 2
2.069833 3.890422 2 2.0 1,2 2 2
3.383714 4.697595 4 4.0 3,4 8 2
3.461616 3.539518 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 2.905217
total feature number = 20
p@1 = 0.600000
r@1 = 0.450000
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_top1_predict.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = MULTILABELS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 1 2
0.000000 0.000000 4 4.0 3,4 3 2
0.000000 0.000000 8 8.0 8 8 2
1.655385 1.655385 1 1.0 0,1 1 2
2.160409 2.665433 2 2.0 1,2 1 2
1.881014 1.601619 4 4.0 3,4 3 2
1.582159 1.283305 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 1.375444
total feature number = 20
p@1 = 1.000000
r@1 = 0.625000
10 changes: 5 additions & 5 deletions test/train-sets/ref/plt_top1_predict_probabilities.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Input label = MULTILABEL
Output pred = ACTION_PROBS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 0,1 1 2
0.000000 0.000000 2 2.0 1,2 1 2
0.000000 0.000000 4 4.0 3,4 3 2
0.000000 0.000000 8 8.0 8 8 2
1.655385 1.655385 1 1.0 0,1 1 2
2.160409 2.665433 2 2.0 1,2 1 2
1.881014 1.601619 4 4.0 3,4 3 2
1.582159 1.283305 8 8.0 8 8 2

finished run
number of examples = 10
weighted example sum = 10.000000
weighted label sum = 0.000000
average loss = 0.000000
average loss = 1.375444
total feature number = 20
p@1 = 1.000000
r@1 = 0.625000
87 changes: 61 additions & 26 deletions vowpalwabbit/core/src/reductions/plt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,8 @@ inline float learn_node(plt& p, uint32_t n, learner& base, VW::example& ec)
return ec.loss;
}

void learn(plt& p, learner& base, VW::example& ec)
void get_nodes_to_update(plt& p, VW::multilabel_label& multilabels)
{
auto multilabels = std::move(ec.l.multilabels);
VW::polyprediction pred = std::move(ec.pred);

double t = p.all->sd->t;
double weighted_holdout_examples = p.all->sd->weighted_holdout_examples;
p.all->sd->weighted_holdout_examples = 0;

p.positive_nodes.clear();
p.negative_nodes.clear();

Expand Down Expand Up @@ -139,6 +132,18 @@ void learn(plt& p, learner& base, VW::example& ec)
}
}
else { p.negative_nodes.insert(0); }
}

void learn(plt& p, learner& base, VW::example& ec)
{
auto multilabels = std::move(ec.l.multilabels);
VW::polyprediction pred = std::move(ec.pred);

double t = p.all->sd->t;
double weighted_holdout_examples = p.all->sd->weighted_holdout_examples;
p.all->sd->weighted_holdout_examples = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why reset this to zero?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. This is there from the initial implementation, it is not a change I introduced in this PR, but I don't remember why it's here. Does it somehow impact the update of the base classifier? That probably could be a reason for "resetting" this variable. If not, this can probably be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it seems that it doesn't impact the training so I think I will remove these.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!


get_nodes_to_update(p, multilabels);

float loss = 0;
ec.l.simple = {1.f};
Expand Down Expand Up @@ -166,12 +171,37 @@ inline float predict_node(uint32_t n, learner& base, VW::example& ec)
return sigmoid(ec.partial_prediction);
}

inline float evaluate_node(uint32_t n, learner& base, VW::example& ec)
{
base.predict(ec, n);
return ec.loss;
}

template <bool threshold>
void predict(plt& p, learner& base, VW::example& ec)
{
auto multilabels = std::move(ec.l.multilabels);
VW::polyprediction pred = std::move(ec.pred);

// if true labels are present (e.g., predicting on holdout set),
// calculate training loss without updating base learner/weights
if (multilabels.label_v.size() > 0)
{
get_nodes_to_update(p, multilabels);

float loss = 0;
ec.l.simple = {1.f};
ec.ex_reduction_features.template get<VW::simple_label_reduction_features>().reset_to_default();
for (auto& n : p.positive_nodes) { loss += evaluate_node(n, base, ec); }

ec.l.simple.label = -1.f;
for (auto& n : p.negative_nodes) { loss += evaluate_node(n, base, ec); }

ec.loss = loss;
}
else { ec.loss = 0; }

// start real prediction
if (p.probabilities) { pred.a_s.clear(); }
pred.multilabels.label_v.clear();

Expand Down Expand Up @@ -220,18 +250,21 @@ void predict(plt& p, learner& base, VW::example& ec)
}
}

// calculate evaluation measures
uint32_t tp = 0;
uint32_t pred_size = pred.multilabels.label_v.size();

for (uint32_t i = 0; i < pred_size; ++i)
// if there are true labels, calculate evaluation measures
if (p.true_labels.size() > 0)
{
uint32_t pred_label = pred.multilabels.label_v[i];
if (p.true_labels.count(pred_label)) { ++tp; }
uint32_t tp = 0;
uint32_t pred_size = pred.multilabels.label_v.size();

for (uint32_t i = 0; i < pred_size; ++i)
{
uint32_t pred_label = pred.multilabels.label_v[i];
if (p.true_labels.count(pred_label)) { ++tp; }
}
p.tp += tp;
p.fp += static_cast<uint32_t>(pred_size) - tp;
p.fn += static_cast<uint32_t>(p.true_labels.size()) - tp;
}
p.tp += tp;
p.fp += static_cast<uint32_t>(pred_size) - tp;
p.fn += static_cast<uint32_t>(p.true_labels.size()) - tp;
}

// top-k prediction
Expand Down Expand Up @@ -270,21 +303,23 @@ void predict(plt& p, learner& base, VW::example& ec)
}
}

// calculate precision and recall at
float tp_at = 0;
for (size_t i = 0; i < p.top_k; ++i)
// if there are true labels, calculate precision and recall at k
if (p.true_labels.size() > 0)
{
uint32_t pred_label = pred.multilabels.label_v[i];
if (p.true_labels.count(pred_label)) { tp_at += 1; }
p.p_at[i] += tp_at / (i + 1);
if (p.true_labels.size() > 0) { p.r_at[i] += tp_at / p.true_labels.size(); }
float tp_at = 0;
for (size_t i = 0; i < p.top_k; ++i)
{
uint32_t pred_label = pred.multilabels.label_v[i];
if (p.true_labels.count(pred_label)) { tp_at += 1; }
p.p_at[i] += tp_at / (i + 1);
if (p.true_labels.size() > 0) { p.r_at[i] += tp_at / p.true_labels.size(); }
}
}
}

++p.ec_count;
p.node_queue.clear();

ec.loss = 0;
ec.pred = std::move(pred);
ec.l.multilabels = std::move(multilabels);
}
Expand Down