Skip to content

Commit

Permalink
[feat] Allow mixing numeric types on a single Sequence (#1913)
Browse files Browse the repository at this point in the history
  • Loading branch information
alberttorosyan authored Jun 27, 2022
1 parent c7088f1 commit b6a6d0f
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 9 deletions.
2 changes: 1 addition & 1 deletion aim/sdk/sequences/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Metric(Sequence):
@classmethod
def allowed_dtypes(cls) -> Union[str, Tuple[str, ...]]:
# TODO remove 'float64': temporary fix for repos generated with aim < 3.0.7
return 'float', 'float64', 'int'
return 'float', 'float64', 'int', 'number'

@classmethod
def sequence_name(cls) -> str:
Expand Down
3 changes: 2 additions & 1 deletion aim/sdk/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def _update_sequence_info(self, ctx_id: int, name: str, val, step: int):
dtype = get_object_typename(val)

if seq_info.dtype is not None:
def update_trace_dtype(new_dtype):
def update_trace_dtype(old_dtype: str, new_dtype: str):
logger.warning(f'Updating sequence \'{name}\' data type from {old_dtype} to f{new_dtype}.')
self.meta_tree['traces_types', new_dtype, ctx_id, name] = 1
self.meta_run_tree['traces', ctx_id, name, 'dtype'] = new_dtype
seq_info.dtype = new_dtype
Expand Down
15 changes: 12 additions & 3 deletions aim/sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
import uuid
from contextlib import contextmanager
from typing import Union, Any, Tuple
from typing import Union, Any, Tuple, Optional, Callable

from aim.sdk.configs import get_aim_repo_name

Expand Down Expand Up @@ -78,12 +78,21 @@ def get_object_typename(obj) -> str:
any_list_regex = re.compile(r'list\([A-Za-z]{1}[A-Za-z0-9.]*\)')


def check_types_compatibility(dtype: str, base_dtype: str, update_base_dtype_fn=None) -> bool:
def check_types_compatibility(
dtype: str,
base_dtype: str,
update_base_dtype_fn: Optional[Callable[[str, str], None]] = None) -> bool:
if dtype == base_dtype:
return True
if base_dtype == 'number' and dtype in {'int', 'float'}:
return True
if {dtype, base_dtype} == {'int', 'float'}:
if update_base_dtype_fn is not None:
update_base_dtype_fn(base_dtype, 'number')
return True
if base_dtype == 'list' and any_list_regex.match(dtype):
if update_base_dtype_fn is not None:
update_base_dtype_fn(dtype)
update_base_dtype_fn(base_dtype, dtype)
return True
if dtype == 'list' and any_list_regex.match(base_dtype):
return True
Expand Down
22 changes: 18 additions & 4 deletions tests/sdk/test_run_track_type_checking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tests.base import TestBase

from aim.sdk.run import Run
from aim.sdk.sequences.metric import Metric
from aim.storage.context import Context


Expand All @@ -19,9 +20,9 @@ def test_incompatible_type_during_tracking(self):
run = Run(system_tracking_interval=None)
run.track(1., name='numbers', context={})
with self.assertRaises(ValueError) as cm:
run.track(1, name='numbers', context={})
run.track([1], name='numbers', context={})
exception = cm.exception
self.assertEqual('Cannot log value \'1\' on sequence \'numbers\'. Incompatible data types.', exception.args[0])
self.assertEqual('Cannot log value \'[1]\' on sequence \'numbers\'. Incompatible data types.', exception.args[0])

def test_incompatible_type_after_tracking_restart(self):
run = Run(system_tracking_interval=None)
Expand All @@ -32,9 +33,9 @@ def test_incompatible_type_after_tracking_restart(self):

new_run = Run(run_hash=run_hash, system_tracking_interval=None)
with self.assertRaises(ValueError) as cm:
new_run.track(1, name='numbers', context={})
new_run.track([1], name='numbers', context={})
exception = cm.exception
self.assertEqual('Cannot log value \'1\' on sequence \'numbers\'. Incompatible data types.', exception.args[0])
self.assertEqual('Cannot log value \'[1]\' on sequence \'numbers\'. Incompatible data types.', exception.args[0])

def test_type_compatibility_for_empty_list(self):
run = Run(system_tracking_interval=None)
Expand Down Expand Up @@ -67,3 +68,16 @@ def test_type_compatibility_for_empty_list(self):
self.assertEqual(
f'Cannot log value \'{[5]}\' on sequence \'{seq_name}\'. Incompatible data types.',
exception.args[0])

def test_int_float_compatibility(self):
run = Run(system_tracking_interval=None)

# float first
run.track(1., name='float numbers', context={})
run.track(1, name='float numbers', context={})
run.track(1., name='float numbers', context={})

# int first
run.track(1, name='int numbers', context={})
run.track(1., name='int numbers', context={})
run.track(1, name='int numbers', context={})

0 comments on commit b6a6d0f

Please sign in to comment.