Skip to content

Commit

Permalink
Complete convert_from_nx of Graph and DiGraph for updated dispatch …
Browse files Browse the repository at this point in the history
…coming out in nx 3.2

This implements a property graph model where edges and nodes may (or may not) have missing data.
Multigraphs are not yet handled.
  • Loading branch information
eriknw committed Jul 20, 2023
1 parent 8b24e4c commit c4ee98c
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 50 deletions.
36 changes: 30 additions & 6 deletions python/cugraph-nx/cugraph_nx/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
from __future__ import annotations

from collections.abc import Hashable, Iterator
from typing import TYPE_CHECKING
from copy import deepcopy
from typing import TYPE_CHECKING, TypeVar

import cugraph_nx as cnx

if TYPE_CHECKING:
import cupy as cp

NodeType = Hashable
NodeType = TypeVar("NodeType", bound=Hashable)
AttrType = TypeVar("AttrType", bound=Hashable)


__all__ = ["Graph"]
Expand All @@ -30,7 +32,10 @@ class Graph:
indptr: cp.ndarray
row_indices: cp.ndarray
col_indices: cp.ndarray
edge_values: cp.ndarray | None
edge_values: dict[AttrType, cp.ndarray]
edge_masks: dict[AttrType, cp.ndarray]
node_values: dict[AttrType, cp.ndarray]
node_masks: dict[AttrType, cp.ndarray]
key_to_id: dict[NodeType, int] | None
_id_to_key: dict[int, NodeType] | None
_N: int
Expand All @@ -46,6 +51,9 @@ def __init__(
row_indices,
col_indices,
edge_values,
edge_masks,
node_values,
node_masks,
*,
key_to_id,
id_to_key=None,
Expand All @@ -56,6 +64,9 @@ def __init__(
self.row_indices = row_indices
self.col_indices = col_indices
self.edge_values = edge_values
self.edge_masks = edge_masks
self.node_values = node_values
self.node_masks = node_masks
self.key_to_id = key_to_id
self._id_to_key = id_to_key
self._N = indptr.size - 1
Expand Down Expand Up @@ -111,23 +122,36 @@ def to_directed(self, as_view=False) -> cnx.DiGraph:
row_indices = self.row_indices
col_indices = self.col_indices
edge_values = self.edge_values
edge_masks = self.edge_masks
node_values = self.node_values
node_masks = self.node_masks
key_to_id = self.key_to_id
id_to_key = None if key_to_id is None else self._id_to_key
if not as_view:
indptr = indptr.copy()
row_indices = row_indices.copy()
col_indices = col_indices.copy()
if edge_values is not None:
edge_values = edge_values.copy()
edge_values = {key: val.copy() for key, val in edge_values.items()}
edge_masks = {key: val.copy() for key, val in edge_masks.items()}
node_values = {key: val.copy() for key, val in node_values.items()}
node_masks = {key: val.copy() for key, val in node_masks.items()}
if key_to_id is not None:
key_to_id = key_to_id.copy()
if id_to_key is not None:
id_to_key = id_to_key.copy()
return cnx.DiGraph(
rv = cnx.DiGraph(
indptr,
row_indices,
col_indices,
edge_values,
edge_masks,
node_values,
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
)
if as_view:
rv.graph = self.graph
else:
rv.graph.update(deepcopy(self.graph))
return rv
223 changes: 195 additions & 28 deletions python/cugraph-nx/cugraph_nx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.
from __future__ import annotations

import collections
import itertools

import cupy as cp
Expand All @@ -20,6 +21,7 @@

import cugraph_nx as cnx


__all__ = [
"from_networkx",
"to_networkx",
Expand All @@ -28,22 +30,101 @@
"to_undirected_graph",
]

concat = itertools.chain.from_iterable

def from_networkx(
G, edge_attr=None, edge_default=1.0, *, is_directed=None, dtype=None
) -> cnx.Graph:
if not isinstance(G, nx.Graph):
raise TypeError(f"Expected networkx.Graph; got {type(G)}")
if G.is_multigraph():

def convert_from_nx(
graph: nx.Graph,
edge_attrs: dict | None = None,
node_attrs: dict | None = None,
preserve_edge_attrs: bool = False,
preserve_node_attrs: bool = False,
preserve_graph_attrs: bool = False,
name: str | None = None,
graph_name: str | None = None,
*,
# Custom arguments
is_directed: bool | None = None,
edge_dtypes: dict | None = None,
node_dtypes: dict | None = None,
):
# This uses `graph._adj` and `graph._node`, which are private attributes in NetworkX
if not isinstance(graph, nx.Graph):
raise TypeError(f"Expected networkx.Graph; got {type(graph)}")
if graph.is_multigraph():
raise NotImplementedError("MultiGraph support is not yet implemented")
if isinstance(edge_attr, (list, dict, set)):
raise NotImplementedError(
"Graph with multiple attributes is not yet supported; "
f"bad edge_attr: {edge_attr}"
has_missing_edge_data = set()
if graph.number_of_edges() == 0:
pass
elif preserve_edge_attrs:
# attrs = set().union(*concat(map(dict.values, graph._adj.values())))
attr_sets = set(map(frozenset, concat(map(dict.values, graph._adj.values()))))
attrs = set().union(*attr_sets)
edge_attrs = dict.fromkeys(attrs)
if len(attr_sets) > 1:
counts = collections.Counter(concat(attr_sets))
has_missing_edge_data = {
key for key, val in counts.items() if val != len(attr_sets)
}
elif edge_attrs is not None and None in edge_attrs.values():
# Required edge attributes have a default of None in `edge_attrs`
# Verify all edge attributes are present!
required = frozenset(
attr for attr, default in edge_attrs.items() if default is None
)
attr_sets = set(
map(required.intersection, concat(map(dict.values, graph._adj.values())))
)
get_values = edge_attr is not None
adj = G._adj # This is a NetworkX private attribute, but is much faster to use
if get_values and isinstance(adj, nx.classes.coreviews.FilterAdjacency):
# attr_set = set().union(*attr_sets)
if missing := required - set().union(*attr_sets):
# Required attributes missing completely
missing_attrs = ", ".join(sorted(missing))
raise TypeError(f"Missing required edge attribute: {missing_attrs}")
if len(attr_sets) != 1:
# Required attributes are missing _some_ data
counts = collections.Counter(concat(attr_sets))
bad_attrs = {key for key, val in counts.items() if val != len(attr_sets)}
missing_attrs = ", ".join(sorted(bad_attrs))
raise TypeError(
f"Some edges are missing required attribute: {missing_attrs}"
)

has_missing_node_data = set()
if graph.number_of_nodes() == 0:
pass
elif preserve_node_attrs:
# attrs = set().union(*graph._node.values())
attr_sets = set(map(frozenset, graph._node.values()))
attrs = set().union(*attr_sets)
node_attrs = dict.fromkeys(attrs)
if len(attr_sets) > 1:
counts = collections.Counter(concat(attr_sets))
has_missing_node_data = {
key for key, val in counts.items() if val != len(attr_sets)
}
elif node_attrs is not None and None in node_attrs.values():
# Required node attributes have a default of None in `node_attrs`
# Verify all node attributes are present!
required = frozenset(
attr for attr, default in node_attrs.items() if default is None
)
attr_sets = set(map(required.intersection, graph._node.values()))
if missing := required - set().union(*attr_sets):
# Required attributes missing completely
missing_attrs = ", ".join(sorted(missing))
raise TypeError(f"Missing required node attribute: {missing_attrs}")
if len(attr_sets) != 1:
# Required attributes are missing _some_ data
counts = collections.Counter(concat(attr_sets))
bad_attrs = {key for key, val in counts.items() if val != len(attr_sets)}
missing_attrs = ", ".join(sorted(bad_attrs))
raise TypeError(
f"Some nodes are missing required attribute: {missing_attrs}"
)

get_edge_values = edge_attrs is not None
adj = graph._adj # This is a NetworkX private attribute, but is much faster to use
if get_edge_values and isinstance(adj, nx.classes.coreviews.FilterAdjacency):
adj = {k: dict(v) for k, v in adj.items()}
N = len(adj)
key_to_id = dict(zip(adj, range(N)))
Expand All @@ -53,35 +134,121 @@ def from_networkx(
col_iter = map(key_to_id.__getitem__, col_iter)
else:
key_to_id = None
# TODO: do col_indices need to be sorted in each row?
# TODO: do col_indices need to be sorted in each row (if we use indptr as CSR)?
col_indices = cp.fromiter(col_iter, np.int32)
iter_values = (
edgedata.get(edge_attr, edge_default)
for rowdata in adj.values()
for edgedata in rowdata.values()
)
if not get_values:
values = None
elif dtype is None:
values = cp.array(list(iter_values))
else:
values = cp.fromiter(iter_values, dtype)

edge_values = {}
edge_masks = {}
if get_edge_values:
if edge_dtypes is None:
edge_dtypes = {}
for edge_attr, edge_default in edge_attrs.items():
dtype = edge_dtypes.get(edge_attr)
if edge_default is None and edge_attr in has_missing_edge_data:
vals = []
append = vals.append
iter_mask = (
append(
edgedata[edge_attr]
if (present := edge_attr in edgedata)
else False
)
or present
for rowdata in adj.values()
for edgedata in rowdata.values()
)
edge_masks[edge_attr] = cp.fromiter(iter_mask, bool)
edge_values[edge_attr] = cp.array(vals, dtype)
else:
iter_values = (
edgedata.get(edge_attr, edge_default)
for rowdata in adj.values()
for edgedata in rowdata.values()
)
if dtype is None:
edge_values[edge_attr] = cp.array(list(iter_values))
else:
edge_values[edge_attr] = cp.fromiter(iter_values, dtype)
if edge_default is None:
edge_masks[edge_attr] = cp.zeros(col_indices.size, bool)

# TODO: should we use indptr for CSR? Should we only use COO?
indptr = cp.cumsum(
cp.fromiter(itertools.chain([0], map(len, adj.values())), np.int32),
dtype=np.int32,
)
row_indices = cp.repeat(cp.arange(N, dtype=np.int32), list(map(len, adj.values())))
if G.is_directed() or is_directed:

get_node_values = node_attrs is not None
node_values = {}
node_masks = {}
nodes = graph._node
if get_node_values:
if node_dtypes is None:
node_dtypes = {}
for node_attr, node_default in node_attrs.items():
# Iterate over `adj` to ensure consistent order
dtype = node_dtypes.get(node_attr)
if node_default is None and node_attr in has_missing_node_data:
vals = []
append = vals.append
iter_mask = (
append(
nodedata[node_attr]
if (present := node_attr in (nodedata := nodes[node_id]))
else False
)
or present
for node_id in adj
)
node_masks[node_attr] = cp.fromiter(iter_mask, bool)
node_values[node_attr] = cp.array(vals, dtype)
else:
iter_values = (
nodes[node_id].get(node_attr, node_default) for node_id in adj
)
if dtype is None:
node_values[node_attr] = cp.array(list(iter_values))
else:
node_values[node_attr] = cp.fromiter(iter_values, dtype)
if node_default is None:
node_masks[node_attr] = cp.zeros(col_indices.size, bool)

if graph.is_directed() or is_directed:
klass = cnx.DiGraph
else:
klass = cnx.Graph
return klass(
rv = klass(
indptr,
row_indices,
col_indices,
values,
edge_values,
edge_masks,
node_values,
node_masks,
key_to_id=key_to_id,
)
if preserve_graph_attrs:
rv.graph.update(graph.graph) # deepcopy?
return rv


def from_networkx(
G: nx.Graph,
edge_attr=None,
edge_default=1.0,
*,
is_directed: bool | None = None,
dtype=None,
) -> cnx.Graph:
if edge_attr is not None:
edge_attrs = {edge_attr: edge_default}
edge_dtypes = {edge_attr: dtype}
else:
edge_attrs = edge_dtypes = None
return convert_from_nx(
G, edge_attrs=edge_attrs, is_directed=is_directed, edge_dtypes=edge_dtypes
)


def to_networkx(G) -> nx.Graph:
Expand Down
17 changes: 3 additions & 14 deletions python/cugraph-nx/cugraph_nx/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,17 @@
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

import cugraph_nx as cnx

from . import algorithms

if TYPE_CHECKING:
import networkx as nx


class Dispatcher:
is_strongly_connected = algorithms.is_strongly_connected

@staticmethod
def convert_from_nx(graph: nx.Graph, weight=None, *, name=None) -> cnx.Graph:
return cnx.from_networkx(graph, edge_attr=weight)

@staticmethod
def convert_to_nx(obj, *, name=None):
if isinstance(obj, cnx.Graph):
return cnx.to_networkx(obj)
return obj
# Required conversions
convert_from_nx = cnx.convert.convert_from_nx
convert_to_nx = cnx.convert.convert_to_nx

@staticmethod
def on_start_tests(items):
Expand Down
Loading

0 comments on commit c4ee98c

Please sign in to comment.