Skip to content

Commit

Permalink
Prevent circular import
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Jan 14, 2025
1 parent c7f3f8f commit d4a41f1
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 49 deletions.
49 changes: 1 addition & 48 deletions bluecast/config/base_classes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import inspect
from abc import ABC, abstractmethod
from functools import wraps
from typing import Any, Dict, Literal, Union, get_args, get_origin, get_type_hints
from typing import Any, Dict, Literal, Union

import pandas as pd

Expand Down Expand Up @@ -36,48 +34,3 @@ def retrieve_results_as_df(self) -> pd.DataFrame:
Retrieve results from the ExperimentTracker class
"""
pass


def check_types_init(init_method):
@wraps(init_method)
def wrapper(self, *args, **kwargs):
sig = inspect.signature(init_method)
type_hints = get_type_hints(init_method)

bound_arguments = sig.bind(self, *args, **kwargs)
bound_arguments.apply_defaults()

for name, value in bound_arguments.arguments.items():
if name == "self":
continue

expected_type = type_hints.get(name)
if expected_type is None:
continue

# A small helper function to handle Union/Optional:
if not _matches_type(value, expected_type):
raise TypeError(
f"Argument '{name}' must be of type '{expected_type}', "
f"but got value '{value}' of type '{type(value)}'."
)

return init_method(self, *args, **kwargs)

return wrapper


def _matches_type(value, expected_type) -> bool:
"""Return True if 'value' matches the 'expected_type' annotation."""
origin = get_origin(expected_type)
args = get_args(expected_type)

if origin is None:
# expected_type is a regular (non-parameterized) type like int or float
return isinstance(value, expected_type)
elif origin is Union:
# e.g. Union[str, int]
return any(_matches_type(value, t) for t in args)
else:
# fallback to a direct isinstance check
return isinstance(value, expected_type)
48 changes: 48 additions & 0 deletions bluecast/config/config_validations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import inspect
from functools import wraps
from typing import Union, get_args, get_origin, get_type_hints


def check_types_init(init_method):
@wraps(init_method)
def wrapper(self, *args, **kwargs):
sig = inspect.signature(init_method)
type_hints = get_type_hints(init_method)

bound_arguments = sig.bind(self, *args, **kwargs)
bound_arguments.apply_defaults()

for name, value in bound_arguments.arguments.items():
if name == "self":
continue

expected_type = type_hints.get(name)
if expected_type is None:
continue

# A small helper function to handle Union/Optional:
if not _matches_type(value, expected_type):
raise TypeError(
f"Argument '{name}' must be of type '{expected_type}', "
f"but got value '{value}' of type '{type(value)}'."
)

return init_method(self, *args, **kwargs)

return wrapper


def _matches_type(value, expected_type) -> bool:
"""Return True if 'value' matches the 'expected_type' annotation."""
origin = get_origin(expected_type)
args = get_args(expected_type)

if origin is None:
# expected_type is a regular (non-parameterized) type like int or float
return isinstance(value, expected_type)
elif origin is Union:
# e.g. Union[str, int]
return any(_matches_type(value, t) for t in args)
else:
# fallback to a direct isinstance check
return isinstance(value, expected_type)
2 changes: 1 addition & 1 deletion bluecast/config/training_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing import Dict, List, Literal, Optional, Tuple

from bluecast.config.base_classes import check_types_init
from bluecast.config.config_validations import check_types_init


class TrainingConfig:
Expand Down

0 comments on commit d4a41f1

Please sign in to comment.