diff --git a/scripts/viz.py b/scripts/viz.py index 09d7d86dd33..09217d0c405 100755 --- a/scripts/viz.py +++ b/scripts/viz.py @@ -2,10 +2,17 @@ """Visualize an LBANN model's layer graph and save to file.""" import argparse +import random import re import graphviz -import google.protobuf.text_format from lbann import lbann_pb2, layers_pb2 +from lbann.proto import serialize + +# Pastel rainbow (slightly shuffled) from colorkit.co +palette = [ + '#ffffff', '#a0c4ff', '#ffadad', '#fdffb6', '#caffbf', '#9bf6ff', + '#bdb2ff', '#ffc6ff', '#ffd6a5' +] # Parse command-line arguments parser = argparse.ArgumentParser( @@ -17,14 +24,14 @@ parser.add_argument('output', action='store', nargs='?', - default='graph.pdf', + default='graph.dot', type=str, - help='output file (default: graph.pdf)') + help='output file (default: graph.dot)') parser.add_argument('--file-format', action='store', - default='pdf', + default='dot', type=str, - help='output file format (default: pdf)', + help='output file format (default: dot)', metavar='FORMAT') parser.add_argument('--label-format', action='store', @@ -39,6 +46,10 @@ type=str, help='Graphviz visualization scheme (default: dot)', metavar='ENGINE') +parser.add_argument('--color-cross-grid', + action='store_true', + default=False, + help='Highlight cross-grid edges') args = parser.parse_args() # Strip extension from filename @@ -51,9 +62,7 @@ label_format = re.sub(r' |-|_', '', args.label_format.lower()) # Read prototext file -proto = lbann_pb2.LbannPB() -with open(args.input, 'r') as f: - google.protobuf.text_format.Merge(f.read(), proto) +proto = serialize.generic_load(args.input) model = proto.model # Construct graphviz graph @@ -62,29 +71,36 @@ engine=args.graphviz_engine) graph.attr('node', shape='rect') +layer_to_grid_tag = {} + # Construct nodes in layer graph layer_types = (set(layers_pb2.Layer.DESCRIPTOR.fields_by_name.keys()) - set([ 'name', 'parents', 'children', 'datatype', 'data_layout', 'device_allocation', 'weights', 'freeze', 'hint_layer', 'top', 'bottom', - 'type', 'motif_layer' + 'type', 'motif_layer', 'parallel_strategy', 'grid_tag' ])) for l in model.layer: # Determine layer type - type = '' + ltype = '' for _type in layer_types: if l.HasField(_type): - type = getattr(l, _type).DESCRIPTOR.name + ltype = getattr(l, _type).DESCRIPTOR.name break + # If operator layer, use operator type + if ltype == 'OperatorLayer': + url = l.operator_layer.ops[0].parameters.type_url + ltype = url[url.rfind('.') + 1:] + # Construct node label label = '' if label_format == 'nameonly': label = l.name elif label_format == 'typeonly': - label = type + label = ltype elif label_format == 'typeandname': - label = '<{0}
{1}>'.format(type, l.name) + label = '<{0}
{1}>'.format(ltype, l.name) elif label_format == 'full': label = '<' for (index, line) in enumerate(str(l).strip().split('\n')): @@ -94,14 +110,36 @@ label += '>' # Add layer as layer graph node - graph.node(l.name, label=label) + tag = l.grid_tag.value + layer_to_grid_tag[l.name] = tag + attrs = {} + if tag != 0: + attrs = dict(style='filled', fillcolor=palette[tag % len(palette)]) + graph.node(l.name, label=label, **attrs) # Add parent/child relationships as layer graph edges edges = set() +cross_grid_edges = set() for l in model.layer: - edges.update([(p, l.name) for p in l.parents.split()]) - edges.update([(l.name, c) for c in l.children.split()]) + tag = layer_to_grid_tag[l.name] + for p in l.parents: + if tag != layer_to_grid_tag[p]: + cross_grid_edges.add((p, l.name)) + else: + edges.add((p, l.name)) + + for c in l.children: + if tag != layer_to_grid_tag[c]: + cross_grid_edges.add((l.name, c)) + else: + edges.add((l.name, c)) + graph.edges(edges) +if args.color_cross_grid: + for u, v in cross_grid_edges: + graph.edge(u, v, color='red') +else: + graph.edges(cross_grid_edges) # Save to file graph.render(filename=filename, cleanup=True, format=file_format)