Skip to content

[ENH] Improve test framework for v1 metrics #1907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 74 commits into
base: main
Choose a base branch
from

Conversation

PranavBhatP
Copy link
Contributor

@PranavBhatP PranavBhatP commented Jul 1, 2025

Reference Issues/PRs

Fixes #1904.

What does this implement/fix? Explain your changes.

Implements a skbase fixture generator class for common tests for all metrics in v1. Common testing is successful on all metrics except 1 - MASE(point). These special metrics are addressed in their own tests under test_metrics.py.

This PR also makes changes to the source location of SkbaseBaseObject - _BaseObject and is now shared by test_all_estimators and test_all_metrics.

This PR has 4 aspects:

  • BasePtMetric : an _SkbaseBaseObject class that allows the metric package containers to be discovered for testing. It implements the class methods for retrieving the name, the class of the metric and the set of test_params for testing a metric. It also implement the method which retrieves the exact name of the data generation pytest fixture for each metric.

  • Individual modules for the package containers of each metric. Derives from BasePtMetric and contains overriden methods from the base class.

  • TestAllPtMetrics is the main testing class of the library. it implements the integration test for metrics and test for reduction of computed loss.

  • MetricFixtureGenerator is a child class of BaseFixtureGenerator for generating fixtures for the tests. It implement a separate case of _generate_object_instance to allow parametrization of the test_params passed into a metric during its initialisation

