Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add radius argument to NetworkGrid.get_neighbors() #1973

Merged
merged 1 commit into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mesa/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,9 +1539,11 @@ def get_neighborhood(
neighborhood = sorted(neighbors_with_distance.keys())
return neighborhood

def get_neighbors(self, node_id: int, include_center: bool = False) -> list[Agent]:
"""Get all agents in adjacent nodes."""
neighborhood = self.get_neighborhood(node_id, include_center)
def get_neighbors(
self, node_id: int, include_center: bool = False, radius: int = 1
) -> list[Agent]:
"""Get all agents in adjacent nodes (within a certain radius)."""
neighborhood = self.get_neighborhood(node_id, include_center, radius)
return self.get_cell_list_contents(neighborhood)

def move_agent(self, agent: Agent, node_id: int) -> None:
Expand Down
44 changes: 43 additions & 1 deletion tests/test_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,12 +867,54 @@ def test_agent_positions(self):
a = self.agents[i]
assert a.pos == pos

def test_get_neighbors(self):
def test_get_neighborhood(self):
assert len(self.space.get_neighborhood(0, include_center=True)) == 3
assert len(self.space.get_neighborhood(0, include_center=False)) == 2
assert len(self.space.get_neighborhood(2, include_center=True, radius=3)) == 7
assert len(self.space.get_neighborhood(2, include_center=False, radius=3)) == 6

def test_get_neighbors(self):
"""
Test the get_neighbors method with varying radius and include_center values. Note there are agents on node 0, 1 and 5.
"""
# Test with default radius (1) and include_center = False
neighbors_default = self.space.get_neighbors(0, include_center=False)
self.assertEqual(
len(neighbors_default),
1,
"Should have 1 neighbors with default radius and exclude center",
)

# Test with default radius (1) and include_center = True
neighbors_include_center = self.space.get_neighbors(0, include_center=True)
self.assertEqual(
len(neighbors_include_center),
2,
"Should have 2 neighbors (including center) with default radius",
)

# Test with radius = 2 and include_center = False
neighbors_radius_2 = self.space.get_neighbors(0, include_center=False, radius=5)
expected_count_radius_2 = 2
self.assertEqual(
len(neighbors_radius_2),
expected_count_radius_2,
f"Should have {expected_count_radius_2} neighbors with radius 2 and exclude center",
)

# Test with radius = 2 and include_center = True
neighbors_radius_2_include_center = self.space.get_neighbors(
0, include_center=True, radius=5
)
expected_count_radius_2_include_center = (
3 # Adjust this based on your network structure
)
self.assertEqual(
len(neighbors_radius_2_include_center),
expected_count_radius_2_include_center,
f"Should have {expected_count_radius_2_include_center} neighbors (including center) with radius 2",
)

def test_move_agent(self):
initial_pos = 1
agent_number = 1
Expand Down