Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New option to allow different layout ranking #13

Merged
merged 3 commits into from
Mar 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions src/main/scala/dotvisualizer/FirrtlDiagrammer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ package dotvisualizer

import java.io.{File, PrintWriter}

import chisel3.experimental
import chisel3.experimental.{ChiselAnnotation, RunFirrtlTransform}
import chisel3.internal.InstanceId
import chisel3.experimental.ChiselAnnotation
import dotvisualizer.transforms.{MakeDiagramGroup, ModuleLevelDiagrammer}
import firrtl._
import firrtl.annotations._
Expand All @@ -16,10 +14,6 @@ import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future, TimeoutException, blocking}
import scala.sys.process._

//TODO: MONICA: Implement depth
//TODO: Make input and output separate colors
//TODO: Make modules at different levels separate colors

//scalastyle:off magic.number
//scalastyle:off regex

Expand All @@ -40,6 +34,10 @@ case class SetOpenProgram(openProgram: String) extends OptionAnnotation

case class DotTimeOut(seconds: Int) extends OptionAnnotation

case class RankDirAnnotation(rankDir: String) extends OptionAnnotation

case object UseRankAnnotation extends OptionAnnotation

object FirrtlDiagrammer {

var dotTimeOut = 7
Expand Down Expand Up @@ -184,7 +182,7 @@ object FirrtlDiagrammer {
transform.execute(circuitState)
}

val fileName = s"${targetDir}${circuitState.circuit.main}_hierarchy.dot"
val fileName = s"$targetDir${circuitState.circuit.main}_hierarchy.dot"
val openProgram = controlAnnotations.collectFirst {
case SetOpenProgram(program) => program
}.getOrElse("open")
Expand Down Expand Up @@ -217,6 +215,14 @@ object FirrtlDiagrammer {
.action { (_, c) => c.copy(justTopLevel = true) }
.text("use this to only see the top level view")

opt[String]('o', "rank-dir")
.action { (x, c) => c.copy(rankDir = x) }
.text("use to set ranking direction, default is LR, TB is good alternative")

opt[Unit]('j', "rank-elements")
.action { (_, c) => c.copy(useRanking = true) }
.text("tries to rank elements by depth from inputs")

opt[Int]('s', "dot-timeout-seconds")
.action { (x, c) => c.copy(dotTimeOut = x) }
.text("use this to only see the top level view")
Expand All @@ -238,7 +244,9 @@ case class Config(
openProgram: String = Config.getOpenForOs,
targetDir: String = "",
justTopLevel: Boolean = false,
dotTimeOut: Int = 7
dotTimeOut: Int = 7,
useRanking: Boolean = false,
rankDir: String = "LR"
) {
def toAnnotations: Seq[Annotation] = {
val dir = {
Expand All @@ -254,9 +262,11 @@ case class Config(
SetRenderProgram(renderProgram),
SetOpenProgram(openProgram),
TargetDirAnnotation(dir),
DotTimeOut(dotTimeOut)
DotTimeOut(dotTimeOut),
RankDirAnnotation(rankDir)
) ++
(if(startModuleName.nonEmpty) Seq(StartModule(startModuleName)) else Seq.empty)
(if(startModuleName.nonEmpty) Seq(StartModule(startModuleName)) else Seq.empty) ++
(if(useRanking) Seq(UseRankAnnotation) else Seq.empty)
}
}

Expand Down
136 changes: 75 additions & 61 deletions src/main/scala/dotvisualizer/dotnodes/ModuleNode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

package dotvisualizer.dotnodes

import java.io.{File, PrintWriter}

import dotvisualizer.transforms.MakeOneDiagram
import firrtl.graph.DiGraph

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
Expand All @@ -16,9 +15,11 @@ case class ModuleNode(
subModuleDepth: Int = 0
) extends DotNode {

val inputs: ArrayBuffer[DotNode] = new ArrayBuffer()
val outputs: ArrayBuffer[DotNode] = new ArrayBuffer()
var renderWithRank: Boolean = false

val namedNodes: mutable.HashMap[String, DotNode] = new mutable.HashMap()
val subModuleNames: mutable.HashSet[String] = new mutable.HashSet[String]()

val connections: mutable.HashMap[String, String] = new mutable.HashMap()
private val analogConnections = new mutable.HashMap[String, ArrayBuffer[String]]() {
override def default(key: String): ArrayBuffer[String] = {
Expand All @@ -31,22 +32,87 @@ case class ModuleNode(
val backgroundColorIndex: Int = subModuleDepth % MakeOneDiagram.subModuleColorsByDepth.length
val backgroundColor: String = MakeOneDiagram.subModuleColorsByDepth(backgroundColorIndex)

//scalastyle:off method.length cyclomatic.complexity
def constructRankDirectives: String = {
val inputNames = children.collect { case p: PortNode if p.isInput => p }.map(_.absoluteName)
val outputPorts = children.collect { case p: PortNode if ! p.isInput => p }.map(_.absoluteName)

val diGraph = {
val linkedHashMap = new mutable.LinkedHashMap[String, mutable.LinkedHashSet[String]] {
override def default(key: String): mutable.LinkedHashSet[String] = {
this(key) = new mutable.LinkedHashSet[String]
this(key)
}
}

val connectionTargetNames = connections.values.map(_.split(":").head).toSet

connections.foreach { case (rhs, lhs) =>
val source = lhs.split(":").head
val target = rhs.split(":").head

if(target.nonEmpty && connectionTargetNames.contains(target)) {
linkedHashMap(source) += target
linkedHashMap(target)
}
}
DiGraph(linkedHashMap)
}

val sources = diGraph.findSources.filter(inputNames.contains).toSeq

def getRankedNodes: mutable.ArrayBuffer[Seq[String]] = {
val alreadyVisited = new mutable.HashSet[String]()
val rankNodes = new mutable.ArrayBuffer[Seq[String]]()

def walkByRank(nodes: Seq[String], rankNumber: Int = 0): Unit = {
rankNodes.append(nodes)

alreadyVisited ++= nodes

val nextNodes = nodes.flatMap { node =>
diGraph.getEdges(node)
}.filterNot(alreadyVisited.contains).distinct

if(nextNodes.nonEmpty) {
walkByRank(nextNodes, rankNumber + 1)
}
}

walkByRank(sources)
rankNodes
}

val rankedNodes = getRankedNodes

val rankInfo = rankedNodes.map {
nodesAtRank => s"""{ rank=same; ${nodesAtRank.mkString(" ")} };"""
}.mkString("", "\n ", "")

rankInfo + "\n " + s"""{ rank=same; ${outputPorts.mkString(" ")} };"""
}

/**
* Renders this node
* @return
*/
def render: String = {
def expandBiConnects(target: String, sources: ArrayBuffer[String]): String = {
sources.map { vv => s"""$vv -> $target [dir = "both"]""" }.mkString("\n")
}

val rankInfo = if(renderWithRank) constructRankDirectives else ""

val s = s"""
|subgraph $absoluteName {
| label="$name"
| URL="${url_string.getOrElse("")}"
| bgcolor="$backgroundColor"
| ${inputs.map(_.render).mkString("\n")}
| ${outputs.map(_.render).mkString("\n")}
| ${children.map(_.render).mkString("\n")}
|
| ${connections.map { case (k, v) => s"$v -> $k"}.mkString("\n")}
| ${analogConnections.map { case (k, v) => expandBiConnects(k, v) }.mkString("\n")}
| ${connections.map { case (k, v) => s"$v -> $k"}.mkString("", "\n ", "")}
| ${analogConnections.map { case (k, v) => expandBiConnects(k, v) }.mkString("", "\n ", "")}
| $rankInfo
|}
""".stripMargin
s
Expand Down Expand Up @@ -81,59 +147,7 @@ case class ModuleNode(

//scalastyle:off method.name
def += (childNode: DotNode): Unit = {
namedNodes(childNode.name) = childNode
namedNodes(childNode.absoluteName) = childNode
children += childNode
}
}

import scala.sys.process._

//noinspection ScalaStyle
object ModuleNode {
//noinspection ScalaStyle
def main(args: Array[String]): Unit = {
val topModule = ModuleNode("top", parentOpt = None)

val fox = LiteralNode("fox", BigInt(1), Some(topModule))
val dog = LiteralNode("dog", BigInt(5), Some(topModule))

val mux1 = MuxNode("mux1", Some(topModule))
val mux2 = MuxNode("mux2", Some(topModule))

topModule.inputs += PortNode("in1", Some(topModule))
topModule.inputs += PortNode("in2", Some(topModule))

val subModule = ModuleNode("child", Some(topModule))

subModule.inputs += PortNode("cluster_in1", Some(topModule))
subModule.inputs += PortNode("in2", Some(topModule))

topModule.children += fox
topModule.children += dog
topModule.children += mux1
topModule.children += mux2
topModule.children += subModule

topModule.localConnections(mux1.in1) = fox.absoluteName
topModule.localConnections(mux2.in1) = dog.absoluteName

topModule.localConnections(s"cluster_top_cluster_in1") = topModule.inputs(0).absoluteName

val writer = new PrintWriter(new File("module1.dot"))
writer.println(s"digraph structs {")
writer.println(s"graph [splines=ortho]")
writer.println(s"node [shape=plaintext]")
writer.println(topModule.render)


writer.println(s"}")

writer.close()

"fdp -Tpng -O module1.dot".!!
"open module1.dot.png".!!
}
}



24 changes: 15 additions & 9 deletions src/main/scala/dotvisualizer/transforms/MakeOneDiagram.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ package dotvisualizer.transforms

import java.io.PrintWriter

import dotvisualizer._
import dotvisualizer.dotnodes._
import dotvisualizer.{FirrtlDiagrammer, Scope, StartModule}
import firrtl.PrimOps._
import firrtl.ir._
import firrtl.{CircuitForm, CircuitState, LowForm}
import firrtl.{Transform, WDefInstance, WRef, WSubField, WSubIndex}
import firrtl.{CircuitForm, CircuitState, LowForm, Transform, WDefInstance, WRef, WSubField, WSubIndex}

import scala.collection.mutable

Expand Down Expand Up @@ -39,6 +38,10 @@ class MakeOneDiagram extends Transform {
var linesPrintedSinceFlush = 0
var totalLinesPrinted = 0

val useRanking = state.annotations.collectFirst { case UseRankAnnotation => UseRankAnnotation}.isDefined

val rankDir = state.annotations.collectFirst { case RankDirAnnotation(dir) => dir}.getOrElse("LR")

val printFileName = s"$targetDir$startModuleName.dot"
println(s"creating dot file $printFileName")
val printFile = new PrintWriter(new java.io.File(printFileName))
Expand Down Expand Up @@ -250,7 +253,7 @@ class MakeOneDiagram extends Transform {
case port if port.direction == dir =>
val portNode = PortNode(
port.name, Some(moduleNode),
if(moduleNode.parentOpt.isEmpty) 0 else 1,
rank = if(port.direction == firrtl.ir.Input) 0 else 1000,
port.direction == firrtl.ir.Input
)
nameToNode(getFirrtlName(port.name)) = portNode
Expand Down Expand Up @@ -330,6 +333,8 @@ class MakeOneDiagram extends Transform {

processModule(newPrefix, subModule, subModuleNode, scope.descend, subModuleDepth + 1)

moduleNode.subModuleNames += subModuleNode.absoluteName

case DefNode(_, name, expression) if scope.doComponents() =>
val fName = getFirrtlName(name)
val nodeNode = NodeNode(name, Some(moduleNode))
Expand Down Expand Up @@ -367,14 +372,15 @@ class MakeOneDiagram extends Transform {
findModule(startModuleName, c) match {
case topModule: DefModule =>
pl(s"digraph ${topModule.name} {")
pl("""stylesheet = "styles.css"""")
pl("rankdir=\"LR\"")
// pl(s"graph [splines=ortho];")
pl(s"""stylesheet = "styles.css"""")
pl(s"""rankdir="$rankDir" """)
//TODO: make this an option -- pl(s"graph [splines=ortho];")
val topModuleNode = ModuleNode(startModuleName, parentOpt = None)
if(useRanking) topModuleNode.renderWithRank = true
processModule("", topModule, topModuleNode, Scope(0, 1))
// processModule("", topModule, topModuleNode, getScope(topModule.name))

pl(topModuleNode.render)
//pl("\"Modules Only View Here\" [URL=\"TopLevel.dot.svg\" shape=\"rectangle\"]; \n")

pl("}")
case _ =>
println(s"could not find top module $startModuleName")
Expand Down
4 changes: 3 additions & 1 deletion src/test/scala/dotvisualizer/FirExample.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ class FirExampleSpec extends FreeSpec with Matchers {
"""This is an example of an FIR circuit which has a lot of elements in a single module""" in {
val circuit = chisel3.Driver.elaborate(() => new MyManyDynamicElementVecFir(10))
val firrtl = chisel3.Driver.emit(circuit)
val config = Config(targetDir = "test_run_dir/fir_example/", firrtlSource = firrtl)
val config = Config(
targetDir = "test_run_dir/fir_example/", firrtlSource = firrtl, rankDir = "TB", useRanking = true
)
FirrtlDiagrammer.run(config)
}
}
2 changes: 1 addition & 1 deletion src/test/scala/dotvisualizer/GCD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class GCDTester extends FreeSpec with Matchers {
"GCD circuit to visualize" in {
val circuit = chisel3.Driver.elaborate(() => new GCD)
val firrtl = chisel3.Driver.emit(circuit)
val config = Config(targetDir = "test_run_dir/gcd/", firrtlSource = firrtl)
val config = Config(targetDir = "test_run_dir/gcd/", firrtlSource = firrtl, useRanking = true)
FirrtlDiagrammer.run(config)
}
}
Expand Down