diff --git a/src/main/scala/dotvisualizer/FirrtlDiagrammer.scala b/src/main/scala/dotvisualizer/FirrtlDiagrammer.scala index 20de6df..962af25 100644 --- a/src/main/scala/dotvisualizer/FirrtlDiagrammer.scala +++ b/src/main/scala/dotvisualizer/FirrtlDiagrammer.scala @@ -4,9 +4,7 @@ package dotvisualizer import java.io.{File, PrintWriter} -import chisel3.experimental -import chisel3.experimental.{ChiselAnnotation, RunFirrtlTransform} -import chisel3.internal.InstanceId +import chisel3.experimental.ChiselAnnotation import dotvisualizer.transforms.{MakeDiagramGroup, ModuleLevelDiagrammer} import firrtl._ import firrtl.annotations._ @@ -16,10 +14,6 @@ import scala.concurrent.duration.Duration import scala.concurrent.{Await, Future, TimeoutException, blocking} import scala.sys.process._ -//TODO: MONICA: Implement depth -//TODO: Make input and output separate colors -//TODO: Make modules at different levels separate colors - //scalastyle:off magic.number //scalastyle:off regex @@ -40,6 +34,10 @@ case class SetOpenProgram(openProgram: String) extends OptionAnnotation case class DotTimeOut(seconds: Int) extends OptionAnnotation +case class RankDirAnnotation(rankDir: String) extends OptionAnnotation + +case object UseRankAnnotation extends OptionAnnotation + object FirrtlDiagrammer { var dotTimeOut = 7 @@ -184,7 +182,7 @@ object FirrtlDiagrammer { transform.execute(circuitState) } - val fileName = s"${targetDir}${circuitState.circuit.main}_hierarchy.dot" + val fileName = s"$targetDir${circuitState.circuit.main}_hierarchy.dot" val openProgram = controlAnnotations.collectFirst { case SetOpenProgram(program) => program }.getOrElse("open") @@ -217,6 +215,14 @@ object FirrtlDiagrammer { .action { (_, c) => c.copy(justTopLevel = true) } .text("use this to only see the top level view") + opt[String]('o', "rank-dir") + .action { (x, c) => c.copy(rankDir = x) } + .text("use to set ranking direction, default is LR, TB is good alternative") + + opt[Unit]('j', "rank-elements") + .action { (_, c) => c.copy(useRanking = true) } + .text("tries to rank elements by depth from inputs") + opt[Int]('s', "dot-timeout-seconds") .action { (x, c) => c.copy(dotTimeOut = x) } .text("use this to only see the top level view") @@ -238,7 +244,9 @@ case class Config( openProgram: String = Config.getOpenForOs, targetDir: String = "", justTopLevel: Boolean = false, - dotTimeOut: Int = 7 + dotTimeOut: Int = 7, + useRanking: Boolean = false, + rankDir: String = "LR" ) { def toAnnotations: Seq[Annotation] = { val dir = { @@ -254,9 +262,11 @@ case class Config( SetRenderProgram(renderProgram), SetOpenProgram(openProgram), TargetDirAnnotation(dir), - DotTimeOut(dotTimeOut) + DotTimeOut(dotTimeOut), + RankDirAnnotation(rankDir) ) ++ - (if(startModuleName.nonEmpty) Seq(StartModule(startModuleName)) else Seq.empty) + (if(startModuleName.nonEmpty) Seq(StartModule(startModuleName)) else Seq.empty) ++ + (if(useRanking) Seq(UseRankAnnotation) else Seq.empty) } } diff --git a/src/main/scala/dotvisualizer/dotnodes/ModuleNode.scala b/src/main/scala/dotvisualizer/dotnodes/ModuleNode.scala index d7d4e8c..8745f1a 100755 --- a/src/main/scala/dotvisualizer/dotnodes/ModuleNode.scala +++ b/src/main/scala/dotvisualizer/dotnodes/ModuleNode.scala @@ -2,9 +2,8 @@ package dotvisualizer.dotnodes -import java.io.{File, PrintWriter} - import dotvisualizer.transforms.MakeOneDiagram +import firrtl.graph.DiGraph import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -16,9 +15,11 @@ case class ModuleNode( subModuleDepth: Int = 0 ) extends DotNode { - val inputs: ArrayBuffer[DotNode] = new ArrayBuffer() - val outputs: ArrayBuffer[DotNode] = new ArrayBuffer() + var renderWithRank: Boolean = false + val namedNodes: mutable.HashMap[String, DotNode] = new mutable.HashMap() + val subModuleNames: mutable.HashSet[String] = new mutable.HashSet[String]() + val connections: mutable.HashMap[String, String] = new mutable.HashMap() private val analogConnections = new mutable.HashMap[String, ArrayBuffer[String]]() { override def default(key: String): ArrayBuffer[String] = { @@ -31,22 +32,87 @@ case class ModuleNode( val backgroundColorIndex: Int = subModuleDepth % MakeOneDiagram.subModuleColorsByDepth.length val backgroundColor: String = MakeOneDiagram.subModuleColorsByDepth(backgroundColorIndex) + //scalastyle:off method.length cyclomatic.complexity + def constructRankDirectives: String = { + val inputNames = children.collect { case p: PortNode if p.isInput => p }.map(_.absoluteName) + val outputPorts = children.collect { case p: PortNode if ! p.isInput => p }.map(_.absoluteName) + + val diGraph = { + val linkedHashMap = new mutable.LinkedHashMap[String, mutable.LinkedHashSet[String]] { + override def default(key: String): mutable.LinkedHashSet[String] = { + this(key) = new mutable.LinkedHashSet[String] + this(key) + } + } + + val connectionTargetNames = connections.values.map(_.split(":").head).toSet + + connections.foreach { case (rhs, lhs) => + val source = lhs.split(":").head + val target = rhs.split(":").head + + if(target.nonEmpty && connectionTargetNames.contains(target)) { + linkedHashMap(source) += target + linkedHashMap(target) + } + } + DiGraph(linkedHashMap) + } + + val sources = diGraph.findSources.filter(inputNames.contains).toSeq + + def getRankedNodes: mutable.ArrayBuffer[Seq[String]] = { + val alreadyVisited = new mutable.HashSet[String]() + val rankNodes = new mutable.ArrayBuffer[Seq[String]]() + + def walkByRank(nodes: Seq[String], rankNumber: Int = 0): Unit = { + rankNodes.append(nodes) + + alreadyVisited ++= nodes + + val nextNodes = nodes.flatMap { node => + diGraph.getEdges(node) + }.filterNot(alreadyVisited.contains).distinct + + if(nextNodes.nonEmpty) { + walkByRank(nextNodes, rankNumber + 1) + } + } + + walkByRank(sources) + rankNodes + } + + val rankedNodes = getRankedNodes + + val rankInfo = rankedNodes.map { + nodesAtRank => s"""{ rank=same; ${nodesAtRank.mkString(" ")} };""" + }.mkString("", "\n ", "") + + rankInfo + "\n " + s"""{ rank=same; ${outputPorts.mkString(" ")} };""" + } + + /** + * Renders this node + * @return + */ def render: String = { def expandBiConnects(target: String, sources: ArrayBuffer[String]): String = { sources.map { vv => s"""$vv -> $target [dir = "both"]""" }.mkString("\n") } + val rankInfo = if(renderWithRank) constructRankDirectives else "" + val s = s""" |subgraph $absoluteName { | label="$name" | URL="${url_string.getOrElse("")}" | bgcolor="$backgroundColor" - | ${inputs.map(_.render).mkString("\n")} - | ${outputs.map(_.render).mkString("\n")} | ${children.map(_.render).mkString("\n")} | - | ${connections.map { case (k, v) => s"$v -> $k"}.mkString("\n")} - | ${analogConnections.map { case (k, v) => expandBiConnects(k, v) }.mkString("\n")} + | ${connections.map { case (k, v) => s"$v -> $k"}.mkString("", "\n ", "")} + | ${analogConnections.map { case (k, v) => expandBiConnects(k, v) }.mkString("", "\n ", "")} + | $rankInfo |} """.stripMargin s @@ -81,59 +147,7 @@ case class ModuleNode( //scalastyle:off method.name def += (childNode: DotNode): Unit = { - namedNodes(childNode.name) = childNode + namedNodes(childNode.absoluteName) = childNode children += childNode } } - -import scala.sys.process._ - -//noinspection ScalaStyle -object ModuleNode { - //noinspection ScalaStyle - def main(args: Array[String]): Unit = { - val topModule = ModuleNode("top", parentOpt = None) - - val fox = LiteralNode("fox", BigInt(1), Some(topModule)) - val dog = LiteralNode("dog", BigInt(5), Some(topModule)) - - val mux1 = MuxNode("mux1", Some(topModule)) - val mux2 = MuxNode("mux2", Some(topModule)) - - topModule.inputs += PortNode("in1", Some(topModule)) - topModule.inputs += PortNode("in2", Some(topModule)) - - val subModule = ModuleNode("child", Some(topModule)) - - subModule.inputs += PortNode("cluster_in1", Some(topModule)) - subModule.inputs += PortNode("in2", Some(topModule)) - - topModule.children += fox - topModule.children += dog - topModule.children += mux1 - topModule.children += mux2 - topModule.children += subModule - - topModule.localConnections(mux1.in1) = fox.absoluteName - topModule.localConnections(mux2.in1) = dog.absoluteName - - topModule.localConnections(s"cluster_top_cluster_in1") = topModule.inputs(0).absoluteName - - val writer = new PrintWriter(new File("module1.dot")) - writer.println(s"digraph structs {") - writer.println(s"graph [splines=ortho]") - writer.println(s"node [shape=plaintext]") - writer.println(topModule.render) - - - writer.println(s"}") - - writer.close() - - "fdp -Tpng -O module1.dot".!! - "open module1.dot.png".!! - } -} - - - diff --git a/src/main/scala/dotvisualizer/transforms/MakeOneDiagram.scala b/src/main/scala/dotvisualizer/transforms/MakeOneDiagram.scala index c7da81c..8dfa197 100644 --- a/src/main/scala/dotvisualizer/transforms/MakeOneDiagram.scala +++ b/src/main/scala/dotvisualizer/transforms/MakeOneDiagram.scala @@ -4,12 +4,11 @@ package dotvisualizer.transforms import java.io.PrintWriter +import dotvisualizer._ import dotvisualizer.dotnodes._ -import dotvisualizer.{FirrtlDiagrammer, Scope, StartModule} import firrtl.PrimOps._ import firrtl.ir._ -import firrtl.{CircuitForm, CircuitState, LowForm} -import firrtl.{Transform, WDefInstance, WRef, WSubField, WSubIndex} +import firrtl.{CircuitForm, CircuitState, LowForm, Transform, WDefInstance, WRef, WSubField, WSubIndex} import scala.collection.mutable @@ -39,6 +38,10 @@ class MakeOneDiagram extends Transform { var linesPrintedSinceFlush = 0 var totalLinesPrinted = 0 + val useRanking = state.annotations.collectFirst { case UseRankAnnotation => UseRankAnnotation}.isDefined + + val rankDir = state.annotations.collectFirst { case RankDirAnnotation(dir) => dir}.getOrElse("LR") + val printFileName = s"$targetDir$startModuleName.dot" println(s"creating dot file $printFileName") val printFile = new PrintWriter(new java.io.File(printFileName)) @@ -250,7 +253,7 @@ class MakeOneDiagram extends Transform { case port if port.direction == dir => val portNode = PortNode( port.name, Some(moduleNode), - if(moduleNode.parentOpt.isEmpty) 0 else 1, + rank = if(port.direction == firrtl.ir.Input) 0 else 1000, port.direction == firrtl.ir.Input ) nameToNode(getFirrtlName(port.name)) = portNode @@ -330,6 +333,8 @@ class MakeOneDiagram extends Transform { processModule(newPrefix, subModule, subModuleNode, scope.descend, subModuleDepth + 1) + moduleNode.subModuleNames += subModuleNode.absoluteName + case DefNode(_, name, expression) if scope.doComponents() => val fName = getFirrtlName(name) val nodeNode = NodeNode(name, Some(moduleNode)) @@ -367,14 +372,15 @@ class MakeOneDiagram extends Transform { findModule(startModuleName, c) match { case topModule: DefModule => pl(s"digraph ${topModule.name} {") - pl("""stylesheet = "styles.css"""") - pl("rankdir=\"LR\"") - // pl(s"graph [splines=ortho];") + pl(s"""stylesheet = "styles.css"""") + pl(s"""rankdir="$rankDir" """) + //TODO: make this an option -- pl(s"graph [splines=ortho];") val topModuleNode = ModuleNode(startModuleName, parentOpt = None) + if(useRanking) topModuleNode.renderWithRank = true processModule("", topModule, topModuleNode, Scope(0, 1)) - // processModule("", topModule, topModuleNode, getScope(topModule.name)) + pl(topModuleNode.render) - //pl("\"Modules Only View Here\" [URL=\"TopLevel.dot.svg\" shape=\"rectangle\"]; \n") + pl("}") case _ => println(s"could not find top module $startModuleName") diff --git a/src/test/scala/dotvisualizer/FirExample.scala b/src/test/scala/dotvisualizer/FirExample.scala index 0b22625..5924138 100644 --- a/src/test/scala/dotvisualizer/FirExample.scala +++ b/src/test/scala/dotvisualizer/FirExample.scala @@ -24,7 +24,9 @@ class FirExampleSpec extends FreeSpec with Matchers { """This is an example of an FIR circuit which has a lot of elements in a single module""" in { val circuit = chisel3.Driver.elaborate(() => new MyManyDynamicElementVecFir(10)) val firrtl = chisel3.Driver.emit(circuit) - val config = Config(targetDir = "test_run_dir/fir_example/", firrtlSource = firrtl) + val config = Config( + targetDir = "test_run_dir/fir_example/", firrtlSource = firrtl, rankDir = "TB", useRanking = true + ) FirrtlDiagrammer.run(config) } } \ No newline at end of file diff --git a/src/test/scala/dotvisualizer/GCD.scala b/src/test/scala/dotvisualizer/GCD.scala index 39f9b4f..1f16927 100644 --- a/src/test/scala/dotvisualizer/GCD.scala +++ b/src/test/scala/dotvisualizer/GCD.scala @@ -30,7 +30,7 @@ class GCDTester extends FreeSpec with Matchers { "GCD circuit to visualize" in { val circuit = chisel3.Driver.elaborate(() => new GCD) val firrtl = chisel3.Driver.emit(circuit) - val config = Config(targetDir = "test_run_dir/gcd/", firrtlSource = firrtl) + val config = Config(targetDir = "test_run_dir/gcd/", firrtlSource = firrtl, useRanking = true) FirrtlDiagrammer.run(config) } }