Skip to content

Commit

Permalink
Merge pull request #33 from AllenNeuralDynamics/dep-update-aind-behav…
Browse files Browse the repository at this point in the history
…ior-services

Update aind-behavior-services to 0.9
  • Loading branch information
bruno-f-cruz authored Jan 10, 2025
2 parents 974f29f + 1de040e commit 417268d
Show file tree
Hide file tree
Showing 6 changed files with 391 additions and 520 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"pydantic>=2.7, <3.0",
"gitpython",
"semver",
"aind_behavior_services>=0.8.0, <0.9.0",
"aind_behavior_services<=0.9",
"aind-slims-api<0.2"
]

Expand Down
42 changes: 2 additions & 40 deletions src/aind_behavior_experiment_launcher/data_mapper/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import xml.etree.ElementTree as ET
from importlib import metadata
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, get_args
from typing import Dict, List, TypeVar, Union

import pydantic
from aind_behavior_services import (
AindBehaviorRigModel,
)
from aind_behavior_services.rig import CameraController, CameraTypes
from aind_behavior_services.utils import get_fields_of_type

logger = logging.getLogger(__name__)

Expand All @@ -36,45 +37,6 @@ def get_cameras(


ISearchable = Union[pydantic.BaseModel, Dict, List]
_ISearchableTypeChecker = tuple(get_args(ISearchable)) # pre-compute for performance


def get_fields_of_type(
searchable: ISearchable,
target_type: Type[T],
*args,
recursive: bool = True,
stop_recursion_on_type: bool = True,
**kwargs,
) -> List[Tuple[Optional[str], T]]:
_iterable: Iterable
_is_type: bool
result: List[Tuple[Optional[str], T]] = []

if isinstance(searchable, dict):
_iterable = searchable.items()
elif isinstance(searchable, list):
_iterable = list(zip([None for _ in range(len(searchable))], searchable))
elif isinstance(searchable, pydantic.BaseModel):
_iterable = {k: getattr(searchable, k) for k in searchable.model_fields.keys()}.items()
else:
raise ValueError(f"Unsupported model type: {type(searchable)}")

for name, field in _iterable:
_is_type = False
if isinstance(field, target_type):
result.append((name, field))
_is_type = True
if recursive and isinstance(field, _ISearchableTypeChecker) and not (stop_recursion_on_type and _is_type):
result.extend(
get_fields_of_type(
field,
target_type,
recursive=recursive,
stop_recursion_on_type=stop_recursion_on_type,
)
)
return result


def _sanity_snapshot_keys(snapshot: Dict[str, str]) -> Dict[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def transfer(self) -> None:
self.force_restart(kill_if_running=False)
except subprocess.CalledProcessError as e:
logger.error("Failed to start watchdog service. %s", e)
raise RuntimeError("Failed to start watchdog service.")
raise RuntimeError("Failed to start watchdog service.") from e
else:
if not self.is_running():
logger.error("Failed to start watchdog service.")
Expand Down Expand Up @@ -260,10 +260,10 @@ def _get_project_names(
) -> list[str]:
response = requests.get(end_point, timeout=timeout)
if response.ok:
content = json.loads(response.content)
return json.loads(response.content)["data"]
else:
response.raise_for_status()
return content["data"]
raise HTTPError(f"Failed to fetch project names from endpoint. {response.content}")

def is_running(self) -> bool:
output = subprocess.check_output(
Expand All @@ -290,7 +290,7 @@ def dump_manifest_config(self, path: Optional[os.PathLike] = None, make_dir: boo

path = (Path(path) if path else Path(watch_config.flag_dir) / f"manifest_{manifest_config.name}.yaml").resolve()
if "manifest" not in path.name:
logger.warning("Prefix " "manifest_" " not found in file name. Appending it.")
logger.warning("Prefix manifest_ not found in file name. Appending it.")
path = path.with_name(f"manifest_{path.name}.yaml")

if make_dir and not path.parent.exists():
Expand Down
2 changes: 1 addition & 1 deletion src/aind_behavior_experiment_launcher/ui_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def prompt_pick_file_from_list(
if zero_label is not None:
self._print(f"0: {zero_label}")
for i, file in enumerate(available_files):
self._print(f"{i+1}: {os.path.split(file)[1]}")
self._print(f"{i + 1}: {os.path.split(file)[1]}")
choice = int(input("Choice: "))
if choice < 0 or choice >= len(available_files) + 1:
raise ValueError
Expand Down
57 changes: 1 addition & 56 deletions tests/test_data_mapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import unittest

# from . import TESTS_ASSETS, REPO_ROOT
from pathlib import Path
from typing import Dict, List, Optional
from unittest.mock import patch
Expand All @@ -9,12 +7,10 @@

from aind_behavior_experiment_launcher.data_mapper.helpers import (
_sanity_snapshot_keys,
get_fields_of_type,
snapshot_bonsai_environment,
snapshot_python_environment,
)

from . import TESTS_ASSETS
from tests import TESTS_ASSETS


class MockModel(BaseModel):
Expand All @@ -26,57 +22,6 @@ class MockModel(BaseModel):
sub_model: Optional["MockModel"] = None


class TestGetFieldsOfType(unittest.TestCase):
def test_get_fields_of_type_dict(self):
data = {"field1": 1, "field2": "test", "field3": [1, 2, 3], "field4": {"key1": 1, "key2": 2}, "field5": None}
result = get_fields_of_type(data, int, recursive=False)
expected = [("field1", 1)]
self.assertEqual(result, expected)

result = get_fields_of_type(data, int, recursive=True)
expected = [("field1", 1), (None, 1), (None, 2), (None, 3), ("key1", 1), ("key2", 2)]
self.assertEqual(result, expected)

def test_get_fields_of_type_list(self):
data = [1, "test", [1, 2, 3], {"key1": 1, "key2": 2}, None]
result = get_fields_of_type(data, int, recursive=False)
expected = [(None, 1)]
self.assertEqual(result, expected)

result = get_fields_of_type(data, int, recursive=True)
expected = [(None, 1), (None, 1), (None, 2), (None, 3), ("key1", 1), ("key2", 2)]
self.assertEqual(result, expected)

def test_get_fields_of_type_pydantic_model(self):
model = MockModel(field1=1, field2="test", field3=[1, 2, 3], field4={"key1": 1, "key2": 2})
result = get_fields_of_type(model, int, recursive=False)
expected = [("field1", 1)]
self.assertEqual(result, expected)

result = get_fields_of_type(model, int, recursive=True)
expected = [("field1", 1), (None, 1), (None, 2), (None, 3), ("key1", 1), ("key2", 2)]
self.assertEqual(result, expected)

def test_get_fields_of_type_stop_recursion(self):
sub_model = MockModel(field1=1, field2="test", field3=[1, 3, 36], field4={"key1": 2, "key2": 3})
model = MockModel(field1=1, field2="test", field3=[1, 2, 3], field4={"key1": 1, "key2": 2}, sub_model=sub_model)
data = {
"field1": 1,
"field2": "test",
"field3": [1, 2, {"nested_field": 3}],
"field4": {"key1": 1, "key2": 2},
"field5": None,
"field6": model,
}
result = get_fields_of_type(data, MockModel, recursive=True, stop_recursion_on_type=True)
expected = [("field6", model)]
self.assertEqual(result, expected)

result = get_fields_of_type(data, MockModel, recursive=True, stop_recursion_on_type=False)
expected = [("field6", model), ("sub_model", sub_model)]
self.assertEqual(result, expected)


class TestHelpers(unittest.TestCase):
@patch("importlib.metadata.distributions")
def test_snapshot_python_environment(self, mock_distributions):
Expand Down
Loading

0 comments on commit 417268d

Please sign in to comment.