Skip to content

Commit

Permalink
Add region_aware stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
peterrrock2 committed Jan 9, 2024
1 parent 627e6c6 commit d5fa077
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 30 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,7 @@ junit.xml

# Pytest cache
.pytest_cache/
Dockerfile
Dockerfile

# Extra Documentation Stuff
release_notes.md
10 changes: 9 additions & 1 deletion gerrychain/proposals/tree_proposals.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
from inspect import signature
from ..random import random

from ..tree import (
Expand All @@ -9,7 +10,9 @@


def recom(
partition, pop_col, pop_target, epsilon, node_repeats=1, method=bipartition_tree
partition, pop_col, pop_target, epsilon, node_repeats=1,
weight_dict = None,
method=bipartition_tree
):
"""ReCom proposal.
Expand Down Expand Up @@ -45,6 +48,11 @@ def recom(
partition.parts[parts_to_merge[0]] | partition.parts[parts_to_merge[1]]
)

# Try to add the region aware in if the method accepts the weight dictionary
if 'weight_dict' in signature(method).parameters:
method = partial(method, weight_dict=weight_dict)


flips = recursive_tree_part(
subgraph.graph,
parts_to_merge,
Expand Down
92 changes: 64 additions & 28 deletions gerrychain/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from networkx.algorithms import tree

from functools import partial
from inspect import signature
from .random import random
from collections import deque, namedtuple
from typing import Any, Callable, Dict, List, Optional, Set, Union, Sequence
Expand All @@ -15,19 +16,28 @@ def successors(h: nx.Graph, root: Any) -> Dict:
return {a: b for a, b in nx.bfs_successors(h, root)}


def random_spanning_tree(graph: nx.Graph) -> nx.Graph:
""" Builds a spanning tree chosen by Kruskal's method using random weights.
:param graph: FrozenGraph
Important Note:
The key is specifically labelled "random_weight" instead of the previously
used "weight". Turns out that networkx uses the "weight" keyword for other
operations, like when computing the laplacian or the adjacency matrix.
This meant that the laplacian would change for the graph step to step,
something that we do not intend!!
def random_spanning_tree(graph: nx.Graph, weight_dict: Dict) -> nx.Graph:
"""
Builds a spanning tree chosen by Kruskal's method using random weights.
:param graph: The input graph to build the spanning tree from. Should be a Networkx Graph.
:type graph: nx.Graph
:param weight_dict: Dictionary of weights to add to the random weights used in region-aware variants.
:type weight_dict: Dict
:return: The maximal spanning tree represented as a Networkx Graph.
:rtype: nx.Graph
"""
for edge in graph.edge_indices:
graph.edges[edge]["random_weight"] = random.random()
if weight_dict is None:
weight_dict = dict()

for edge in graph.edges():
weight = random.random()
for key, value in weight_dict.items():
if graph.nodes[edge[0]][key] == graph.nodes[edge[1]][key] and \
graph.nodes[edge[0]][key] is not None:
weight += value

graph.edges[edge]["random_weight"] = weight

spanning_tree = tree.maximum_spanning_tree(
graph, algorithm="kruskal", weight="random_weight"
Expand Down Expand Up @@ -179,35 +189,61 @@ def bipartition_tree(
node_repeats: int = 1,
spanning_tree: Optional[nx.Graph] = None,
spanning_tree_fn: Callable = random_spanning_tree,
weight_dict: Dict = None,
balance_edge_fn: Callable = find_balanced_edge_cuts_memoization,
choice: Callable = random.choice,
max_attempts: Optional[int] = 10000

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

It looks like you did not remove the old max_attempts default argument of None.

max_attempts: Optional[int] = None
) -> Set:
"""This function finds a balanced 2 partition of a graph by drawing a
"""
This function finds a balanced 2 partition of a graph by drawing a
spanning tree and finding an edge to cut that leaves at most an epsilon
imbalance between the populations of the parts. If a root fails, new roots
are tried until node_repeats in which case a new tree is drawn.
Builds up a connected subgraph with a connected complement whose population
is ``epsilon * pop_target`` away from ``pop_target``.
Returns a subset of nodes of ``graph`` (whose induced subgraph is connected).
The other part of the partition is the complement of this subset.
:param graph: The graph to partition
:param pop_col: The node attribute holding the population of each node
:param pop_target: The target population for the returned subset of nodes
:param epsilon: The allowable deviation from ``pop_target`` (as a percentage of
``pop_target``) for the subgraph's population
:param node_repeats: A parameter for the algorithm: how many different choices
of root to use before drawing a new spanning tree.
:param spanning_tree: The spanning tree for the algorithm to use (used when the
algorithm chooses a new root and for testing)
:param spanning_tree_fn: The random spanning tree algorithm to use if a spanning
tree is not provided
:param choice: :func:`random.choice`. Can be substituted for testing.
:param max_atempts: The max number of attempts that should be made to bipartition.
:param graph: The graph to partition.
:type graph: nx.Graph
:param pop_col: The node attribute holding the population of each node.
:type pop_col: str
:param pop_target: The target population for the returned subset of nodes.
:type pop_target: Union[int, float]
:param epsilon: The allowable deviation from ``pop_target`` (as a percentage of
``pop_target``) for the subgraph's population.
:type epsilon: float
:param node_repeats: A parameter for the algorithm: how many different choices
of root to use before drawing a new spanning tree. Defaults to 1.
:type node_repeats: int
:param spanning_tree: The spanning tree for the algorithm to use (used when the
algorithm chooses a new root and for testing).
:type spanning_tree: Optional[nx.Graph]
:param spanning_tree_fn: The random spanning tree algorithm to use if a spanning
tree is not provided. Defaults to :func:`random_spanning_tree`.
:type spanning_tree_fn: Callable
:param weight_dict: A dictionary of weights for the spanning tree algorithm.
Defaults to None.
:type weight_dict: Dict, optional

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

Consistent use of Optional[Dict]?

:param balance_edge_fn: The function to find balanced edge cuts. Defaults to
:func:`find_balanced_edge_cuts_memoization`.
:type balance_edge_fn: Callable, optional

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

Consistent use of Optional

:param choice: The function to make a random choice. Can be substituted for testing.
Defaults to :func:`random.choice`.
:type choice: Callable

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

Optional.

:param max_attempts: The maximum number of attempts that should be made to bipartition.
Defaults to 1000.
:type max_attempts: Optional[int]
:return: A subset of nodes of ``graph`` (whose induced subgraph is connected). The other
part of the partition is the complement of this subset.
:rtype: Set
:raises RuntimeError: If a possible cut cannot be found after the maximum number of attempts.
"""
# Try to add the region-aware in if the spanning_tree_fn accepts a weight dictionary
if 'weight_dict' in signature(spanning_tree_fn).parameters:
spanning_tree_fn = partial(spanning_tree_fn, weight_dict=weight_dict)

populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices}

possible_cuts = []
Expand Down

1 comment on commit d5fa077

@cdonnay
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Read through the whole thing!

Please sign in to comment.