-
Notifications
You must be signed in to change notification settings - Fork 699
[ENH] Improve test framework for v1 metrics #1907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for object_instance
or tag-based retrieval to work, the class needs to have a get_test_params
method, and possibly inherit from skbase
BaseObject
. It may be that this is not going to be compatible with the LightningMetric
- so I think there are a few options here:
- is it possible to make the
skbase
retrieval work without addingskbase
BaseObject
to thepytorch-forecasting
Metric
? To do this, you can sayobject_type_filter = Metric
(the class). - another option would be not using the
skbase
framework, and instead usingall_objects
from the_registry
, then combine with this withpytest.mark.parameterize
etc. - finally, a third option is to introduce the
pkg
architecture for metrics as well - similar to the models, where anscikit-base
class is the index for the metrics.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1907 +/- ##
=======================================
Coverage ? 87.06%
=======================================
Files ? 136
Lines ? 8618
Branches ? 0
=======================================
Hits ? 7503
Misses ? 1115
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@PranavBhatP, do you have an opinion on the options? Perhaps others that I do not see? |
Currently I have hardcoded the process of obtaining metrics' classes, as I want to get all the tests working first. I will be mostly be taking a route similar to the below. Yet to explore more efficient options.
|
ok, that is probably the smartest route - use |
@fkiraly @phoeenniixx They cover all the basic functionalities for metrics in v1. Currently, the way I'm handling the tests is to loop all the existing data input formats in the package with metrics, without going into the specifics of which loss function accepts what kind of input format for next with this design,
METRIC_COMPATIBILITY = {
# Point metrics
"MAE": {
"point": True,
"quantile": False,
"packed_sequence_2d": True,
"packed_sequence_3d": False,
"weighted_2d": True,
"weighted_3d": False,
"distribution": False,
"classification": False,
},
...
} note: i'm not sure if this current design and suggested improvements is good practice. I'm open to critical comments and improvements, since I'm new to testing metrics and I might be missing something crucial as well. With the vast amount of metric-specific cases present, I think this design does a decent job with ensuring at least the basic contracts are held in place, but again, the trade-off is the huge amount of tests skipped, simply because of tests containing incompatible data formats and metrics being ignored. There is still work left, with implementing metric-specific tests (for losses like |
Very nice! I will comment on the compatibility matrix first: based on our discussion, I was of the impression that there were only 3 classes of compatibilities: point prediction losses, quantile losses, distribution losses. Is this no longer a working assumption? Because if it is still valid, the compatibility matrix could be dealt with by:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
I left some comments above, two recurring ones:
- we should avoid try/except inside tests to catch exceptions, since we want to raise the exact exception if something failes.
- if possible right now, we should also avoid stepouts for individual cases. Though I realize it might be more prudent to reduce the amount of stepouts step by step
- I think we should also add some negative cases, things that we know should fail - for example, to test error handling. The
pytest
pattern here iswith pytest.raises
i have made changes:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for the rework!
I am wondering though if there is a misunderstanding about "splitting the tests".
- when I mean "splitting", I was specifically referring to split one method starting with
test_sth
into other methods starting withtest
, all being self-contained - not making many private mehtods and having them called from a single test, in sequence.- in this vein, I would not consider the current large all-in-one test to be "split" at the current state of the PR
- Also, I thought there would be 7 or 8 such methods in total, not 4.
Okay, I get what you are saying. I think there was a misunderstanding from my end, I will implement each test separately if required. I just assumed calling it in a sequence would make sense since we would re-use the same instance of metric and the data for the set of "private" methods.
There are 6 steps inside
|
yes, exactly! And those I thought should simply be separate tests, that can run (and fail) independently. |
@fkiraly I think the PR looks neat now. I have found all non-conformances (listed in |
pytorch_forecasting/metrics/_point_pkg/_cross_entropy/_cross_entropy_pkg.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
I have one change request related to the data/scenario dispatch via _setup_metric_test_scenario
: we should make this more idiomatic using tags, and leaner:
- we should remove the model training in the fixture generation altogether. I would just create the dict of tensors from scratch. Otherwise we are basically doing an integration test hidden in the fixture, where we run the entire workflow of model training and prediction.
- we should replace
requires_data_type
with a simpleget_tag
call. - I would move
_setup_metric_test_scenario
into one or two_generate
methods that producey_pred
andy_true
. - I think there is a place for integration testing - we could pick a model that works with all losses, and only vary the metric, in a separate test? Though I would not want to introduce feature or scope creep into this PR.
There must have been a misunderstanding. There is no model being trained in fixture generation, rather we use an instance of the |
Yes, my mistake. May I refine my comment then: I do not like that we are using the v1 data loader and In particular, since we may not change metrics for v2, but the data loading will change substantially - the current testing strategy may lead to substantial changes to the test framework and possibly unknown issues with the metrics. If instead we can create dicts of tensors from scratch, this would be avoided. Futher, the "unexpected errors" you describe feel suspect. This should not be the case if we have understood the metrics API properly, so it now feels extra risky to consider this complete. What exactly fails if you try to generate the same outputs from scratch? |
Yes, what you say makes sense, but this PR's title also specifies testing in the sense of v1. It might be a pain to switch to the v2 version of the dataloader to test the metrics (even though they remain the same in v2). I will try seeing if its possible to do this with a dict and keep it independent of the version of the v1 and v2 datasets and dataloaders.
These errors are related to how the data is handled before it is passed to the metric. Normally, we need to preprocess all the data using the |
Agreed on the first part but not sure if I agree with the second:
If so, then it is an error of the models API and not a problem with the metric. We do not need to replicate all behaviour of the |
i've greatly simplified the data generation in the fixtures using dictionaries, the change is a simple refactor. |
Should I do this in a separate PR? |
yes, I would say, separate PR. |
Reference Issues/PRs
Fixes #1904.
What does this implement/fix? Explain your changes.
Implements a skbase fixture generator class for common tests for all metrics in v1. Common testing is successful on all metrics except 1 -
MASE
(point). These special metrics are addressed in their own tests undertest_metrics.py
.This PR also makes changes to the source location of SkbaseBaseObject -
_BaseObject
and is now shared bytest_all_estimators
andtest_all_metrics
.This PR has 4 aspects:
BasePtMetric
: an_SkbaseBaseObject
class that allows the metric package containers to be discovered for testing. It implements the class methods for retrieving the name, theclass
of the metric and the set of test_params for testing a metric. It also implement the method which retrieves the exact name of the data generation pytest fixture for each metric.Individual modules for the package containers of each metric. Derives from
BasePtMetric
and contains overriden methods from the base class.TestAllPtMetrics
is the main testing class of the library. it implements the integration test for metrics and test for reduction of computed loss.MetricFixtureGenerator
is a child class ofBaseFixtureGenerator
for generating fixtures for the tests. It implement a separate case of_generate_object_instance
to allow parametrization of the test_params passed into a metric during its initialisationException and skipped tests due to non-conformance:
to_prediction
, requires us to pass its own prediction_length and not themax_prediction_length
from the model.to_quantiles
is non-conformant with the expectation of point metric. It returns a 3d quantile dimension (equal to the len(quantiles))when a 2d tensor is passed into a call of
to_quantileswith
PoissonLoss. For point metrics this is usually
(batch_size, prediction_length, 1)`.loss()
method API, where theloss()
returns a scalar value unlike all the remaining metrics which return either a 2d tensor or a 3d tensor (distribution - 2d tensor with log prob calculated for batch and timestepsPR checklist
pre-commit install
.To run hooks independent of commit, execute
pre-commit run --all-files