Exception and skipped tests due to non-conformance:

  1. MQF2DistributionLoss - skipped due to non-conformance with the design for to_prediction, requires us to pass its own prediction_length and not the max_prediction_length from the model.
  2. PoissonLoss - it's implementation of to_quantiles is non-conformant with the expectation of point metric. It returns a 3d quantile dimension (equal to the len(quantiles))when a 2d tensor is passed into a call ofto_quantileswithPoissonLoss. For point metrics this is usually (batch_size, prediction_length, 1)`.
  3. MultivariateNormalDistribution - This metric is compliant with most of the existing contracts for the API. It is non-compliant only with the loss() method API, where the loss() returns a scalar value unlike all the remaining metrics which return either a 2d tensor or a 3d tensor (distribution - 2d tensor with log prob calculated for batch and timesteps

PR checklist

  • The PR title starts with either [ENH], [MNT], [DOC], or [BUG]. [BUG] - bugfix, [MNT] - CI, test framework, [ENH] - adding or improving code, [DOC] - writing or improving documentation or docstrings.
  • Added/modified tests
  • Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with pre-commit install.
    To run hooks independent of commit, execute pre-commit run --all-files

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for object_instance or tag-based retrieval to work, the class needs to have a get_test_params method, and possibly inherit from skbase BaseObject. It may be that this is not going to be compatible with the LightningMetric - so I think there are a few options here:

  • is it possible to make the skbase retrieval work without adding skbase BaseObject to the pytorch-forecasting Metric? To do this, you can say object_type_filter = Metric (the class).
  • another option would be not using the skbase framework, and instead using all_objects from the _registry, then combine with this with pytest.mark.parameterize etc.
  • finally, a third option is to introduce the pkg architecture for metrics as well - similar to the models, where an scikit-base class is the index for the metrics.

Copy link

codecov bot commented Jul 1, 2025

Codecov Report

❌ Patch coverage is 73.52941% with 54 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (main@3093b9f). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...g/_log_normal/_log_normal_distribution_loss_pkg.py 52.38% 10 Missing ⚠️
...h_forecasting/metrics/base_metrics/_base_object.py 65.00% 7 Missing ⚠️
...implicit_quantile_network_distribution_loss_pkg.py 71.42% 4 Missing ⚠️
...ributions_pkg/_mqf2/_mqf2_distribution_loss_pkg.py 71.42% 4 Missing ⚠️
...ributions_pkg/_beta/_beta_distribution_loss_pkg.py 72.72% 3 Missing ⚠️
...nomial/_negative_binomial_distribution_loss_pkg.py 72.72% 3 Missing ⚠️
...asting/metrics/_quantile_pkg/_quantile_loss_pkg.py 70.00% 3 Missing ⚠️
...rmal/_multivariate_normal_distribution_loss_pkg.py 71.42% 2 Missing ⚠️
...tions_pkg/_normal/_normal_distribution_loss_pkg.py 71.42% 2 Missing ⚠️
...cs/_point_pkg/_cross_entropy/_cross_entropy_pkg.py 71.42% 2 Missing ⚠️
... and 7 more
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1907   +/-   ##
=======================================
  Coverage        ?   87.06%           
=======================================
  Files           ?      136           
  Lines           ?     8618           
  Branches        ?        0           
=======================================
  Hits            ?     7503           
  Misses          ?     1115           
  Partials        ?        0           
Flag Coverage Δ
cpu 87.06% <73.52%> (?)
pytest 87.06% <73.52%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@fkiraly
Copy link
Collaborator

fkiraly commented Jul 1, 2025

for object_instance or tag-based retrieval to work, the class needs to have a get_test_params method, and possibly inherit from skbase BaseObject. It may be that this is not going to be compatible with the LightningMetric - so I think there are a few options here:

@PranavBhatP, do you have an opinion on the options? Perhaps others that I do not see?

@PranavBhatP
Copy link
Contributor Author

PranavBhatP commented Jul 1, 2025

@PranavBhatP, do you have an opinion on the options? Perhaps others that I do not see?

Currently I have hardcoded the process of obtaining metrics' classes, as I want to get all the tests working first. I will be mostly be taking a route similar to the below. Yet to explore more efficient options.

another option would be not using the skbase framework, and instead using all_objects from the _registry, then combine with this with pytest.mark.parameterize etc.

@fkiraly
Copy link
Collaborator

fkiraly commented Jul 1, 2025

ok, that is probably the smartest route - use all_objects for retrieval, vanilla pytest, and once it works make the test framework more programmatic.

@PranavBhatP
Copy link
Contributor Author

PranavBhatP commented Jul 1, 2025

@fkiraly @phoeenniixx
I've implemented a very rudimentary test framework for the metrics. It uses hard-coded imports (for now) and uses vanilla pytest. All tests are passing and it includes a majority of the metrics (except 3-4)

They cover all the basic functionalities for metrics in v1. Currently, the way I'm handling the tests is to loop all the existing data input formats in the package with metrics, without going into the specifics of which loss function accepts what kind of input format for y_pred. If there is an inherent incompatibility (for ex: giving a y_pred 2d tensor to a quantile loss - this is not supposed to work anyways) we catch the error and simply skip the test with pytest.skip(). This provides convenience over an alternate design which might have to map every data format to the metrics. If the data format is accepted by the metric, we then proceed with the rest of the tests to assert the required conditions to ensure the metric is working fine.

next with this design,

  1. I'm thinking of implementing a small counter variable within each test to check if a metric is compatible with al least one of the provided data formats. if this counter is zero after looping through all the data, then we raise an error since there is clearly something wrong with the data
    (or)
  2. Using a compatibility matrix, for hard-coded compatibilities between metrics and the provided data formats, and checking if the metric fails for that case or not. This can drastically reduce the number of skipped tests. But very tedious to extend this matrix everytime some new metric is added to the library. It would look something like this.
METRIC_COMPATIBILITY = {
    # Point metrics
    "MAE": {
        "point": True,
        "quantile": False,
        "packed_sequence_2d": True,
        "packed_sequence_3d": False,
        "weighted_2d": True,
        "weighted_3d": False,
        "distribution": False,
        "classification": False,
    },
    ...
}

note: i'm not sure if this current design and suggested improvements is good practice. I'm open to critical comments and improvements, since I'm new to testing metrics and I might be missing something crucial as well. With the vast amount of metric-specific cases present, I think this design does a decent job with ensuring at least the basic contracts are held in place, but again, the trade-off is the huge amount of tests skipped, simply because of tests containing incompatible data formats and metrics being ignored.

There is still work left, with implementing metric-specific tests (for losses like MASE, CrossEntropy and some distribution losses with very specific implementations) and other tests like usage of MultiLoss. The currently ignored tests are commented in the get_all_metrics function. which cannot be covered with the existing class.

@fkiraly
Copy link
Collaborator

fkiraly commented Jul 2, 2025

Very nice!

I will comment on the compatibility matrix first: based on our discussion, I was of the impression that there were only 3 classes of compatibilities: point prediction losses, quantile losses, distribution losses.

Is this no longer a working assumption? Because if it is still valid, the compatibility matrix could be dealt with by:

  • a single flag that has three values
  • a dict that maps single flag values on more granular compatibility dicts

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

I left some comments above, two recurring ones:

  • we should avoid try/except inside tests to catch exceptions, since we want to raise the exact exception if something failes.
  • if possible right now, we should also avoid stepouts for individual cases. Though I realize it might be more prudent to reduce the amount of stepouts step by step
  • I think we should also add some negative cases, things that we know should fail - for example, to test error handling. The pytest pattern here is with pytest.raises

@PranavBhatP PranavBhatP requested a review from fkiraly August 2, 2025 09:38
@PranavBhatP
Copy link
Contributor Author

PranavBhatP commented Aug 2, 2025

i have made changes:

  1. added a test exclusively for the loss method. This test actually uncovered a non-conformance in one of the metrics - MultivariateNormalDistributionLoss, but it is not a very major one. I have highlighted the cause of non-conformance in the description of the PR. Tests for the loss method are skipped on this metric.
  2. split the metric integration test into 4 parts as suggested in the change request.

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for the rework!

I am wondering though if there is a misunderstanding about "splitting the tests".

  • when I mean "splitting", I was specifically referring to split one method starting with test_sth into other methods starting with test, all being self-contained - not making many private mehtods and having them called from a single test, in sequence.
    • in this vein, I would not consider the current large all-in-one test to be "split" at the current state of the PR
  • Also, I thought there would be 7 or 8 such methods in total, not 4.

@PranavBhatP
Copy link
Contributor Author

PranavBhatP commented Aug 2, 2025

when I mean "splitting", I was specifically referring to split one method starting with test_sth into other methods starting with test, all being self-contained - not making many private mehtods and having them called from a single test, in sequence.
in this vein, I would not consider the current large all-in-one test to be "split" at the current state of the PR

Okay, I get what you are saying. I think there was a misunderstanding from my end, I will implement each test separately if required. I just assumed calling it in a sequence would make sense since we would re-use the same instance of metric and the data for the set of "private" methods.

I thought there would be 7 or 8 such methods in total, not 4

There are 6 steps inside test_metric_functionality if you look at the sequence.

  1. check for valid metric_type
  2. check for metric reset/update/compute
  3. check for to_prediction
  4. check for to_quantiles
  5. check for composite metrics under test_composite_and_weighted_metrics
  6. check for weighted metrics under test_composite_and_weighted_metrics

@fkiraly
Copy link
Collaborator

fkiraly commented Aug 2, 2025

yes, exactly! And those I thought should simply be separate tests, that can run (and fail) independently.

@PranavBhatP
Copy link
Contributor Author

PranavBhatP commented Aug 3, 2025

@fkiraly I think the PR looks neat now. I have found all non-conformances (listed in _config.py) and the tests are individual now.

@PranavBhatP PranavBhatP requested a review from fkiraly August 3, 2025 15:31
Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

I have one change request related to the data/scenario dispatch via _setup_metric_test_scenario: we should make this more idiomatic using tags, and leaner:

  • we should remove the model training in the fixture generation altogether. I would just create the dict of tensors from scratch. Otherwise we are basically doing an integration test hidden in the fixture, where we run the entire workflow of model training and prediction.
  • we should replace requires_data_type with a simple get_tag call.
  • I would move _setup_metric_test_scenario into one or two _generate methods that produce y_pred and y_true.
  • I think there is a place for integration testing - we could pick a model that works with all losses, and only vary the metric, in a separate test? Though I would not want to introduce feature or scope creep into this PR.

@PranavBhatP
Copy link
Contributor Author

PranavBhatP commented Aug 3, 2025

we should remove the model training in the fixture generation altogether. I would just create the dict of tensors from scratch

There must have been a misunderstanding. There is no model being trained in fixture generation, rather we use an instance of the TimeSeriesDataset (input with pd.DataFrame) to extract x and y from the dataloader. This helps mimic how the data generation would actually work before we even pass data to the metrics. This data generation is independent of metrics. I have tried creating dict of tensors from scratch, it was leading to unexpected errors.

@fkiraly
Copy link
Collaborator

fkiraly commented Aug 4, 2025

This helps mimic how the data generation would actually work before we even pass data to the metrics. This data generation is independent of metrics. I have tried creating dict of tensors from scratch, it was leading to unexpected errors.

Yes, my mistake.

May I refine my comment then: I do not like that we are using the v1 data loader and TimeSeriesDataset here. It should be possible to produce valid inputs from scratch.

In particular, since we may not change metrics for v2, but the data loading will change substantially - the current testing strategy may lead to substantial changes to the test framework and possibly unknown issues with the metrics. If instead we can create dicts of tensors from scratch, this would be avoided.

Futher, the "unexpected errors" you describe feel suspect. This should not be the case if we have understood the metrics API properly, so it now feels extra risky to consider this complete.

What exactly fails if you try to generate the same outputs from scratch?

@PranavBhatP
Copy link
Contributor Author

PranavBhatP commented Aug 4, 2025

In particular, since we may not change metrics for v2, but the data loading will change substantially - the current testing strategy may lead to substantial changes to the test framework and possibly unknown issues with the metrics. If instead we can create dicts of tensors from scratch, this would be avoided.

Yes, what you say makes sense, but this PR's title also specifies testing in the sense of v1. It might be a pain to switch to the v2 version of the dataloader to test the metrics (even though they remain the same in v2). I will try seeing if its possible to do this with a dict and keep it independent of the version of the v1 and v2 datasets and dataloaders.

Futher, the "unexpected errors" you describe feel suspect. This should not be the case if we have understood the metrics API properly, so it now feels extra risky to consider this complete.

These errors are related to how the data is handled before it is passed to the metric. Normally, we need to preprocess all the data using the TimeSeriesDataset before passing it to the metric via the dataloader, so the errors are not related to the metric. If we are going to this do this with independent of TimeSeriesDataset, we might have to replicate internals of TimeSeriesDataset inside the fixtures. Is this is a good idea? Otherwise I don't think it has anything to do with metrics.

@fkiraly
Copy link
Collaborator

fkiraly commented Aug 6, 2025

Agreed on the first part but not sure if I agree with the second:

These errors are related to how the data is handled before it is passed to the metric.

If so, then it is an error of the models API and not a problem with the metric. We do not need to replicate all behaviour of the TimeSeriesDataset exaclty, we only need to cover the metrics contract with input/outputs etc.

@PranavBhatP
Copy link
Contributor Author

i've greatly simplified the data generation in the fixtures using dictionaries, the change is a simple refactor.

@PranavBhatP
Copy link
Contributor Author

I think there is a place for integration testing - we could pick a model that works with all losses, and only vary the metric, in a separate test? Though I would not want to introduce feature or scope creep into this PR.

Should I do this in a separate PR?

@PranavBhatP PranavBhatP requested a review from fkiraly August 6, 2025 19:03
@fkiraly
Copy link
Collaborator

fkiraly commented Aug 6, 2025

yes, I would say, separate PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: PR in progress
Development

Successfully merging this pull request may close these issues.

[ENH] improve test framework for metrics in v1.
3 participants