Skip to content

Commit

Permalink
KOL-3467 - Thresholded (#249)
Browse files Browse the repository at this point in the history
* Added ThresholdedMetrics class to support thresholds, updated validators and evaluator to support this new class.
  • Loading branch information
diegokolena authored Nov 6, 2023
1 parent ca407d6 commit c85ef45
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 2 deletions.
2 changes: 2 additions & 0 deletions kolena/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@
from .test_run import TestRun
from .test_run import test
from .define_workflow import define_workflow
from .thresholded import ThresholdedMetrics

__all__ = [
"DataObject",
"ThresholdedMetrics",
"Metadata",
"Image",
"ImagePair",
Expand Down
4 changes: 3 additions & 1 deletion kolena/workflow/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
from kolena.workflow._datatypes import DataObject
from kolena.workflow.annotation import _ANNOTATION_TYPES
from kolena.workflow.asset import _ASSET_TYPES
from kolena.workflow.thresholded import ThresholdedMetrics

_SUPPORTED_FIELD_TYPES = [*_SCALAR_TYPES, *_ANNOTATION_TYPES, *_ASSET_TYPES]

_SUPPORTED_FIELD_TYPES = [*_SCALAR_TYPES, *_ANNOTATION_TYPES, *_ASSET_TYPES, ThresholdedMetrics]


def assert_workflows_match(workflow_expected: str, workflow_provided: str) -> None:
Expand Down
6 changes: 5 additions & 1 deletion kolena/workflow/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class TestCaseMetrics(MetricsTestCase):
macro_F1: float
mAP: float
PerClass: List[PerClassMetrics]
ThresholdClass: List[ThresholdDataObjectClass]
```
Any `str`-type fields (e.g. `Class` in the above example) will be used as identifiers when displaying nested metrics
Expand Down Expand Up @@ -268,7 +269,10 @@ def _validate_metrics_test_sample_type(metrics_test_sample_type: Type[MetricsTes


def _validate_metrics_test_case_type(metrics_test_case_type: Type[DataObject]) -> None:
validate_scalar_data_object_type(metrics_test_case_type, supported_list_types=[MetricsTestCase])
validate_scalar_data_object_type(
metrics_test_case_type,
supported_list_types=[MetricsTestCase],
)

# validate that there is only one level of nesting
for field_name, field_type in get_data_object_field_types(metrics_test_case_type).items():
Expand Down
91 changes: 91 additions & 0 deletions kolena/workflow/thresholded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2021-2023 Kolena Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABCMeta
from dataclasses import dataclass
from dataclasses import fields

from kolena.workflow._datatypes import TypedDataObject


class PreventThresholdOverrideMeta(ABCMeta, type):
def __new__(cls, name, bases, dct):
if "threshold" in dct.get("__annotations__", {}):
for base in bases:
if base.__name__ == "ThresholdedMetrics":
raise TypeError(f"Subclasses of {base.__name__} cannot override 'threshold'")
return super().__new__(cls, name, bases, dct)


@dataclass(frozen=True)
class ThresholdedMetrics(TypedDataObject, metaclass=PreventThresholdOverrideMeta):
"""
Represents metrics tied to a specific threshold.
`List[ThresholdedMetrics]` should be used as a field type within `MetricsTestSample` or
`MetricsTestCase` from the `kolena.workflow` module. This list is meant to hold metric values
associated with distinct thresholds. These metrics are expected to be uniform across `TestSample`
instances within a single test execution.
`ThresholdedMetrics` prohibits the use of dictionary objects as field values and guarantees that
the threshold values remain immutable once set. For application within a particular workflow,
subclassing is required to define relevant metrics fields.
Usage example:
```python
from kolena.workflow import MetricsTestSample
from kolena.workflow import ThresholdedMetrics
@dataclass(frozen=True)
class ClassThresholdedMetrics(ThresholdedMetrics):
precision: float
recall: float
f1: float
@dataclass(frozen=True)
class TestSampleMetrics(MetricsTestSample):
car: List[ClassThresholdedMetrics]
pedestrian: List[ClassThresholdedMetrics]
# Creating an instance of metrics
metric = TestSampleMetrics(
car=[
ClassThresholdedMetrics(threshold=0.3, precision=0.5, recall=0.8, f1=0.615),
ClassThresholdedMetrics(threshold=0.4, precision=0.6, recall=0.6, f1=0.6),
ClassThresholdedMetrics(threshold=0.5, precision=0.8, recall=0.4, f1=0.533),
# ...
],
pedestrian=[
ClassThresholdedMetrics(threshold=0.3, precision=0.6, recall=0.9, f1=0.72),
ClassThresholdedMetrics(threshold=0.4, precision=0.7, recall=0.7, f1=0.7),
ClassThresholdedMetrics(threshold=0.5, precision=0.8, recall=0.6, f1=0.686),
# ...
],
)
```
Raises:
TypeError: If any of the field values is a dictionary.
"""

threshold: float

def _data_type() -> str:
return "METRICS/THRESHOLDED"

def __post_init__(self) -> None:
for field in fields(self):
field_value = getattr(self, field.name)
if isinstance(field_value, dict):
raise TypeError(f"Field '{field.name}' should not be a dictionary")
35 changes: 35 additions & 0 deletions tests/unit/workflow/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from kolena.workflow.evaluator import MetricsTestCase
from kolena.workflow.evaluator import MetricsTestSample
from kolena.workflow.evaluator import MetricsTestSuite
from kolena.workflow.thresholded import ThresholdedMetrics


def test__validate__metrics_test_sample() -> None:
Expand Down Expand Up @@ -242,3 +243,37 @@ class NestedNested(MetricsTestCase):
@dataclasses.dataclass(frozen=True)
class Nested2DTester(MetricsTestCase):
a: List[NestedNested] # only one layer of nesting allowed


def test__validate__metrics_test_case__fail_overwrite_field() -> None:
with pytest.raises(TypeError):

@dataclasses.dataclass(frozen=True)
class MyThresholdedMetrics(ThresholdedMetrics):
threshold: str # overwrite type


def test__validate__metrics_test_sample__with_thresholded_metric_field() -> None:
@dataclasses.dataclass(frozen=True)
class MyThresholded(ThresholdedMetrics):
score: float

@dataclasses.dataclass(frozen=True)
class MyMetrics(MetricsTestSample):
thresholded_scores: List[MyThresholded]

sample = MyMetrics(
thresholded_scores=[
MyThresholded(threshold=0.1, score=0.5),
MyThresholded(threshold=0.5, score=0.6),
MyThresholded(threshold=0.9, score=0.7),
],
)

assert len(sample.thresholded_scores) == 3
assert sample.thresholded_scores[0].threshold == 0.1
assert sample.thresholded_scores[0].score == 0.5
assert sample.thresholded_scores[1].threshold == 0.5
assert sample.thresholded_scores[1].score == 0.6
assert sample.thresholded_scores[2].threshold == 0.9
assert sample.thresholded_scores[2].score == 0.7
80 changes: 80 additions & 0 deletions tests/unit/workflow/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from kolena.workflow.annotation import Polygon
from kolena.workflow.annotation import Polyline
from kolena.workflow.asset import ImageAsset
from kolena.workflow.thresholded import ThresholdedMetrics


def test__validate_data_object__invalid() -> None:
Expand Down Expand Up @@ -171,3 +172,82 @@ class NestedTester(DataObject):

with pytest.raises(ValueError):
validate_data_object_type(NestedTester)


def test__validate_field__thresholded() -> None:
@dataclasses.dataclass(frozen=True)
class MyThresholdedMetrics(ThresholdedMetrics):
a: List[str]
b: List[bool]
c: List[int]
d: List[float]
e: List[BoundingBox]
g: List[Polygon]
i: List[Keypoints]
j: List[Polyline]
k: List[Union[BoundingBox, BoundingBox]] # Redundant Union can be simplified
l: float
m: int
n: str

MyThresholdedMetrics(
threshold=1.0,
a=["1"],
b=[True, False],
c=[1],
d=[1.0],
e=[BoundingBox((1, 1), (2, 2))],
g=[Polygon(points=[(0, 0), (1, 1), (2, 2), (0, 0)])],
i=[Keypoints(points=[(10, 10), (11, 11), (12, 12)])],
j=[Polyline(points=[(0, 0), (1, 1), (2, 2)])],
k=[
BoundingBox((1, 1), (2, 2)),
BoundingBox(top_left=[1, 1], bottom_right=[10, 10]),
],
l=1.0,
m=1,
n="str",
)


def test__validate_field__thresholded__no_initialize_threshold_invalid() -> None:
@dataclasses.dataclass(frozen=True)
class MyThresholdedMetrics(ThresholdedMetrics):
a: List[str]

with pytest.raises(TypeError):
MyThresholdedMetrics(a=["1"])


def test__validate_field__thresholded__avoid_reserved_field_name() -> None:
with pytest.raises(TypeError):

@dataclasses.dataclass(frozen=True)
class MyThresholdedMetrics(ThresholdedMetrics):
threshold: str
a: float

MyThresholdedMetrics(threshold="1", a=1.0)


def test__validate_field__thresholded__invalid_dict_field() -> None:
@dataclasses.dataclass(frozen=True)
class MyThresholdedMetrics(ThresholdedMetrics):
a: Dict[str, str]

with pytest.raises(TypeError):
MyThresholdedMetrics(threshold=1, a={"key": "value"})


def test__validate_field__thresholded__invalid_nested_field() -> None:
@dataclasses.dataclass(frozen=True)
class Nested(DataObject):
a: float

@dataclasses.dataclass(frozen=True)
class MyThresholdedMetrics(ThresholdedMetrics):
a: Nested

with pytest.raises(TypeError):
n = Nested(a=1.0)
MyThresholdedMetrics(a=n)

0 comments on commit c85ef45

Please sign in to comment.