diff --git a/mesa/space.py b/mesa/space.py index 286f63c8fa3..6ed7c09697e 100644 --- a/mesa/space.py +++ b/mesa/space.py @@ -68,10 +68,9 @@ def accept_tuple_argument(wrapped_function: F) -> F: single-item list rather than forcing user to do it.""" def wrapper(grid_instance, positions) -> Any: - if isinstance(positions, tuple) and len(positions) == 2: - return wrapped_function(grid_instance, [positions]) - else: - return wrapped_function(grid_instance, positions) + if len(positions) == 2 and not isinstance(positions[0], tuple): + positions = [positions] + return wrapped_function(grid_instance, positions) return cast(F, wrapper) diff --git a/tests/test_space.py b/tests/test_space.py index 8691dc45f0a..f8f2cc9440c 100644 --- a/tests/test_space.py +++ b/tests/test_space.py @@ -327,6 +327,28 @@ def move_agent(self): assert self.space[initial_pos[0]][initial_pos[1]] is None assert self.space[final_pos[0]][final_pos[1]] == _agent + def test_iter_cell_list_contents(self): + """ + Test neighborhood retrieval + """ + cell_list_1 = list(self.space.iter_cell_list_contents(TEST_AGENTS_GRID[0])) + assert len(cell_list_1) == 1 + + cell_list_2 = list( + self.space.iter_cell_list_contents( + (TEST_AGENTS_GRID[0], TEST_AGENTS_GRID[1]) + ) + ) + assert len(cell_list_2) == 2 + + cell_list_3 = list(self.space.iter_cell_list_contents(tuple(TEST_AGENTS_GRID))) + assert len(cell_list_3) == 3 + + cell_list_4 = list( + self.space.iter_cell_list_contents((TEST_AGENTS_GRID[0], (0, 0))) + ) + assert len(cell_list_4) == 1 + class TestSingleNetworkGrid(unittest.TestCase): GRAPH_SIZE = 10