diff --git a/scripts/viz.py b/scripts/viz.py
index 09d7d86dd33..09217d0c405 100755
--- a/scripts/viz.py
+++ b/scripts/viz.py
@@ -2,10 +2,17 @@
"""Visualize an LBANN model's layer graph and save to file."""
import argparse
+import random
import re
import graphviz
-import google.protobuf.text_format
from lbann import lbann_pb2, layers_pb2
+from lbann.proto import serialize
+
+# Pastel rainbow (slightly shuffled) from colorkit.co
+palette = [
+ '#ffffff', '#a0c4ff', '#ffadad', '#fdffb6', '#caffbf', '#9bf6ff',
+ '#bdb2ff', '#ffc6ff', '#ffd6a5'
+]
# Parse command-line arguments
parser = argparse.ArgumentParser(
@@ -17,14 +24,14 @@
parser.add_argument('output',
action='store',
nargs='?',
- default='graph.pdf',
+ default='graph.dot',
type=str,
- help='output file (default: graph.pdf)')
+ help='output file (default: graph.dot)')
parser.add_argument('--file-format',
action='store',
- default='pdf',
+ default='dot',
type=str,
- help='output file format (default: pdf)',
+ help='output file format (default: dot)',
metavar='FORMAT')
parser.add_argument('--label-format',
action='store',
@@ -39,6 +46,10 @@
type=str,
help='Graphviz visualization scheme (default: dot)',
metavar='ENGINE')
+parser.add_argument('--color-cross-grid',
+ action='store_true',
+ default=False,
+ help='Highlight cross-grid edges')
args = parser.parse_args()
# Strip extension from filename
@@ -51,9 +62,7 @@
label_format = re.sub(r' |-|_', '', args.label_format.lower())
# Read prototext file
-proto = lbann_pb2.LbannPB()
-with open(args.input, 'r') as f:
- google.protobuf.text_format.Merge(f.read(), proto)
+proto = serialize.generic_load(args.input)
model = proto.model
# Construct graphviz graph
@@ -62,29 +71,36 @@
engine=args.graphviz_engine)
graph.attr('node', shape='rect')
+layer_to_grid_tag = {}
+
# Construct nodes in layer graph
layer_types = (set(layers_pb2.Layer.DESCRIPTOR.fields_by_name.keys()) - set([
'name', 'parents', 'children', 'datatype', 'data_layout',
'device_allocation', 'weights', 'freeze', 'hint_layer', 'top', 'bottom',
- 'type', 'motif_layer'
+ 'type', 'motif_layer', 'parallel_strategy', 'grid_tag'
]))
for l in model.layer:
# Determine layer type
- type = ''
+ ltype = ''
for _type in layer_types:
if l.HasField(_type):
- type = getattr(l, _type).DESCRIPTOR.name
+ ltype = getattr(l, _type).DESCRIPTOR.name
break
+ # If operator layer, use operator type
+ if ltype == 'OperatorLayer':
+ url = l.operator_layer.ops[0].parameters.type_url
+ ltype = url[url.rfind('.') + 1:]
+
# Construct node label
label = ''
if label_format == 'nameonly':
label = l.name
elif label_format == 'typeonly':
- label = type
+ label = ltype
elif label_format == 'typeandname':
- label = '<{0}
{1}>'.format(type, l.name)
+ label = '<{0}
{1}>'.format(ltype, l.name)
elif label_format == 'full':
label = '<'
for (index, line) in enumerate(str(l).strip().split('\n')):
@@ -94,14 +110,36 @@
label += '>'
# Add layer as layer graph node
- graph.node(l.name, label=label)
+ tag = l.grid_tag.value
+ layer_to_grid_tag[l.name] = tag
+ attrs = {}
+ if tag != 0:
+ attrs = dict(style='filled', fillcolor=palette[tag % len(palette)])
+ graph.node(l.name, label=label, **attrs)
# Add parent/child relationships as layer graph edges
edges = set()
+cross_grid_edges = set()
for l in model.layer:
- edges.update([(p, l.name) for p in l.parents.split()])
- edges.update([(l.name, c) for c in l.children.split()])
+ tag = layer_to_grid_tag[l.name]
+ for p in l.parents:
+ if tag != layer_to_grid_tag[p]:
+ cross_grid_edges.add((p, l.name))
+ else:
+ edges.add((p, l.name))
+
+ for c in l.children:
+ if tag != layer_to_grid_tag[c]:
+ cross_grid_edges.add((l.name, c))
+ else:
+ edges.add((l.name, c))
+
graph.edges(edges)
+if args.color_cross_grid:
+ for u, v in cross_grid_edges:
+ graph.edge(u, v, color='red')
+else:
+ graph.edges(cross_grid_edges)
# Save to file
graph.render(filename=filename, cleanup=True, format=file_format)