Skip to content

Commit

Permalink
Updated error message and using a proper TypeError exception when an …
Browse files Browse the repository at this point in the history
…invalid MultiGraph is passed in (#1925)

* Updated error message and using a proper TypeError exception when an invalid MultiGraph is passed in.
* Added test to verify.

This PR replaces #1914 since that one is not passing the style check and the author has not responded.

Authors:
  - Rick Ratzel (https://github.com/rlratzel)

Approvers:
  - Brad Rees (https://github.com/BradReesWork)

URL: #1925
  • Loading branch information
rlratzel authored Nov 5, 2021
1 parent 61e8bad commit ef286cc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
12 changes: 5 additions & 7 deletions python/cugraph/cugraph/structure/graph_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class Graph:
Parameters
----------
m_graph : cuGraph.Graph object or None
Initialize the graph from another Multigraph object
m_graph : cuGraph.MultiGraph object or None
Initialize the graph from a cugraph.MultiGraph object
directed : boolean
Indicated is the graph is directed.
Default is False - Undirected
Expand All @@ -58,7 +58,7 @@ def __init__(self, m_graph=None, directed=False):
self._Impl = None
self.graph_properties = Graph.Properties(directed)
if m_graph is not None:
if m_graph.is_multigraph():
if isinstance(m_graph, MultiGraph):
elist = m_graph.view_edge_list()
if m_graph.is_weighted():
weights = "weights"
Expand All @@ -69,10 +69,8 @@ def __init__(self, m_graph=None, directed=False):
destination="dst",
edge_attr=weights)
else:
msg = (
"Graph can only be initialized using MultiGraph "
)
raise Exception(msg)
raise TypeError("m_graph can only be an instance of a "
f"cugraph.MultiGraph, got {type(m_graph)}")

def __getattr__(self, name):
if self._Impl is None:
Expand Down
20 changes: 20 additions & 0 deletions python/cugraph/cugraph/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,23 @@ def test_neighbors(graph_file):
cu_neighbors.sort()
nx_neighbors.sort()
assert cu_neighbors == nx_neighbors


def test_graph_init_with_multigraph():
"""
Ensures only a valid MultiGraph instance can be used to initialize a Graph
by checking if either the correct exception is raised or no exception at
all.
"""
nxMG = nx.MultiGraph()
with pytest.raises(TypeError):
cugraph.Graph(m_graph=nxMG)

gdf = cudf.DataFrame({"src": [0, 1, 2], "dst": [1, 2, 3]})
cMG = cugraph.MultiGraph()
cMG.from_cudf_edgelist(gdf, source="src", destination="dst")
cugraph.Graph(m_graph=cMG)

cDiMG = cugraph.MultiDiGraph() # deprecated, but should still work
cDiMG.from_cudf_edgelist(gdf, source="src", destination="dst")
cugraph.Graph(m_graph=cDiMG)

0 comments on commit ef286cc

Please sign in to comment.