Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

100 unify cost api #126

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions motile/costs/appear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@ class Appear(Cost):

Args:
weight:
The weight to apply to the cost of each starting track.
The weight to apply to the cost attribute of each starting track.
Defaults to 1.

attribute:
The name of the attribute to use to look up cost. Default is
``None``, which means that a constant cost is used.

constant:
A constant cost for each node that starts a track.
A constant cost for each node that starts a track. Defaults to 0.

ignore_attribute:
The name of an optional node attribute that, if it is set and
evaluates to ``True``, will not set the appear cost for that node.
Defaults to None
"""

def __init__(
Expand Down
29 changes: 24 additions & 5 deletions motile/costs/disappear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,32 @@
This is cost is not applied to nodes in the last frame of the graph.

Args:
constant (float):
A constant cost for each node that ends a track.
weight:
The weight to apply to the cost of each ending track. Defaults to 1.

attribute:
The name of the attribute to use to look up cost. Default is
``None``, which means that a constant cost is used.

constant:
A constant cost for each node that starts a track. Defaults to 0.

ignore_attribute:
The name of an optional node attribute that, if it is set and
evaluates to ``True``, will not set the disappear cost for that
node.
evaluates to ``True``, will not set the disappear cost for that node.
Defaults to None.
"""

def __init__(self, constant: float, ignore_attribute: str | None = None) -> None:
def __init__(
self,
weight: float = 1,
attribute: str | None = None,
constant: float = 0,
ignore_attribute: str | None = None,
):
self.weight = Weight(weight)
self.constant = Weight(constant)
self.attribute = attribute
self.ignore_attribute = ignore_attribute

def apply(self, solver: Solver) -> None:
Expand All @@ -39,4 +54,8 @@
continue
if G.nodes[node][G.frame_attribute] == G.get_frames()[1] - 1:
continue
if self.attribute is not None:
solver.add_variable_cost(

Check warning on line 58 in motile/costs/disappear.py

View check run for this annotation

Codecov / codecov/patch

motile/costs/disappear.py#L58

Added line #L58 was not covered by tests
index, G.nodes[node][self.attribute], self.weight
)
solver.add_variable_cost(index, 1.0, self.constant)
15 changes: 8 additions & 7 deletions motile/costs/edge_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ class EdgeSelection(Cost):

Args:
weight:
The weight to apply to the cost given by the ``cost`` attribute of
each edge.
The weight to apply to the cost given by the provided attribute of
each edge. Default is ``1.0``.

attribute:
The name of the edge attribute to use to look up cost. Default is
``'cost'``.
None, which means only a constant cost is used.

constant:
A constant cost for each selected edge. Default is ``0.0``.
"""

def __init__(
self, weight: float, attribute: str = "cost", constant: float = 0.0
self, weight: float = 1, attribute: str | None = None, constant: float = 0.0
) -> None:
self.weight = Weight(weight)
self.constant = Weight(constant)
Expand All @@ -37,7 +37,8 @@ def apply(self, solver: Solver) -> None:
edge_variables = solver.get_variables(EdgeSelected)

for edge, index in edge_variables.items():
solver.add_variable_cost(
index, solver.graph.edges[edge][self.attribute], self.weight
)
if self.attribute is not None:
solver.add_variable_cost(
index, solver.graph.edges[edge][self.attribute], self.weight
)
solver.add_variable_cost(index, 1.0, self.constant)
17 changes: 9 additions & 8 deletions motile/costs/node_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ class NodeSelection(Cost):

Args:
weight:
The weight to apply to the cost given by the ``cost`` attribute of
each node.
The weight to apply to the cost given by the provided attribute of
each node. Default is ``1.0``

attribute:
The name of the node attribute to use to look up cost. Default is
``'cost'``.
The name of the node attribute to use to look up cost, or None if a constant
cost is desired. Default is ``None``.

constant:
A constant cost for each selected node. Default is ``0.0``.
"""

def __init__(
self, weight: float, attribute: str = "cost", constant: float = 0.0
self, weight: float = 1, attribute: str | None = None, constant: float = 0.0
) -> None:
self.weight = Weight(weight)
self.constant = Weight(constant)
Expand All @@ -37,7 +37,8 @@ def apply(self, solver: Solver) -> None:
node_variables = solver.get_variables(NodeSelected)

for node, index in node_variables.items():
solver.add_variable_cost(
index, solver.graph.nodes[node][self.attribute], self.weight
)
if self.attribute is not None:
solver.add_variable_cost(
index, solver.graph.nodes[node][self.attribute], self.weight
)
solver.add_variable_cost(index, 1.0, self.constant)
4 changes: 2 additions & 2 deletions motile/costs/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ class Split(Cost):

Args:
weight:
The weight to apply to the cost of each split.
The weight to apply to the cost of each split. Default is ``1``.

attribute:
The name of the attribute to use to look up the cost. Default is
``None``, which means that a constant cost is used.

constant:
A constant cost for each node that has more than one selected
child.
child. Default is ``0``.
"""

def __init__(
Expand Down
Loading