diff --git a/.travis.yml b/.travis.yml
index 618f2f51..7287cbb3 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,6 +1,6 @@
language: scala
scala:
-- 2.11.8
+- 2.12.4
script: "./build.sh"
deploy:
- provider: bintray
diff --git a/bintray-release.json b/bintray-release.json
index 2aca41ad..f7db64a3 100644
--- a/bintray-release.json
+++ b/bintray-release.json
@@ -9,7 +9,7 @@
"released": "2017-11-23"
},
"files": [{
- "includePattern": "scala/target/scala-2.11/Quantomatic.jar",
+ "includePattern": "scala/target/scala-(.*)/Quantomatic.jar",
"uploadPattern": "release-0.6.1/Quantomatic.jar",
"matrixParams": {"override": 1}
}],
diff --git a/bintray.json b/bintray.json
index 8473b0ad..112b7a43 100644
--- a/bintray.json
+++ b/bintray.json
@@ -1,6 +1,6 @@
{
"package": {
- "name": "Quantomatic",
+ "name": "quantomatic",
"repo": "quantomatic",
"subject": "quantomatic"
},
@@ -8,7 +8,7 @@
"name": "bleeding-edge"
},
"files": [{
- "includePattern": "scala/target/scala-2.11/Quantomatic.jar",
+ "includePattern": "scala/target/scala-2.12/Quantomatic.jar",
"uploadPattern": "bleeding-edge/Quantomatic.jar",
"matrixParams": {"override": 1}
}],
diff --git a/docs/json_formats.txt b/docs/json_formats.txt
index 46427410..52d08fe7 100644
--- a/docs/json_formats.txt
+++ b/docs/json_formats.txt
@@ -134,6 +134,7 @@ VTYPE_DESC ::=
"style": {
"shape": "circle" | "rectangle" | "custom",
"custom_shape_path": JSON_STRING,
+ "stroke_width": JSON_INT,
"stroke_color": COLOR,
"fill_color": COLOR,
"label" : {
@@ -209,3 +210,40 @@ STEP ::=
}
+
+========================
+Simproc Batch Run
+========================
+
+File extension: .qsbr
+
+A simproc batch run takes a collection of simprocs and applies them to a list of graphs. The resulting derivations are then recorded and timestamped.
+Notes are optional notes added to the file by the job creator.
+"python" contains all the python source code of the simprocs in memory.
+"selected_simprocs" contains the names of just those simprocs used in this job.
+
+SIMPROC_BATCH_RUN ::=
+{
+ "python": {SIMPROC_NAME: PYTHON_STRING (, SIMPROC_NAME: PYTHON_STRING)*},
+ "selected_simprocs": Array[String],
+ "results": {DERIVATION_WITH_TIMINGS (, DERIVATION_WITH_TIMINGS)*},
+ "notes": String
+}
+
+DERIVATION_WITH_TIMINGS ::=
+{
+ "simproc" : STRING,
+ "derivation": DERIVATION,
+ "timings": {STEP_TIME(, STEP_TIME)*}
+}
+
+STEP_TIME ::=
+{
+ "step": STEP_NAME,
+ "time": TIMESTAMP
+}
+
+TIMESTAMP : NUMBER = time in seconds since process began
+
+
+
diff --git a/scala/build.sbt b/scala/build.sbt
index 35c2e624..293c0356 100644
--- a/scala/build.sbt
+++ b/scala/build.sbt
@@ -2,7 +2,7 @@ name := "quanto"
version := "1.0"
-scalaVersion := "2.11.8"
+scalaVersion := "2.12.6"
scalacOptions ++= Seq("-feature", "-language:implicitConversions")
@@ -14,19 +14,23 @@ fork := true
resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/"
-libraryDependencies += "com.typesafe.akka" %% "akka-actor" % "2.4.10" withSources() withJavadoc()
+libraryDependencies += "com.typesafe.akka" %% "akka-actor" % "2.5.12" withSources() withJavadoc()
-libraryDependencies += "com.fasterxml.jackson.core" % "jackson-core" % "2.1.2"
+libraryDependencies += "com.fasterxml.jackson.core" % "jackson-core" % "2.9.5"
-libraryDependencies += "com.fasterxml.jackson.module" % "jackson-module-scala" % "2.1.2"
+libraryDependencies += "com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.9.5"
-libraryDependencies += "org.scalatest" % "scalatest_2.11" % "2.2.1" % "test"
+//libraryDependencies += "org.scalatest" % "scalatest_2.11" % "2.2.1" % "test"
-libraryDependencies += "org.scala-lang.modules" % "scala-swing_2.11" % "1.0.1"
+//libraryDependencies += "org.scalactic" %% "scalactic" % "3.0.4"
+
+libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.5" % "test"
+
+libraryDependencies += "org.scala-lang.modules" %% "scala-swing" % "2.0.3"
//libraryDependencies += "org.scala-lang" % "scala-compiler" % scalaVersion.value
-libraryDependencies += "org.scala-lang.modules" %% "scala-parser-combinators" % "1.0.5"
+libraryDependencies += "org.scala-lang.modules" %% "scala-parser-combinators" % "1.1.0"
//EclipseKeys.withSource := true
@@ -36,7 +40,7 @@ seq(appbundle.settings: _*)
appbundle.mainClass := Some("quanto.gui.QuantoDerive")
-appbundle.javaVersion := "1.7+"
+appbundle.javaVersion := "1.8+"
appbundle.screenMenu := true
@@ -46,7 +50,7 @@ appbundle.normalizedName := "quantoderiveapp"
appbundle.organization := "org.quantomatic"
-appbundle.version := "0.2.0"
+appbundle.version := "0.3.0"
appbundle.icon := Some(file("../docs/graphics/quantoderive.icns"))
diff --git a/scala/dist/mk-linux-generic.sh b/scala/dist/mk-linux-generic.sh
index 41730b26..a29fcfa8 100755
--- a/scala/dist/mk-linux-generic.sh
+++ b/scala/dist/mk-linux-generic.sh
@@ -5,8 +5,14 @@
BUNDLE=target/QuantoDerive
+# Pre-build cleanup, so we know we're building in a consistent environment.
+sbt clean
+rm -r ../core/heaps/*
+rm -r $BUNDLE/*
+
# Rebuild the core heap
echo Rebuilding the core heap...
+mkdir -p ../core/heaps
(cd ../core; ../scala/dist/linux-dist/poly --use build_heap.ML)
@@ -22,7 +28,8 @@ sbt package
echo Including binaries...
cp -f dist/linux-dist/quanto-derive.sh $BUNDLE/
-cp -f dist/linux-dist/polybin dist/linux-dist/poly $BUNDLE/bin
+cp -f dist/linux-dist/polybin $BUNDLE/bin
+cp -f dist/linux-dist/poly $BUNDLE/bin
cp -f dist/linux-dist/libpolyml.so.4 $BUNDLE/bin
echo Including heap...
@@ -36,7 +43,8 @@ echo Including jars...
cp -f lib_managed/jars/*/*/akka-actor*.jar $BUNDLE/jars
cp -f lib_managed/jars/*/*/scala-library*.jar $BUNDLE/jars
cp -f lib_managed/jars/*/*/scala-swing*.jar $BUNDLE/jars
-cp -f lib_managed/jars/*/*/jackson-core*.jar $BUNDLE/jars
+cp -f lib_managed/bundles/*/*/scala-parser-combinators*.jar $BUNDLE/jars
+cp -f lib_managed/bundles/*/*/jackson-core*.jar $BUNDLE/jars
cp -f lib_managed/bundles/*/*/config*.jar $BUNDLE/jars
# grab local dependences
diff --git a/scala/lib/jedit-textArea.jar b/scala/lib/jedit-textArea.jar
index d3c087a1..0bcb9541 100644
Binary files a/scala/lib/jedit-textArea.jar and b/scala/lib/jedit-textArea.jar differ
diff --git a/scala/src/main/java/quanto/core/data/TexConstants.java b/scala/src/main/java/quanto/core/data/TexConstants.java
index da5a92c5..6f63d6f9 100644
--- a/scala/src/main/java/quanto/core/data/TexConstants.java
+++ b/scala/src/main/java/quanto/core/data/TexConstants.java
@@ -48,6 +48,8 @@ private static void initialize() {
c.put("Phi", "\u03a6");
c.put("Psi", "\u03a8");
c.put("Omega", "\u03a9");
+ c.put("True", "\u22A4");
+ c.put("False", "\u22A5");
constants = c;
}
diff --git a/scala/src/main/java/quanto/util/json/Json.scala b/scala/src/main/java/quanto/util/json/Json.scala
index 5d2306d9..a17836a2 100644
--- a/scala/src/main/java/quanto/util/json/Json.scala
+++ b/scala/src/main/java/quanto/util/json/Json.scala
@@ -277,7 +277,7 @@ object Json {
def this(out: java.io.OutputStream) = this(factory.createJsonGenerator(out, JsonEncoding.UTF8))
def this(out: java.io.Writer) = this(factory.createJsonGenerator(out))
- def this(f: java.io.File) = this(factory.createJsonGenerator(f, JsonEncoding.UTF8))
+ def this(f: java.io.File) = this(factory.createJsonGenerator(new java.io.FileOutputStream(f), JsonEncoding.UTF8))
private var _pp = false
def prettyPrint_=(b: Boolean) {
diff --git a/scala/src/main/resources/quanto/data/ZW.qtheory b/scala/src/main/resources/quanto/data/ZW.qtheory
new file mode 100644
index 00000000..3d6ced8d
--- /dev/null
+++ b/scala/src/main/resources/quanto/data/ZW.qtheory
@@ -0,0 +1,74 @@
+{
+ "name" : "ZW",
+ "core_name" : "zw",
+ "vertex_types" : {
+ "Z" : {
+ "value" : {
+ "type" : "string",
+ "latex_constants" : true,
+ "validate_with_core" : true
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 0.0, 0.0, 0.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 1.0, 1.0, 1.0 ],
+ "shape" : "circle"
+ },
+ "default_data" : {
+ "type" : "Z",
+ "label" : "",
+ "value" : {
+ "pretty" : ""
+ }
+ }
+ },
+ "W" : {
+ "value" : {
+ "type" : "string",
+ "latex_constants" : true,
+ "validate_with_core" : true
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 0.0, 0.0, 0.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 0.0, 0.0, 0.0 ],
+ "shape" : "circle"
+ },
+ "default_data" : {
+ "type" : "W",
+ "label" : "",
+ "value" : {
+ "pretty" : ""
+ }
+ }
+ }
+ },
+ "default_vertex_type" : "Z",
+ "default_edge_type" : "plain",
+ "edge_types" : {
+ "plain" : {
+ "value" : {
+ "type" : "empty",
+ "latex_constants" : false,
+ "validate_with_core" : false
+ },
+ "style" : {
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "stroke_width" : 1,
+ "label" : {
+ "position" : "auto",
+ "fg_color" : [ 0.0, 0.0, 0.0 ]
+ }
+ },
+ "default_data" : {
+ "type" : "plain"
+ }
+ }
+ }
+}
diff --git a/scala/src/main/resources/quanto/data/ZX.qtheory b/scala/src/main/resources/quanto/data/ZX.qtheory
new file mode 100644
index 00000000..eb92d16f
--- /dev/null
+++ b/scala/src/main/resources/quanto/data/ZX.qtheory
@@ -0,0 +1,110 @@
+{
+ "name" : "Red/green theory",
+ "core_name" : "red_green",
+ "vertex_types" : {
+ "X" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : true,
+ "type" : "angle_expr"
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 1.0, 1.0, 1.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 1.0, 0.0, 0.0 ],
+ "shape" : "circle"
+ },
+ "default_data" : {
+ "type" : "X",
+ "value" : ""
+ }
+ },
+ "Z" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : true,
+ "type" : "angle_expr"
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 0.0, 0.0, 0.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 0.0, 0.8, 0.0 ],
+ "shape" : "circle"
+ },
+ "default_data" : {
+ "type" : "Z",
+ "value" : ""
+ }
+ },
+ "hadamard" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : false,
+ "type" : "string"
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 0.0, 0.2, 0.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 1.0, 1.0, 0.0 ],
+ "shape" : "rectangle"
+ },
+ "default_data" : {
+ "type" : "hadamard",
+ "value" : ""
+ }
+ },
+ "var" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : false,
+ "type" : "string"
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 0.0, 0.0, 0.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 0.6, 1.0, 0.8 ],
+ "shape" : "rectangle"
+ },
+ "default_data" : {
+ "type" : "var",
+ "value" : ""
+ }
+ }
+ },
+ "edge_types" : {
+ "string" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : false,
+ "type" : "string"
+ },
+ "style" : {
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "stroke_width" : 1,
+ "label" : {
+ "position" : "center",
+ "fg_color" : [ 0.0, 0.0, 1.0 ],
+ "bg_color" : [ 0.8, 0.8, 1.0, 0.7 ]
+ }
+ },
+ "default_data" : {
+ "type" : "string",
+ "value" : ""
+ }
+ }
+ },
+ "default_vertex_type" : "Z",
+ "default_edge_type" : "string"
+}
\ No newline at end of file
diff --git a/scala/src/main/resources/quanto/data/ZXRails.qtheory b/scala/src/main/resources/quanto/data/ZXRails.qtheory
new file mode 100644
index 00000000..4e17c08e
--- /dev/null
+++ b/scala/src/main/resources/quanto/data/ZXRails.qtheory
@@ -0,0 +1,130 @@
+{
+ "name" : "Red/green with rails theory",
+ "core_name" : "red_green_rails",
+ "vertex_types" : {
+ "X" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : true,
+ "type" : "angle_expr"
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 1.0, 1.0, 1.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 1.0, 0.0, 0.0 ],
+ "shape" : "circle"
+ },
+ "default_data" : {
+ "type" : "X",
+ "value" : ""
+ }
+ },
+ "Z" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : true,
+ "type" : "angle_expr"
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 0.0, 0.0, 0.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 0.0, 0.8, 0.0 ],
+ "shape" : "circle"
+ },
+ "default_data" : {
+ "type" : "Z",
+ "value" : ""
+ }
+ },
+ "hadamard" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : false,
+ "type" : "string"
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 0.0, 0.2, 0.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 1.0, 1.0, 0.0 ],
+ "shape" : "rectangle"
+ },
+ "default_data" : {
+ "type" : "hadamard",
+ "value" : ""
+ }
+ },
+ "var" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : false,
+ "type" : "string"
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 0.0, 0.0, 0.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 0.6, 1.0, 0.8 ],
+ "shape" : "rectangle"
+ },
+ "default_data" : {
+ "type" : "var",
+ "value" : ""
+ }
+ }
+ },
+ "edge_types" : {
+ "string" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : false,
+ "type" : "string"
+ },
+ "style" : {
+ "stroke_color" : [ 0.5, 0.5, 0.5 ],
+ "stroke_width" : 1,
+ "label" : {
+ "position" : "center",
+ "fg_color" : [ 0.0, 0.0, 1.0 ],
+ "bg_color" : [ 0.8, 0.8, 1.0, 0.7 ]
+ }
+ },
+ "default_data" : {
+ "type" : "string",
+ "value" : ""
+ }
+ },
+ "rail" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : false,
+ "type" : "string"
+ },
+ "style" : {
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "stroke_width" : 1,
+ "label" : {
+ "position" : "center",
+ "fg_color" : [ 0.0, 0.0, 1.0 ],
+ "bg_color" : [ 0.8, 0.8, 1.0, 0.7 ]
+ }
+ },
+ "default_data" : {
+ "type" : "rail",
+ "value" : ""
+ }
+ }
+ },
+ "default_vertex_type" : "Z",
+ "default_edge_type" : "rail"
+}
\ No newline at end of file
diff --git a/scala/src/main/resources/quanto/data/composite.qtheory b/scala/src/main/resources/quanto/data/composite.qtheory
new file mode 100644
index 00000000..1f5ae40e
--- /dev/null
+++ b/scala/src/main/resources/quanto/data/composite.qtheory
@@ -0,0 +1,110 @@
+{
+ "name" : "ZX with boolean data",
+ "core_name" : "red_green",
+ "vertex_types" : {
+ "X" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : true,
+ "type" : "angle_expr, boolean"
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 1.0, 1.0, 1.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 1.0, 0.0, 0.0 ],
+ "shape" : "circle"
+ },
+ "default_data" : {
+ "type" : "X",
+ "value" : ""
+ }
+ },
+ "Z" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : true,
+ "type" : "angle_expr, boolean"
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 0.0, 0.0, 0.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 0.0, 0.8, 0.0 ],
+ "shape" : "circle"
+ },
+ "default_data" : {
+ "type" : "Z",
+ "value" : ""
+ }
+ },
+ "hadamard" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : false,
+ "type" : "string"
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 0.0, 0.2, 0.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 1.0, 1.0, 0.0 ],
+ "shape" : "rectangle"
+ },
+ "default_data" : {
+ "type" : "hadamard",
+ "value" : ""
+ }
+ },
+ "var" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : false,
+ "type" : "string"
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 0.0, 0.0, 0.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 0.6, 1.0, 0.8 ],
+ "shape" : "rectangle"
+ },
+ "default_data" : {
+ "type" : "var",
+ "value" : ""
+ }
+ }
+ },
+ "edge_types" : {
+ "string" : {
+ "value" : {
+ "path" : "$.value",
+ "latex_constants" : false,
+ "type" : "string"
+ },
+ "style" : {
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "stroke_width" : 1,
+ "label" : {
+ "position" : "center",
+ "fg_color" : [ 0.0, 0.0, 1.0 ],
+ "bg_color" : [ 0.8, 0.8, 1.0, 0.7 ]
+ }
+ },
+ "default_data" : {
+ "type" : "string",
+ "value" : ""
+ }
+ }
+ },
+ "default_vertex_type" : "Z",
+ "default_edge_type" : "string"
+}
diff --git a/scala/src/main/resources/quanto/data/plain.qtheory b/scala/src/main/resources/quanto/data/plain.qtheory
new file mode 100644
index 00000000..b699a4b3
--- /dev/null
+++ b/scala/src/main/resources/quanto/data/plain.qtheory
@@ -0,0 +1,51 @@
+{
+ "name" : "Plain",
+ "core_name" : "plain",
+ "vertex_types" : {
+ "var" : {
+ "value" : {
+ "type" : "string",
+ "latex_constants" : true,
+ "validate_with_core" : true
+ },
+ "style" : {
+ "label" : {
+ "position" : "inside",
+ "fg_color" : [ 0.0, 0.0, 0.0 ]
+ },
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "fill_color" : [ 1.0, 1.0, 1.0 ],
+ "shape" : "rectangle"
+ },
+ "default_data" : {
+ "type" : "var",
+ "label" : "",
+ "value" : {
+ "pretty" : ""
+ }
+ }
+ }
+ },
+ "default_vertex_type" : "var",
+ "default_edge_type" : "plain",
+ "edge_types" : {
+ "plain" : {
+ "value" : {
+ "type" : "empty",
+ "latex_constants" : false,
+ "validate_with_core" : false
+ },
+ "style" : {
+ "stroke_color" : [ 0.0, 0.0, 0.0 ],
+ "stroke_width" : 1,
+ "label" : {
+ "position" : "auto",
+ "fg_color" : [ 0.0, 0.0, 0.0 ]
+ }
+ },
+ "default_data" : {
+ "type" : "plain"
+ }
+ }
+ }
+}
diff --git a/scala/src/main/resources/quanto/gui/Mac_OS_X_keys.props b/scala/src/main/resources/quanto/gui/Mac_OS_X_keys.props
new file mode 100644
index 00000000..2eeb3908
--- /dev/null
+++ b/scala/src/main/resources/quanto/gui/Mac_OS_X_keys.props
@@ -0,0 +1,194 @@
+action-bar.shortcut=C+ENTER
+add-explicit-fold.shortcut=MC+a
+add-marker-shortcut.shortcut=C+t
+add-marker.shortcut=C+m
+backspace-word-std.shortcut2=A+k
+backspace-word.shortcut=C+BACK_SPACE
+backspace.shortcut=BACK_SPACE
+bottom-docking-area.shortcut=MC+DOWN
+center-caret.shortcut=AC+n
+clear-register.shortcut=C+r C+l
+close-all.shortcut=AMC+w
+close-buffer.shortcut=C+w
+close-docking-area.shortcut=AC+BACK_QUOTE
+closeall-bufferset.shortcut=AC+w
+closeall-except-active.shortcut=CS+w
+collapse-all-folds.shortcut=AS+BACK_SPACE
+collapse-fold.shortcut=A+BACK_SPACE
+complete-word.shortcut=C+b
+copy-append-string-register.shortcut=MC+c
+copy-append.shortcut=CS+c
+copy-string-register.shortcut=AC+c
+copy.shortcut2=C+INSERT
+copy.shortcut=C+c
+cut-append-string-register.shortcut=MC+x
+cut-append.shortcut=CS+x
+cut-string-register.shortcut=AC+x
+cut.shortcut2=S+DELETE
+cut.shortcut=C+x
+delete-end-line.shortcut=CS+DELETE
+delete-line.shortcut=C+d
+delete-paragraph.shortcut2=A+h
+delete-start-line.shortcut=CS+BACK_SPACE
+delete-word.shortcut=C+DELETE
+delete.shortcut2=A+d
+delete.shortcut=DELETE
+document-end.shortcut2=C+DOWN
+document-end.shortcut=C+END
+document-home.shortcut2=C+UP
+document-home.shortcut=C+HOME
+end.shortcut2=C+RIGHT
+end.shortcut=END
+exit.shortcut=C+q
+expand-abbrev.shortcut=C+SEMICOLON
+expand-fold.shortcut=AS+ENTER
+expand-folds.shortcut=AC+ENTER
+expand-one-level.shortcut=A+ENTER
+find-next.shortcut=C+g
+find-prev.shortcut=CS+g
+find-previous.shortcut=C+e g
+find.shortcut=C+f
+focus-buffer-switcher.shortcut=A+BACK_QUOTE
+global-close-buffer.shortcut=MC+w
+global-options.shortcut=C+F12
+goto-line.shortcut=A+g
+goto-marker.shortcut=C+y
+help.shortcut=F1
+home.shortcut2=C+LEFT
+home.shortcut=HOME
+hypersearch-word.shortcut=A+PERIOD
+hypersearch.shortcut2=MC+f
+hypersearch.shortcut=C+PERIOD
+ignore-case.shortcut=AC+i
+indent-lines.shortcut=A+i
+insert-literal.shortcut=MC+v
+insert-newline-indent.shortcut=ENTER
+insert-tab-indent.shortcut=TAB
+invert-selection.shortcut=MC+i
+last-action.shortcut=AC+SPACE
+last-macro.shortcut=C+m C+l
+latextools-compile.shortcut=F6
+left-docking-area.shortcut=MC+LEFT
+line-comment.shortcut=AC+SLASH
+line-end.shortcut=A+e
+line-home.shortcut=A+a
+new-file-in-mode.shortcut=CS+n
+new-file.shortcut=C+n
+next-bracket.shortcut=AC+CLOSE_BRACKET
+next-buffer.shortcut2=A+f
+next-buffer.shortcut=C+PAGE_DOWN
+next-char.shortcut=RIGHT
+next-fold.shortcut=A+DOWN
+next-line.shortcut=DOWN
+next-marker.shortcut=AC+PERIOD
+next-page.shortcut2=A+v
+next-page.shortcut=PAGE_DOWN
+next-paragraph.shortcut=M+DOWN
+next-textarea.shortcut=A+PAGE_DOWN
+next-word.shortcut=M+RIGHT
+open-file.shortcut=C+o
+open-path.shortcut=C+e C+o
+overwrite.shortcut=INSERT
+page-setup.shortcut=CS+p
+parent-fold.shortcut=MC+u
+paste-deleted.shortcut=AC+y
+paste-previous.shortcut=CS+v
+paste-string-register.shortcut=AC+v
+paste.shortcut2=S+INSERT
+paste.shortcut=C+v
+prev-bracket.shortcut=AC+OPEN_BRACKET
+prev-buffer.shortcut=C+PAGE_UP
+prev-char.shortcut2=A+b
+prev-char.shortcut=LEFT
+prev-fold.shortcut=A+UP
+prev-line.shortcut=UP
+prev-marker.shortcut=AC+COMMA
+prev-page.shortcut=PAGE_UP
+prev-paragraph.shortcut=M+UP
+prev-textarea.shortcut=A+PAGE_UP
+prev-word.shortcut=M+LEFT
+print.shortcut=C+p
+quick-search-word.shortcut=A+COMMA
+quick-search.shortcut=C+COMMA
+range-comment.shortcut=C+SLASH
+recent-buffer.shortcut=C+BACK_QUOTE
+record-macro.shortcut=AM+r
+record-temp-macro.shortcut=AM+m
+redo.shortcut=CS+z
+reload.shortcut=F5
+remove-trailing-ws.shortcut=MC+r
+replace-and-find-next.shortcut=CS+f
+replace-in-selection.shortcut=CS+r
+resplit.shortcut=C+4
+right-docking-area.shortcut=MC+RIGHT
+run-temp-macro.shortcut=AM+p
+save-all.shortcut=AC+s
+save-as.shortcut=CS+s
+save.shortcut=C+s
+scroll-and-center.shortcut=C+l
+scroll-down-page.shortcut=A+SLASH
+scroll-to-current-line.shortcut2=A+l
+scroll-to-current-line.shortcut=C+j
+scroll-up-line.shortcut=C+QUOTE
+scroll-up-page.shortcut=A+QUOTE
+search-in-directory.shortcut=CS+d
+search-in-open-buffers.shortcut=CS+b
+select-all.shortcut=C+a
+select-document-end.shortcut=CS+END
+select-document-home.shortcut=CS+HOME
+select-end.shortcut2=CS+RIGHT
+select-end.shortcut=S+END
+select-fold.shortcut=MC+s
+select-home.shortcut2=CS+LEFT
+select-home.shortcut=S+HOME
+select-line-range.shortcut=AC+l
+select-line.shortcut=MC+l
+select-marker.shortcut=C+u
+select-next-char.shortcut2=AS+l
+select-next-char.shortcut=S+RIGHT
+select-next-line.shortcut2=AS+k
+select-next-line.shortcut=S+DOWN
+select-next-page.shortcut2=AS+a
+select-next-page.shortcut=S+PAGE_DOWN
+select-next-paragraph.shortcut=MS+DOWN
+select-next-word.shortcut=MS+RIGHT
+select-none.shortcut=CS+a
+select-paragraph.shortcut=MC+p
+select-prev-char.shortcut2=AS+j
+select-prev-char.shortcut=S+LEFT
+select-prev-line.shortcut2=AS+i
+select-prev-line.shortcut=S+UP
+select-prev-page.shortcut2=AS+q
+select-prev-page.shortcut=S+PAGE_UP
+select-prev-paragraph.shortcut=MS+UP
+select-prev-word.shortcut=MS+LEFT
+shift-left.shortcut2=S+TAB
+shift-left.shortcut=C+OPEN_BRACKET
+shift-right.shortcut=C+CLOSE_BRACKET
+show-context-menu.shortcut=CONTEXT_MENU
+split-horizontal.shortcut=C+2
+split-vertical.shortcut=C+3
+stop-recording.shortcut=AM+s
+swap-marker.shortcut=C+k
+toggle-dock-areas.shortcut=F12
+toggle-full-screen.shortcut=F11
+toggle-line-numbers.shortcut=AC+t
+toggle-multi-select.shortcut=C+BACK_SLASH
+toggle-rect-select.shortcut=A+BACK_SLASH
+top-docking-area.shortcut=MC+UP
+undo.shortcut=C+z
+unsplit-current.shortcut=C+0
+unsplit.shortcut=C+1
+vertical-paste-string-register.shortcut=AM+v
+vertical-paste.shortcut=AC+p
+vfs.browser.delete.shortcut=DELETE
+vfs.browser.home.shortcut=~
+vfs.browser.new-directory.shortcut=INSERT
+vfs.browser.new-file.shortcut=C+n
+vfs.browser.next.shortcut=A+Right
+vfs.browser.previous.shortcut=A+Left
+vfs.browser.reload.shortcut=F5
+vfs.browser.rename.shortcut=F2
+vfs.browser.roots.shortcut=/
+vfs.browser.synchronize.shortcut=-
+vfs.browser.up.shortcut=A+Up
diff --git a/scala/src/main/resources/quanto/gui/add-edge.png b/scala/src/main/resources/quanto/gui/add-edge.png
new file mode 100644
index 00000000..886818b1
Binary files /dev/null and b/scala/src/main/resources/quanto/gui/add-edge.png differ
diff --git a/scala/src/main/resources/quanto/gui/blank.xml b/scala/src/main/resources/quanto/gui/blank.xml
new file mode 100644
index 00000000..ca2c8ee1
--- /dev/null
+++ b/scala/src/main/resources/quanto/gui/blank.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
diff --git a/scala/src/main/resources/quanto/gui/expand.png b/scala/src/main/resources/quanto/gui/expand.png
new file mode 100644
index 00000000..15c6bbe0
Binary files /dev/null and b/scala/src/main/resources/quanto/gui/expand.png differ
diff --git a/scala/src/main/resources/quanto/gui/focus.png b/scala/src/main/resources/quanto/gui/focus.png
new file mode 100644
index 00000000..42ae843f
Binary files /dev/null and b/scala/src/main/resources/quanto/gui/focus.png differ
diff --git a/scala/src/main/resources/quanto/gui/generic-change.svg b/scala/src/main/resources/quanto/gui/generic-change.svg
new file mode 100644
index 00000000..1a49be36
--- /dev/null
+++ b/scala/src/main/resources/quanto/gui/generic-change.svg
@@ -0,0 +1,113 @@
+
+
+
+
diff --git a/scala/src/main/resources/quanto/gui/jEdit_keys.props b/scala/src/main/resources/quanto/gui/jEdit_keys.props
new file mode 100644
index 00000000..7b736f29
--- /dev/null
+++ b/scala/src/main/resources/quanto/gui/jEdit_keys.props
@@ -0,0 +1,231 @@
+#{{{ Function keys
+help.shortcut=F1
+show-context-menu.shortcut=CONTEXT_MENU
+#}}}
+
+#{{{ C+X
+select-all.shortcut=C+a
+complete-word.shortcut=C+b
+copy.shortcut=C+c
+delete-line.shortcut=C+d
+# C+e is a prefix
+find.shortcut=C+f
+find-next.shortcut=F3
+# C+h is not usable on MacOS X
+indent-lines.shortcut=C+i
+join-lines.shortcut=C+j
+swap-marker.shortcut=C+k
+goto-line.shortcut=C+g
+# C+m is a prefix
+new-file.shortcut=C+n
+new-file-in-mode.shortcut=CS+n
+open-file.shortcut=C+o
+reload.shortcut=F5
+print.shortcut=C+p
+exit.shortcut=C+q
+# C+r is a prefix
+save.shortcut=C+s
+add-marker-shortcut.shortcut=C+t
+select-marker.shortcut=C+u
+paste.shortcut=C+v
+close-buffer.shortcut=C+w
+cut.shortcut=C+x
+goto-marker.shortcut=C+y
+undo.shortcut=C+z
+unsplit-current.shortcut=C+0
+unsplit.shortcut=C+1
+split-horizontal.shortcut=C+2
+split-vertical.shortcut=C+3
+resplit.shortcut=C+4
+#}}}
+
+#{{{ C+non-alpha
+delete-start-line.shortcut=CS+BACK_SPACE
+delete-end-line.shortcut=CS+DELETE
+prev-paragraph.shortcut=C+UP
+next-paragraph.shortcut=C+DOWN
+select-prev-paragraph.shortcut=CS+UP
+select-next-paragraph.shortcut=CS+DOWN
+backspace-word.shortcut=C+BACK_SPACE
+delete-word.shortcut=C+DELETE
+document-home.shortcut=C+HOME
+document-end.shortcut=C+END
+select-document-home.shortcut=CS+HOME
+select-document-end.shortcut=CS+END
+prev-word.shortcut=C+LEFT
+select-prev-word.shortcut=CS+LEFT
+next-word.shortcut=C+RIGHT
+select-next-word.shortcut=CS+RIGHT
+action-bar.shortcut=C+ENTER
+prev-buffer.shortcut=C+PAGE_UP
+next-buffer.shortcut=C+PAGE_DOWN
+last-action.shortcut=C+SPACE
+
+recent-buffer.shortcut=C+BACK_QUOTE
+select-block.shortcut=C+OPEN_BRACKET
+match-bracket.shortcut=C+CLOSE_BRACKET
+expand-abbrev.shortcut=C+SEMICOLON
+quick-search.shortcut=C+COMMA
+hypersearch.shortcut=C+PERIOD
+scroll-up-line.shortcut=C+QUOTE
+scroll-down-line.shortcut=C+SLASH
+toggle-multi-select.shortcut=C+BACK_SLASH
+#}}}
+
+#{{{ C+e C+X
+# Unused: f, h, q, y
+copy-append.shortcut=C+e C+a
+search-in-open-buffers.shortcut=C+e C+b
+range-comment.shortcut=C+e C+c
+search-in-directory.shortcut=C+e C+d
+replace-and-find-next.shortcut=C+e C+g
+ignore-case.shortcut=C+e C+i
+scroll-to-current-line.shortcut=C+e C+j
+line-comment.shortcut=C+e C+k
+select-line-range.shortcut=C+e C+l
+add-marker.shortcut=C+e C+m
+center-caret.shortcut=C+e C+n
+scroll-and-center.shortcut=C+l
+open-path.shortcut=C+e C+o
+vertical-paste.shortcut=C+e C+p
+replace-in-selection.shortcut=C+e C+r
+save-all.shortcut=C+e C+s
+toggle-line-numbers.shortcut=C+e C+t
+cut-append.shortcut=C+e C+u
+paste-previous.shortcut=C+e C+v
+close-all.shortcut=C+e C+w
+regexp.shortcut=C+e C+x
+paste-deleted.shortcut=C+e C+y
+redo.shortcut=C+e C+z
+#}}}
+
+#{{{ C+e C+non-alpha
+left-docking-area.shortcut=C+e C+LEFT
+top-docking-area.shortcut=C+e C+UP
+right-docking-area.shortcut=C+e C+RIGHT
+bottom-docking-area.shortcut=C+e C+DOWN
+toggle-full-screen.shortcut=F11
+toggle-dock-areas.shortcut=F12
+combined-options.shortcut=C+F12
+prev-marker.shortcut=C+e C+COMMA
+next-marker.shortcut=C+e C+PERIOD
+prev-bracket.shortcut=C+e C+OPEN_BRACKET
+next-bracket.shortcut=C+e C+CLOSE_BRACKET
+close-docking-area.shortcut=C+e C+BACK_QUOTE
+#}}}
+
+#{{{ C+e X
+# Unused: b e g h j k m o q t y z
+add-explicit-fold.shortcut=C+e a
+collapse-all-folds.shortcut=C+e c
+delete-paragraph.shortcut=C+e d
+format-paragraph.shortcut=C+e f
+find-previous.shortcut=C+e g
+invert-selection.shortcut=C+e i
+select-line.shortcut=C+e l
+narrow-to-fold.shortcut=C+e n n
+narrow-to-selection.shortcut=C+e n s
+select-paragraph.shortcut=C+e p
+remove-trailing-ws.shortcut=C+e r
+select-fold.shortcut=C+e s
+insert-literal.shortcut=C+e v
+select-word.shortcut=C+e w
+parent-fold.shortcut=C+e u
+expand-all-folds.shortcut=C+e x
+#}}}
+
+#{{{ C+e non-alpha
+expand-folds.shortcut=C+e ENTER
+#}}}
+
+#{{{ C+m C+X
+record-temp-macro.shortcut=C+m C+m
+run-temp-macro.shortcut=C+m C+p
+record-macro.shortcut=C+m C+r
+stop-recording.shortcut=C+m C+s
+last-macro.shortcut=C+m C+l
+#}}}
+
+#{{{ C+r C+X
+copy-append-string-register.shortcut=C+r C+a
+copy-string-register.shortcut=C+r C+c
+clear-register.shortcut=C+r C+l
+vertical-paste-string-register.shortcut=C+r C+p
+cut-append-string-register.shortcut=C+r C+u
+paste-string-register.shortcut=C+r C+v
+cut-string-register.shortcut=C+r C+x
+#}}}
+
+#{{{ A+non-alpha
+prev-fold.shortcut=A+UP
+next-fold.shortcut=A+DOWN
+shift-left.shortcut=A+LEFT
+shift-right.shortcut=A+RIGHT
+collapse-fold.shortcut=A+BACK_SPACE
+expand-fold.shortcut=AS+ENTER
+expand-one-level.shortcut=A+ENTER
+quick-search-word.shortcut=A+COMMA
+hypersearch-word.shortcut=A+PERIOD
+scroll-up-page.shortcut=A+QUOTE
+scroll-down-page.shortcut=A+SLASH
+prev-textarea.shortcut=A+PAGE_UP
+next-textarea.shortcut=A+PAGE_DOWN
+
+focus-buffer-switcher.shortcut=A+BACK_QUOTE
+toggle-rect-select.shortcut=A+BACK_SLASH
+#}}}
+
+#{{{ Other keys
+shift-left.shortcut2=S+TAB
+select-none.shortcut=ESCAPE
+backspace.shortcut=BACK_SPACE
+delete.shortcut=DELETE
+overwrite.shortcut=INSERT
+home.shortcut=HOME
+end.shortcut=END
+select-home.shortcut=S+HOME
+select-end.shortcut=S+END
+prev-page.shortcut=PAGE_UP
+next-page.shortcut=PAGE_DOWN
+select-prev-page.shortcut=S+PAGE_UP
+select-next-page.shortcut=S+PAGE_DOWN
+prev-char.shortcut=LEFT
+select-prev-char.shortcut=S+LEFT
+next-char.shortcut=RIGHT
+select-next-char.shortcut=S+RIGHT
+prev-line.shortcut=UP
+select-prev-line.shortcut=S+UP
+next-line.shortcut=DOWN
+select-next-line.shortcut=S+DOWN
+insert-newline-indent.shortcut=ENTER
+insert-newline.shortcut=S+ENTER
+insert-tab-indent.shortcut=TAB
+#}}}
+
+#{{{ Alternative shortcuts for frequently-used commands
+select-next-page.shortcut2=AS+a
+select-prev-line.shortcut2=AS+i
+select-prev-char.shortcut2=AS+j
+select-next-line.shortcut2=AS+k
+select-next-char.shortcut2=AS+l
+select-prev-page.shortcut2=AS+q
+select-end.shortcut2=AS+x
+select-home.shortcut2=AS+z
+copy.shortcut2=C+INSERT
+paste.shortcut2=S+INSERT
+cut.shortcut2=S+DELETE
+#}}}
+
+#{{{ VFS browser
+vfs.browser.delete.shortcut=DELETE
+vfs.browser.home.shortcut=~
+vfs.browser.new-directory.shortcut=INSERT
+vfs.browser.new-file.shortcut=C+n
+vfs.browser.reload.shortcut=F5
+vfs.browser.rename.shortcut=F2
+vfs.browser.roots.shortcut=/
+vfs.browser.synchronize.shortcut=-
+vfs.browser.up.shortcut=A+Up
+vfs.browser.previous.shortcut=A+Left
+vfs.browser.next.shortcut=A+Right
+#}}}
diff --git a/scala/src/main/resources/quanto/gui/jedit.props b/scala/src/main/resources/quanto/gui/jedit.props
new file mode 100644
index 00000000..2df4b86d
--- /dev/null
+++ b/scala/src/main/resources/quanto/gui/jedit.props
@@ -0,0 +1,536 @@
+###
+### jEdit global properties
+### :tabSize=4:indentSize=4:noTabs=false:
+### :folding=explicit:collapseFolds=1:
+### :encoding=UTF-8:
+###
+### Copyright (C) 1998, 2003 Slava Pestov
+###
+
+#{{{ Global settings
+
+# Swing look and feel
+#lookAndFeel=javax.swing.plaf.metal.MetalLookAndFeel
+
+view.antiAlias=subpixel
+
+#{{{ Metal control and menu font
+metal.primary.font=Dialog
+metal.primary.fontsize=12
+metal.primary.fontstyle=0
+#}}}
+
+#{{{ Metal system and user text font
+metal.secondary.font=Dialog
+metal.secondary.fontsize=12
+metal.secondary.fontstyle=0
+#}}}
+
+#{{{ HelpViewer and Tip of the Day font
+helpviewer.font=Dialog
+helpviewer.fontsize=12
+helpviewer.fontstyle=0
+#}}}
+
+
+# Decorate frames and dialogs using Swing L&F?
+decorate.frames=false
+decorate.dialogs=false
+
+# Draw multi-key shortcuts on screen menu bar? (OS X only)
+menu.multiShortcut=false
+
+# If true, welcome screen will be displayed.
+# Set to false after initial startup.
+firstTime=false
+
+# Autosave interval in seconds, 0=off
+autosave=30
+
+# Autosave untitled buffers
+autosaveUntitled=true
+
+
+# Maximum number of elements in a history list
+history=20
+
+# Maximum size (in characters) of a history list
+historyMaxSize=5000000
+
+# Number of recent files
+recentFiles=40
+
+# Restore open files on startup?
+restore=true
+
+# Restore even if file names specified on command line?
+restore.cli=true
+
+# Persistent markers
+persistentMarkers=true
+
+# Two-stage save (save to #filename#save# first, then filename)
+twoStageSave=true
+
+# Strip trailing EOL
+stripTrailingEOL=false
+
+# Complete words from all open buffers
+completeFromAllBuffers=false
+
+# Insert a word completion when the corresponding digit is pressed
+insertCompletionWithDigit=true
+
+# Need this so that new files have a trailing EOL
+buffer.trailingEOL=true
+
+# Caret saving toggle
+saveCaret=true
+
+# Backup on every save
+backupEverySave=false
+
+# Number of backups to make, 0=no backups
+backups=1
+
+# Backup directory
+backup.directory=
+
+# Backup filename prefix and suffix
+backup.prefix=
+backup.suffix=~
+
+# Min time between backups (0 = always backup)
+backup.minTime=0
+
+# Sort buffer list
+sortBuffers=true
+sortByName=true
+
+# Sort recent file list
+sortRecent=false
+
+# Default firewall properties
+firewall.enabled=false
+
+# Keep dialog on by default
+search.keepDialog.toggle=true
+
+# When this limit is reached a dialog appears to cancel the search
+hypersearch.maxWarningResults=1000
+
+# If the hypersearch query is longer than this value it will be truncated
+# on display in the results
+hypersearch.displayQueryLength=100
+
+# Style for highlighting matches in hypersearch results
+hypersearch.results.highlight=bgColor:#ccccff
+
+# Confirm dialogs
+confirmSaveAll=true
+
+# Check modification status on focus?
+
+# if false and autoReload == false, do nothing
+# if false and autoReload == true, autoreload silently
+# If true and autoReload == false, prompt for reloading
+# if true and autoReload == true, reload and notify user (message box)
+autoReloadDialog=true
+
+# If this is true, auto reload; if false, use 'reload' button in dialog
+autoReload=true
+
+# When to check file status on disk: 1=view focus. See GeneralOptionPane class for meanings of values.
+checkFileStatus=1
+
+# Encoding detectors
+encodingDetectors=BOM XML-PI html python buffer-local-property
+
+#}}}
+
+# The critical size, over it a buffer will prompt when setting the edit mode
+largeBufferSize=4000000
+longLineLimit=4000
+largefilemode=ask
+
+#{{{ Buffer settings
+# These can also be specified as buffer-local properties
+
+# Line separator
+# The OS default will be used if this is not specified
+# Unix=\n
+# Windows=\r\n
+# MacOS=\r
+#buffer.lineSeparator=\n
+
+# Encoding
+# The OS default will be used if this is not specified
+#buffer.encoding=ISO8859_1
+
+# Auto-detect GZIP, UTF16, UTF8 and XML encodings?
+buffer.encodingAutodetect=true
+
+# Tab width
+buffer.tabSize=4
+
+# Indent width
+buffer.indentSize=4
+
+# Automatic indentation
+buffer.autoIndent=full
+
+# Soft tabs
+buffer.noTabs=false
+
+#elastic tabstops
+buffer.elasticTabstops=false
+
+# Default edit mode
+buffer.defaultMode=text
+
+# Undo queue size
+buffer.undoCount=100
+
+# Wrap mode (none, soft, hard)
+buffer.wrap=none
+
+# Wrap column
+buffer.maxLineLen=80
+
+# Word break characters
+buffer.wordBreakChars=
+
+# Non-alphanumeric word characters
+buffer.noWordSep=_
+
+# Whether to separate "CamelCased" words
+buffer.camelCasedWords=false
+
+# Fold mode (explicit, indent, or none)
+buffer.folding=none
+
+# Folds with a level equal to or higher than this will be collapsed
+buffer.collapseFolds=0
+#}}}
+
+#{{{ View settings
+
+# Apply jEdit colors to Swing text fields, text areas, lists, tables and trees?
+textColors=false
+
+# Show toolbar?
+view.showToolbar=true
+
+# Show searchbar?
+view.showSearchbar=false
+
+# Show buffer switcher?
+view.showBufferSwitcher=true
+
+# Show full path in title bar?
+view.showFullPath=false
+
+# Font
+view.font=Lucida Sans Typewriter
+view.fontsize=11
+# 0=plain, 1=bold, 2=italic, 3=boldItalic
+view.fontstyle=0
+
+# Font Substitution
+view.enableFontSubst=false
+view.enableFontSubstSystemFonts=true
+
+# Background and foreground colors (for the text area)
+view.bgColor=#ffffff
+view.fgColor=#000000
+
+# Line highlighting
+view.lineHighlight=true
+view.lineHighlightColor=#ffffe0
+
+# Bracket highlighting
+view.structureHighlight=true
+view.structureHighlightColor=#000000
+
+# EOL markers
+view.eolMarkers=false
+view.eolMarkerColor=#ff6633
+view.eolMarkerChar=↩
+
+# Wrap guide
+view.wrapGuide=true
+view.wrapGuideColor=#8080ff
+
+# page breaks
+view.pageBreaks=false
+view.pageBreaksColor=#8080ff
+
+# Caret color
+view.caretColor=#ff0000
+
+# Selection color
+view.selectionColor=#ccccff
+view.multipleSelectionColor=#ccffcc
+
+# Caret blinking
+view.caretBlink=true
+
+# Block caret
+view.blockCaret=false
+
+# Electric borders
+view.electricBorders=3
+
+# Drag and drop of text
+view.dragAndDrop=true
+
+# Abbreviate paths using environment variables
+view.abbreviatePaths=true
+
+# Treat consecutive non-alphanumeric characters as one word
+view.joinNonWordChars=true
+
+# Middle mouse button pastes % register
+view.middleMousePaste=false
+
+# Pressing Ctrl while mouse actions makes them
+# as if selection mode were rectangular mode
+view.ctrlForRectangularSelection=true
+
+# Minimal view size that is considered "valid" when loading perspective
+view.minStartupWidth=200
+view.minStartupHeight=200
+
+Combined\ Options.width=1024
+Combined\ Options.height=768
+
+# The default bufferSet scope
+editpane.bufferset.default=view
+
+#{{{ Gutter
+
+# Gutter background color
+view.gutter.bgColor=#dbdbdb
+
+# Gutter foreground color
+view.gutter.fgColor=#000000
+
+# Gutter highlight color
+view.gutter.highlightColor=#990066
+
+# Gutter current line highlight color
+view.gutter.currentLineColor=#ff0033
+
+# Gutter bracket highlighting
+view.gutter.structureHighlight=true
+view.gutter.structureHighlightColor=#666699
+
+# Gutter marker highlight color
+view.gutter.markerColor=#ccffcc
+
+# Gutter border colors
+view.gutter.focusBorderColor=#990099
+view.gutter.noFocusBorderColor=#ffffff
+
+# Fold triangle color
+view.gutter.foldColor=#838383
+
+# Gutter font name
+view.gutter.font=Monospaced
+
+# Gutter font style
+# 0=plain, 1=bold, 2=italic, 3=boldItalic
+view.gutter.fontstyle=0
+
+# Gutter font size
+view.gutter.fontsize=10
+
+# Gutter border width
+view.gutter.borderWidth=2
+
+# Gutter displays line numbers
+view.gutter.lineNumbers=true
+
+# Line numbers are drawn with this alignment (left, center, right)
+view.gutter.numberAlignment=right
+
+# Gutter line numbers are highlighted at this interval
+view.gutter.highlightInterval=5
+
+# Gutter current line highlighting
+view.gutter.highlightCurrentLine=true
+
+# Marker highlight
+view.gutter.markerHighlight=true
+
+# Click behavior
+view.gutter.foldClick=toggle-fold
+view.gutter.SfoldClick=toggle-fold-fully
+view.gutter.CfoldClick=select-fold
+view.gutter.AfoldClick=narrow-fold
+view.gutter.structClick=match-struct
+view.gutter.CstructClick=select-struct
+view.gutter.AstructClick=narrow-struct
+
+# Show gutter?
+view.gutter.enabled=true
+
+# Gutter minimal number of digits to reserve for line numbers
+view.gutter.minDigitCount=2
+
+# Show selection area in gutter?
+view.gutter.selectionAreaEnabled=true
+# Gutter selection area background color
+# - leave out so by default it is the same as the gutter's
+#view.gutter.selectionAreaBgColor=#dbdbdb
+#}}}
+
+# Expand abbrevs when space bar pressed
+view.expandOnInput=false
+
+#{{{ Syntax styles
+view.style.comment1=color:#cc0000
+view.style.comment2=color:#ff8400
+view.style.comment3=color:#6600cc
+view.style.comment4=color:#cc6600
+view.style.digit=color:#ff0000
+view.style.foldLine.0=color:#000000 bgColor:#dafeda style:b
+view.style.foldLine.1=color:#000000 bgColor:#fff0cc style:b
+view.style.foldLine.2=color:#000000 bgColor:#e7e7ff style:b
+view.style.foldLine.3=color:#000000 bgColor:#ffe0f0 style:b
+view.style.function=color:#9966ff
+view.style.invalid=color:#ff0066 bgColor:#ffffcc
+view.style.keyword1=color:#006699 style:b
+view.style.keyword2=color:#009966 style:b
+view.style.keyword3=color:#0099ff style:b
+view.style.keyword4=color:#66ccff style:b
+view.style.label=color:#02b902
+view.style.literal1=color:#ff00cc
+view.style.literal2=color:#cc00cc
+view.style.literal3=color:#9900cc
+view.style.literal4=color:#6600cc
+view.style.markup=color:#0000ff
+view.style.operator=color:#000000 style:b
+#}}}
+
+# Docking and tool bar positioning
+view.docking.alternateLayout=false
+# "alternate" is actually closer to standard location IMHO...
+view.toolbar.alternateLayout=true
+
+#{{{ Status bar
+view.status.foreground=black
+view.status.background=white
+view.status.visible=true
+view.status.plainview.visible=false
+view.status.show-caret-status=true
+view.status.memory.foreground=#cccccc
+view.status.memory.background=#666699a
+#}}}
+
+#}}}
+
+#{{{ Printing settings
+
+# Font
+print.font=Monospaced
+print.fontsize=9
+print.fontstyle=0
+
+# Print header?
+print.header=true
+
+# Print footer?
+print.footer=true
+
+# Print line numbers?
+print.lineNumbers=true
+
+# Print in color, or black and white?
+print.color=false
+
+# Print tab size
+print.tabSize=2
+
+# Use old (JDK 1.3) printing API
+print.force13=false
+
+# Force use of glyph vectors to work around spacing problems
+print.glyphVector=false
+#}}}
+
+#{{{ File system browser settings
+vfs.browser.showMenubar=true
+vfs.browser.showToolbar=true
+vfs.browser.showIcons=true
+vfs.browser.showHiddenFiles=false
+vfs.browser.sortMixFilesAndDirs=false
+vfs.browser.sortIgnoreCase=true
+vfs.browser.doubleClickClose=false
+vfs.browser.currentBufferFilter=false
+vfs.browser.useDefaultIcons=true
+
+# Can be one of: buffer, home, favorites, last
+vfs.browser.defaultPath=buffer
+
+# File list coloring
+vfs.browser.colorize=true
+
+vfs.browser.colors.0.glob={CVS,#*,*~,\\.*}
+vfs.browser.colors.0.color=#a0a0a0
+vfs.browser.colors.1.glob=*.class
+vfs.browser.colors.1.color=#660066
+vfs.browser.colors.2.glob={build.xml,makefile*,*.hxml}
+vfs.browser.colors.2.color=#666600
+vfs.browser.colors.3.glob=*.{gif,jpg,jpeg,png,bmp,xpm,svg}
+vfs.browser.colors.3.color=#009933
+vfs.browser.colors.4.glob=*.{gz,jar,zip,tgz,z,war,ear}
+vfs.browser.colors.4.color=#990000
+vfs.browser.colors.5.glob=tags
+vfs.browser.colors.5.color=#003366
+vfs.browser.colors.6.glob=*.{sh,bat,cmd}
+vfs.browser.colors.6.color=#006666
+vfs.browser.colors.7.glob={CHANGELOG,CHANGES,INSTALL,LICENSE,NEWS,README,TODO}{,.txt}
+vfs.browser.colors.7.color=#330066
+vfs.browser.colors.8.glob=*.{props,properties}
+vfs.browser.colors.8.color=#666666
+
+browser.custom.context=
+#}}}
+
+#{{{ Plugin manager settings
+plugin-manager.downloadSource=false
+plugin-manager.installUser=true
+plugin-manager.showAll=true
+plugin-manager.mirror.id=NONE
+plugin-manager.deleteDownloads=true
+plugin-manager.hide-libraries.toggle=true
+plugin-blacklist.MacOS.jar=true
+#}}}
+
+#{{{ Hidden settings
+menu.spillover=20
+bufferSwitcher.maxRowCount=10
+showTooltips=true
+ioThreadCount=4
+server.brokenToFront=false
+search.dontSyncFilter=false
+#}}}
+
+#{{{ Miscellaneous settings
+# restore remote VFS files by default (new option)
+options.general.restore.remote=true
+optional.title-template={0} {1}
+mime2mode.text/html=html
+debug.beepOnOutput=false
+#}}}
+
+#{{{ Keymaps
+# the current keymap name
+keymap.current=jEdit
+# the default keymap name
+keymap.default=jEdit
+#}}}
+
+# if lang.usedefaultlocale is true, the lang.current is not used
+lang.usedefaultlocale=true
diff --git a/scala/src/main/resources/quanto/gui/normalise.png b/scala/src/main/resources/quanto/gui/normalise.png
new file mode 100644
index 00000000..6e6c6547
Binary files /dev/null and b/scala/src/main/resources/quanto/gui/normalise.png differ
diff --git a/scala/src/main/resources/quanto/gui/python.xml b/scala/src/main/resources/quanto/gui/python.xml
index 1e97faa3..7d36a94b 100644
--- a/scala/src/main/resources/quanto/gui/python.xml
+++ b/scala/src/main/resources/quanto/gui/python.xml
@@ -356,6 +356,26 @@
__truediv__
__version__
__xor__
+
+
+ new_graph_from_json
+ vertex_angle_is
+ EMPTY
+ JSON_REWRITE
+ REWRITE
+ REWRITE_METRIC
+ REWRITE_METRIC_TO
+ REWRITE_WEAK_METRIC
+ REWRITE_TARGETED
+ ANNEAL
+ LOG
+ REWRITE_TARGET_LIST
+ REPEAT
+ REDUCE
+ REDUCE_METRIC
+ REDUCE_METRIC_TO
+ REDUCE_WEAK_METRIC
+ REDUCE_TARGETED
diff --git a/scala/src/main/resources/quanto/gui/quantoderive.ico b/scala/src/main/resources/quanto/gui/quantoderive.ico
new file mode 100644
index 00000000..4dd67f50
Binary files /dev/null and b/scala/src/main/resources/quanto/gui/quantoderive.ico differ
diff --git a/scala/src/main/scala/quanto/cosy/BlockEnumeration.scala b/scala/src/main/scala/quanto/cosy/BlockEnumeration.scala
index 3517212a..26d1edcc 100644
--- a/scala/src/main/scala/quanto/cosy/BlockEnumeration.scala
+++ b/scala/src/main/scala/quanto/cosy/BlockEnumeration.scala
@@ -1,7 +1,9 @@
package quanto.cosy
+import quanto.cosy.BlockGenerators.QuickGraph
+import quanto.data.Names._
+import quanto.data.{VName, _}
import quanto.util.json._
-import quanto.data._
/**
* Enumerates diagrams by composing simple building blocks in a 2D fashion
@@ -10,12 +12,13 @@ class BlockEnumeration {
}
-case class Block(inputs: List[Int], outputs: List[Int], name: String, tensor: Tensor) {
+case class Block(inputs: List[Int], outputs: List[Int], name: String, tensor: Tensor, graph: Graph = new Graph()) {
lazy val toJson = JsonObject(
"inputs" -> inputs,
"outputs" -> outputs,
"name" -> name,
- "tensor" -> tensor.toJson
+ "tensor" -> tensor.toJson,
+ "graph" -> graph.toJson()
)
override def toString: String = this.name
@@ -26,11 +29,17 @@ object Block {
new Block((js / "inputs").asArray.map(x => x.intValue).toList,
(js / "outputs").asArray.map(x => x.intValue).toList,
(js / "name").stringValue,
- Tensor.fromJson((js / "tensor").asObject))
+ Tensor.fromJson((js / "tensor").asObject),
+ Graph.fromJson((js / "graph").asObject))
}
+
+ implicit def toTensor(b: Block) : Tensor = b.tensor
}
-class BlockRow(val blocks: List[Block], suggestTensor: Option[Tensor] = None) {
+
+case class BlockRow(blocks: List[Block], suggestTensor: Option[Tensor] = None, suggestGraph: Option[Graph] = None) {
+
+
implicit def optionTensor(t: Tensor): Option[Tensor] = {
Option(t)
}
@@ -42,6 +51,13 @@ class BlockRow(val blocks: List[Block], suggestTensor: Option[Tensor] = None) {
case None => blocks.foldLeft(Tensor.id(1))((a, b) => a x b.tensor)
}
+ lazy val graph: Graph = suggestGraph match {
+ case Some(g) => g
+ case None => blocks.foldLeft(new Graph()) { (g, b) =>
+ BlockRow.graphsSideBySide(g, b.graph)
+ }
+ }
+
lazy val toJson = JsonObject(
"blocks" -> JsonArray(blocks.map(b => b.toJson)),
"inputs" -> inputs,
@@ -56,13 +72,46 @@ object BlockRow {
def fromJson(js: JsonObject): BlockRow = {
new BlockRow((js / "blocks").asArray.map(j => Block.fromJson(j.asObject)).toList)
}
+
+
+ def graphsSideBySide(fixed: Graph, added: Graph): Graph = {
+
+
+ val startingInputs = fixed.verts.filter(vn => vn.prefix == "i-")
+ val startingOutputs = fixed.verts.filter(vn => vn.prefix == "o-")
+ val shift = math.max(startingInputs.size, startingOutputs.size)
+
+ val aShifted: Graph = added.verts.foldLeft(added)((g, vn) => g.updateVData(vn)(vd => {
+ val currentCoord = added.vdata(vn).coord
+ added.vdata(vn).withCoord(currentCoord._1 + shift, currentCoord._2)
+ }
+ ))
+
+ val renameMap = aShifted.verts.flatMap(vn => (vn.prefix, vn.suffix) match {
+ case ("i-", n) => Some(vn -> VName("i-" + (n + startingInputs.size)))
+ case ("o-", n) => Some(vn -> VName("o-" + (n + startingOutputs.size)))
+ case (a, b) => Some(vn -> VName("bl-" + shift + "-" + a + b))
+ case _ => None
+ }).toMap
+
+ val aShiftedRename = aShifted.rename(vrn = renameMap, ern = Map(), brn = Map())
+ fixed.appendGraph(aShiftedRename.renameAvoiding(fixed), false)
+ }
}
-class BlockStack(val rows: List[BlockRow]) extends Ordered[BlockStack] {
+case class BlockStack(rows: List[BlockRow],
+ suggestedTensor: Option[Tensor] = None,
+ suggestedGraph: Option[Graph] = None) extends Ordered[BlockStack] {
+
+ require(rows.flatMap(_.blocks).nonEmpty) // require at least one block
+
lazy val tensor: Tensor = if (rows.isEmpty) {
Tensor.id(1)
} else {
- rows.foldRight(Tensor.id(rows.last.tensor.width))((a, b) => a.tensor o b)
+ suggestedTensor match {
+ case Some(t) => t
+ case None => rows.foldRight(Tensor.id(rows.last.tensor.width))((a, b) => a.tensor o b)
+ }
}
lazy val toJson = JsonObject(
"rows" -> JsonArray(rows.map(b => b.toJson)),
@@ -79,197 +128,90 @@ class BlockStack(val rows: List[BlockRow]) extends Ordered[BlockStack] {
this.rows.length - that.asInstanceOf[BlockStack].rows.length
}
-}
-
-object BlockStack {
- def fromJson(js: JsonObject): BlockStack = {
- new BlockStack((js / "rows").asArray.map(j => BlockRow.fromJson(j.asObject)).toList)
- }
-}
-
-object BlockRowMaker {
- implicit def quickList(n: Int): List[Int] = {
- n match {
- case 0 => List()
- case 1 => List(0)
- case m => quickList(m - 1) ::: List(0)
+ //left-most row is on top!
+ lazy val graph: Graph = {
+ suggestedGraph match {
+ case Some(g) => g
+ case None =>
+ val blocks = rows.flatMap(row => row.blocks)
+ if (blocks.isEmpty) {
+ new Graph()
+ } else {
+ val gdata = blocks.head.graph.data
+ BlockStack.joinRowsInStack(
+ rows.reverse.zipWithIndex.foldRight(new Graph(data = gdata))((ri, g) =>
+ BlockStack.graphStackUnjoined(g, ri._1.graph, ri._2)))
+
+ }
}
- }
-
- val BellTeleportation: List[Block] = List(
- Block(List(0), List(0), " A ", Tensor.id(2)),
- Block(List(-1), List(-1), " B ", Tensor.id(2)),
- Block(List(0, 0), List(1), " m1", Tensor(Array(Array(1, 0, 0, 1)))),
- Block(List(0, 0), List(2), " m2", Tensor(Array(Array(1, 0, 0, -1)))),
- Block(List(0, 0), List(3), " m3", Tensor(Array(Array(0, 1, 1, 0)))),
- Block(List(0, 0), List(4), " m4", Tensor(Array(Array(0, 1, -1, 0)))),
- Block(List(-1, 1), List(-1), " c1", Tensor(Array(Array(1, 0), Array(0, 1)))),
- Block(List(-1, 2), List(-1), " c2", Tensor(Array(Array(1, 0), Array(0, -1)))),
- Block(List(-1, 3), List(-1), " c3", Tensor(Array(Array(0, 1), Array(1, 0)))),
- Block(List(-1, 4), List(-1), " c4", Tensor(Array(Array(0, 1), Array(-1, 0)))),
- Block(List(), List(0, -1), " p ", Tensor(Array(Array(1, 0, 0, 1))).transpose)
- ) :::
- swapQuantumClassical(List(0, -1), Tensor.id(2), List(1, 2, 3, 4)) :::
- makeClassicalIdentites(List(1, 2, 3, 4))
- val ZW: List[Block] = List(
- // BOTTOM TO TOP!
- Block(1, 1, " 1 ", Tensor.idWires(1)),
- Block(2, 2, " s ", Tensor.swap(List(1, 0))),
- Block(2, 2, "crs", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 1, 0), Array(0, 1, 0, 0), Array(0, 0, 0, -1)))),
- Block(0, 2, "cup", Tensor(Array(Array(1, 0, 0, 1))).transpose),
- Block(2, 0, "cap", Tensor(Array(Array(1, 0, 0, 1)))),
- Block(1, 1, " w ", Tensor(Array(Array(1, 0), Array(0, -1)))),
- Block(1, 2, "1w2", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 0, -1))).transpose),
- Block(2, 1, "2w1", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 0, -1)))),
- Block(1, 1, " b ", Tensor(Array(Array(0, 1), Array(1, 0)))),
- Block(1, 2, "1b2", Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 1))).transpose),
- Block(2, 1, "2b1", Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 1))))
- )
- def swapQuantumClassical(listQuantum: List[Int], quantumTensor: Tensor, listClassical: List[Int]): List[Block] = {
- (for (w1 <- listQuantum; w2 <- listClassical) yield {
- List(Block(List(w1, w2), List(w2, w1), w1 + "s" + w2, quantumTensor),
- Block(List(w2, w1), List(w1, w2), w2 + "s" + w1, quantumTensor))
- }).flatten
}
- def swapQuantumQuantum(listQuantum: List[Int], quantumTensor: Tensor): List[Block] = {
- (for (w1 <- listQuantum; w2 <- listQuantum) yield {
- List(Block(List(w1, w2), List(w2, w1), w1 + "s" + w2, quantumTensor),
- Block(List(w2, w1), List(w1, w2), w2 + "s" + w1, quantumTensor))
- }).flatten
+ // newly added row is on top!
+ // And our graphs go bottom to top
+ // Forces computation of tensor and graph (as that's what this is designed for)
+ def append(row: BlockRow): BlockStack = {
+ //TODO: Make this parallel?
+ val newTensor = if (rows.nonEmpty) {
+ require(rows.head.outputs == row.inputs)
+ row.tensor o tensor
+ } else {
+ row.tensor
+ }
+ val newRows = row :: rows
+ val newGraph = BlockStack.joinRowsInStack(BlockStack.graphStackUnjoined(graph, row.graph, rows.length))
+ new BlockStack(newRows, Some(newTensor), Some(newGraph))
}
- def makeClassicalIdentites(listClassical: List[Int]): List[Block] = {
- for (w <- listClassical) yield {
- Block(List(w), List(w), "w" + w + " ", Tensor.id(1))
- }
+}
+
+object BlockStack {
+ def fromJson(js: JsonObject): BlockStack = {
+ new BlockStack((js / "rows").asArray.map(j => BlockRow.fromJson(j.asObject)).toList)
}
- // Traditionally the number of angles is 3 (Clifford) or 9 (Clifford+T)
- def ZXQudit(dimension: Int, numAngles: Int): List[Block] = {
- require(dimension > 1)
- def swapIndex(i: Int): Int = {
- val left: Int = i / dimension
- val right = i % dimension
- right * dimension + left
+ def joinRowsInStack(graph: Graph): Graph = {
+ var g = QuickGraph(graph)
+ val InputPattern = raw"r-(\d+)-i-(\d+)".r
+ g.verts.foreach(vName => vName.s match {
+ case InputPattern(n, m) =>
+ // For 0.toString is coming out as "" not "0", but this shouldn't affect us
+ if (g.verts.contains(s"r-${Integer.parseInt(n) - 1}-o-$m")) {
+ g = g.joinIfNotAlready(s"r-${Integer.parseInt(n) - 1}-o-$m", s"r-$n-i-$m", Some("rail"))
+ }
+ case _ => g
}
-
- val H: Tensor = Hadamard(dimension)
-
- val greenFork = Tensor(dimension, dimension * dimension,
- (i, j) => if (j == i * (dimension + 1)) Complex.one else Complex.zero)
-
- // Go through the diagonal entries creating all the different spiders
- val greenBlocks = (1 until dimension).foldLeft(
- List(Block(1, 1, "g", Tensor(dimension, dimension, (i, j) => if (i == 0 && j == 0) Complex.one else Complex.zero)))
- )((lb, i) => lb.flatMap(b => (0 until numAngles).map(x =>
- Block(1, 1, b.name + "|" + x, b.tensor + Tensor(dimension, dimension, (j, k) =>
- if (j == i && k == i) ei(x * 2 * math.Pi / numAngles) else Complex.zero)
- ))))
-
- List(
- Block(1, 1, " 1 ", Tensor.id(dimension)),
- Block(2, 2, " s ", Tensor.permutation((0 until dimension * dimension).toList.map(x => swapIndex(x)))),
- Block(1, 1, " H ", H),
- Block(1, 1, " H'", H.dagger),
- Block(2, 1, "2g1", greenFork),
- Block(1, 2, "1g2", greenFork.dagger),
- Block(0, 1, "gu ", Tensor(dimension, 1, (i, j) => Complex.one).scaled(1.0 / math.sqrt(dimension))),
- Block(1, 2, "1r2", (H.dagger o greenFork o (H x H)).dagger),
- Block(2, 1, "2r1", H.dagger o greenFork o (H x H)),
- Block(0, 1, "ru ", Tensor(dimension, 1, (i, j) => if (i == 0 && j == 0) Complex.one else Complex.zero))
- ) ::: greenBlocks ::: greenBlocks.map(b =>
- Block(1, 1, "r" + b.name.tail, H.dagger o b.tensor o H)
)
+ g
}
- // Traditionally the number of angles is 3 (Clifford) or 9 (Clifford+T)
- def ZXQutrit(numAngles: Int = 9): List[Block] = {
- val H3 = Hadamard(3)
- List(
- Block(1, 1, " 1 ", Tensor.id(3)),
- Block(2, 2, " s ", Tensor.permutation(List(0, 3, 6, 1, 4, 7, 2, 5, 8))),
- Block(1, 1, " H ", H3),
- Block(1, 1, " H'", H3.dagger),
- Block(2, 1, "2g1", Tensor(Array(
- Array(1, 0, 0, 0, 0, 0, 0, 0, 0),
- Array(0, 0, 0, 0, 1, 0, 0, 0, 0),
- Array(0, 0, 0, 0, 0, 0, 0, 0, 1)
- ))),
- Block(1, 2, "1g2", Tensor(Array(
- Array(1, 0, 0, 0, 0, 0, 0, 0, 0),
- Array(0, 0, 0, 0, 1, 0, 0, 0, 0),
- Array(0, 0, 0, 0, 0, 0, 0, 0, 1)
- )).transpose),
- Block(0, 1, "gu ", Tensor(Array(Array(1, 1, 1))).scaled(1.0 / math.sqrt(3)).transpose),
- Block(1, 2, "1r2", (H3.dagger o Tensor(Array(
- Array(1, 0, 0, 0, 0, 0, 0, 0, 0),
- Array(0, 0, 0, 0, 1, 0, 0, 0, 0),
- Array(0, 0, 0, 0, 0, 0, 0, 0, 1)
- )) o (H3 x H3)).dagger),
- Block(2, 1, "2r1", H3.dagger o Tensor(Array(
- Array(1, 0, 0, 0, 0, 0, 0, 0, 0),
- Array(0, 0, 0, 0, 1, 0, 0, 0, 0),
- Array(0, 0, 0, 0, 0, 0, 0, 0, 1)
- )) o (H3 x H3)),
- Block(0, 1, "ru ", Tensor(Array(Array(1, 0, 0))).transpose)
- ) :::
- (for (i <- 0 until numAngles; j <- 0 until numAngles) yield {
- val gs = Tensor(Array(
- Array[Complex](1, 0, 0),
- Array[Complex](0, ei(i * 2 * math.Pi / numAngles), 0),
- Array[Complex](0, 0, ei(j * 2 * math.Pi / numAngles))
- ))
- List(Block(1, 1, "g|" + i.toString + "|" + j.toString, gs),
- Block(1, 1, "r|" + i.toString + "|" + j.toString, H3.dagger o gs o H3))
- }).flatten.toList
+ def graphStackUnjoined(fixed: Graph, adding: Graph, depth: Int): Graph = {
+ val renameMap = adding.verts.map(vn => vn -> VName(s"r-$depth-${vn.s}")).toMap
+ val aRenamed = adding.rename(vrn = renameMap, ern = Map(), brn = Map())
+ val aRenamedShifted = aRenamed.verts.foldLeft(aRenamed)((g, vn) => g.updateVData(vn)(vd => {
+ val currentCoord = aRenamed.vdata(vn).coord
+ aRenamed.vdata(vn).withCoord(currentCoord._1, currentCoord._2 + depth)
+ }
+ ))
+
+ if(fixed.data.theory != adding.data.theory){
+ fixed.copy(data = fixed.data.copy(theory = fixed.data.theory.mixin(adding.data.theory, None)))
+ .appendGraph(aRenamedShifted.renameAvoiding(fixed), noOverlap = false)
+ }else{
+ fixed
+ .appendGraph(aRenamedShifted.renameAvoiding(fixed), noOverlap = false)
+ }
}
+}
- def Hadamard(dimension: Int): Tensor =
- Tensor(dimension, dimension, (i, j) => ei(2 * math.Pi * i * j / dimension)).scaled(1 / math.sqrt(dimension))
-
- private def ei(angle: Double) = Complex(math.cos(angle), math.sin(angle))
-
- def ZX(numAngles: Int = 8): List[Block] = List(
- Block(1, 1, " 1 ", Tensor.idWires(1)),
- Block(2, 2, " s ", Tensor.swap(List(1, 0))),
- Block(0, 2, "cup", Tensor(Array(Array(1, 0, 0, 1))).transpose),
- Block(2, 0, "cap", Tensor(Array(Array(1, 0, 0, 1))))) :::
- (for (i <- 0 until numAngles) yield {
- Block(1, 1, "gT" + i.toString, Tensor(Array(
- Array(Complex.one, Complex.zero),
- Array(Complex.zero, ei(2 * i * math.Pi / numAngles)))))
- }).toList :::
- (for (i <- 0 until numAngles) yield {
- Block(1, 1, "rT" + i.toString, new Tensor(Array(
- Array(1 + ei(2 * i * math.Pi / numAngles), 1 - ei(2 * i * math.Pi / numAngles)),
- Array(1 - ei(2 * i * math.Pi / numAngles), 1 + ei(2 * i * math.Pi / numAngles)))))
- }).toList :::
- List(
- Block(1, 1, " H ", Tensor(Array(Array(1, 1), Array(1, -1))).scaled(1.0 / math.sqrt(2))),
- Block(2, 1, "2g1", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 0, 1)))),
- Block(1, 2, "1g2", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 0, 1))).transpose),
- Block(0, 1, "gu ", Tensor(Array(Array(1, 1))).transpose),
- Block(1, 2, "1r2", Tensor(Array(Array(1, 0, 0, 1), Array(0, 1, 1, 0))).transpose),
- Block(2, 1, "2r1", Tensor(Array(Array(1, 0, 0, 1), Array(0, 1, 1, 0)))),
- Block(0, 1, "ru ", Tensor(Array(Array(1, 0))).transpose)
- )
+object BlockRowMaker {
- def Bian2Qubit: List[Block] = List(
- // Block(0, 0, " w ", Tensor.id(1).scaled(ei(math.Pi / 4))), Ignored for now.
- Block(1, 1, " 1 ", Tensor.id(2)),
- Block(2, 2, " Zc", Tensor.diagonal(Array(Complex.one, Complex.one, Complex.one, Complex.zero.-(Complex.one)))),
- Block(1, 1, " T ", Tensor.diagonal(Array(Complex.one, ei(math.Pi / 4)))),
- Block(1, 1, " H ", Tensor(Array(Array(1, 1), Array(1, -1))).scaled(1.0 / math.sqrt(2))),
- Block(1, 1, " S ", Tensor.diagonal(Array(Complex.one, ei(math.Pi / 2))))
- )
- def stackToGraph(stack: BlockStack, blockToGraph: Block => Graph): Graph = {
+ def predicateStackToGraph(stack: BlockStack, blockToGraph: Block => Graph): Graph = {
var g = (for ((row, index) <- stack.rows.zipWithIndex) yield {
- val rowGraph = rowToGraph(row, blockToGraph)
+ val rowGraph = predicateRowToGraph(row, blockToGraph)
val verts = rowGraph.verts.toList
val vRename = verts.map(v => v -> VName("r" + index + v.s)).toMap
val eRename = rowGraph.edges.map(e => e -> EName("r" + index + e.s)).toMap
@@ -278,13 +220,17 @@ object BlockRowMaker {
}).foldLeft(new Graph())((g, sg) => g.appendGraph(sg))
for ((row, index) <- stack.rows.init.zipWithIndex) {
for (j <- row.outputs.indices) {
- g = g.addEdge(EName("r" + index + "e" + j), UndirEdge(), VName("r" + index + "o" + j) -> VName("r" + (index + 1) + "i" + j))
+ g = g.addEdge(
+ EName("r" + index + "e" + j),
+ UndirEdge(),
+ VName("r" + index + "o" + j) -> VName("r" + (index + 1) + "i" + j)
+ )
}
}
g
}
- def rowToGraph(row: BlockRow, blockToGraph: Block => Graph): Graph = {
+ def predicateRowToGraph(row: BlockRow, blockToGraph: Block => Graph): Graph = {
var inputsCovered = 0
var outputsCovered = 0
val inputRegex = raw"i(\d+)".r
@@ -305,81 +251,23 @@ object BlockRowMaker {
}).foldLeft(new Graph())((g, bg) => g.appendGraph(bg))
}
-
- def Bian2QubitToGraph(block: Block): Graph = {
-
- // The graph produced must be 0-indexed on inputs and outputs, and of the form /i\d+/ and /o\d+/
-
- implicit def vname(str: String): VName = VName(str)
-
- implicit def vnamepair(p: (String, String)): (VName, VName) = VName(p._1) -> VName(p._2)
-
- implicit def ename(str: String): EName = EName(str)
-
- val rg = Theory.fromFile("red_green")
-
- var g = new Graph()
-
- var eCount = 0
-
- def join(v0: String, v1: String): Unit = {
- g = g.addEdge("e" + eCount, UndirEdge(), v0 -> v1)
- eCount += 1
- }
-
- def addVertex(name: String, data: VData): Unit = {
- g = g.addVertex(name, data)
- }
-
- for (i <- block.inputs.indices) {
- addVertex("i" + i, WireV())
- }
- for (i <- block.outputs.indices) {
- addVertex("o" + i, WireV())
- }
- block.name match {
- case " 1 " =>
- join("i0", "o0")
- case " T " =>
- addVertex("v0", NodeV(data = JsonObject("type" -> "X", "value" -> "pi/4"), theory = rg))
- join("i0", "v0")
- join("v0", "o0")
- case " H " =>
- addVertex("v0", NodeV(data = JsonObject("type" -> "hadamard", "value" -> "0"), theory = rg))
- join("i0", "v0")
- join("v0", "o0")
- case " S " =>
- addVertex("v0", NodeV(data = JsonObject("type" -> "X", "value" -> "pi/2"), theory = rg))
- join("i0", "v0")
- join("v0", "o0")
- case " Zc" =>
- addVertex("v0", NodeV(data = JsonObject("type" -> "X", "value" -> "0"), theory = rg))
- addVertex("v1", NodeV(data = JsonObject("type" -> "Z", "value" -> "0"), theory = rg))
- join("v0", "v1")
- join("i0", "v0")
- join("v0", "o0")
- join("i1", "v1")
- join("v1", "o1")
- }
-
- g
- }
-
def apply(maxBlocks: Int, allowedBlocks: List[Block], maxInOut: Option[Int] = None): List[BlockRow] = {
require(maxBlocks >= 0)
- (for (i <- 0 to maxBlocks) yield {
- makeRowsOfSize(i, allowedBlocks, maxInOut)
- }).flatten.toList
+ makeRowsUpToSize(maxBlocks, allowedBlocks, maxInOut)
}
- def makeRowsOfSize(size: Int,
- allowedBlocks: List[Block],
- maxInOut: Option[Int] = None): List[BlockRow] = {
+ def makeRowsUpToSize(size: Int,
+ allowedBlocks: List[Block],
+ maxInOut: Option[Int] = None): List[BlockRow] = {
val maybeTooLargeRows: List[BlockRow] = size match {
case 0 => List[BlockRow]()
case 1 => allowedBlocks.map(b => new BlockRow(List(b)))
- case n => for (base <- makeRowsOfSize(n - 1, allowedBlocks, maxInOut); block <- allowedBlocks) yield {
- new BlockRow(block :: base.blocks)
+ case n => {
+ val fewerBlocks = makeRowsUpToSize(n - 1, allowedBlocks, maxInOut)
+
+ (for (base <- fewerBlocks; block <- allowedBlocks) yield {
+ new BlockRow(block :: base.blocks)
+ }) ::: fewerBlocks
}
}
maxInOut match {
@@ -387,6 +275,8 @@ object BlockRowMaker {
case Some(max) => maybeTooLargeRows.filter(r => (r.inputs.length <= max) && (r.outputs.length <= max))
}
}
+
+
}
object BlockStackMaker {
diff --git a/scala/src/main/scala/quanto/cosy/BlockGenerators.scala b/scala/src/main/scala/quanto/cosy/BlockGenerators.scala
new file mode 100644
index 00000000..2743d8d3
--- /dev/null
+++ b/scala/src/main/scala/quanto/cosy/BlockGenerators.scala
@@ -0,0 +1,444 @@
+package quanto.cosy
+
+import quanto.data._
+import quanto.util.json.JsonObject
+import quanto.data.Names._
+
+import scala.util.matching.Regex
+
+object BlockGenerators {
+
+
+ implicit def quickList(n: Int): List[Int] = {
+ n match {
+ case 0 => List()
+ case 1 => List(0)
+ case m => quickList(m - 1) ::: List(0)
+ }
+ }
+
+
+ class QuickGraph(graph: Graph) {
+ val _g : Graph = graph
+ def node(nodeType: String, angle: String = "", xCoord : Double = 0, nodeName : String = "v-0") : QuickGraph = {
+ val name = _g.verts.freshWithSuggestion(VName(nodeName))
+ val data = NodeV(data = JsonObject("type" -> nodeType, "value" -> angle), theory = _g.data.theory).withCoord((xCoord, 0))
+ QuickGraph(_g.addVertex(name, data))
+ }
+
+ def bbox(name: String, vertices: Set[String]): QuickGraph = {
+ val bbname = _g.bboxes.freshWithSuggestion(BBName(name))
+ val bbdata = BBData(theory = _g.data.theory)
+ QuickGraph(_g.addBBox(bbname, bbdata, vertices.map(s => VName(s))))
+ }
+
+ def addInput(count : Int = 1) : QuickGraph = {
+ count match {
+ case 0 => this
+ case 1 =>
+ val name = _g.verts.freshWithSuggestion(VName("i-0"))
+ val data = WireV().withCoord(name.suffix,-0.5)
+ QuickGraph(_g.addVertex(name, data))
+ case n =>
+ val name = _g.verts.freshWithSuggestion(VName("i-0"))
+ val data = WireV().withCoord(name.suffix,-0.5)
+ QuickGraph(_g.addVertex(name, data)).addInput(count -1)
+ }
+ }
+ def addOutput(count: Int = 1) : QuickGraph = {
+ count match {
+ case 0 => this
+ case 1 =>
+ val name = _g.verts.freshWithSuggestion(VName("o-0"))
+ val data = WireV().withCoord(name.suffix,0.5)
+ QuickGraph(_g.addVertex(name, data))
+ case n =>
+ val name = _g.verts.freshWithSuggestion(VName("o-0"))
+ val data = WireV().withCoord(name.suffix,0.5)
+ QuickGraph(_g.addVertex(name, data)).addOutput(count -1)
+ }
+ }
+
+ def join(s1 : String, s2: String, edgeType : Option[String] = None) : QuickGraph = {
+ val name = _g.edges.freshWithSuggestion("e-0")
+ val eData = if(edgeType.isEmpty) {
+ _g.data.theory.edgeTypes(_g.data.theory.defaultEdgeType).defaultData
+ } else {
+ _g.data.theory.edgeTypes(edgeType.get).defaultData
+ }
+ val data = UndirEdge(eData, theory = _g.data.theory)
+ val v1 = VName(s1)
+ val v2 = VName(s2)
+ QuickGraph(_g.addEdge(name, data, v1 -> v2))
+ }
+
+ def join(s1: String, s2s: Set[String], edgeType: Option[String]) : QuickGraph = {
+ s2s.foldLeft(this)((g,v) => g.join(s1, v, edgeType))
+ }
+
+ def joinIfNotAlready(s1: String, s2: String, edgeType : Option[String] = None) : QuickGraph = {
+ val isJoined = _g.adjacentVerts(s1).contains(s2)
+ if(!isJoined){
+ this.join(s1, s2, edgeType)
+ }else{
+ this
+ }
+ }
+
+ def apply() : Graph = _g
+ }
+
+ object QuickGraph {
+ def apply(graph: Graph) = new QuickGraph(graph)
+ def apply(theory: Theory) = new QuickGraph(
+ new Graph(GData(data = new JsonObject(), annotation = new JsonObject(), theory = theory))
+ )
+
+ implicit def slow(qg: QuickGraph) : Graph = qg()
+
+ val boundaryRegex : Option[Regex] = Some(raw"""(i|o)-(\d+)""".r)
+ }
+
+ val ZXTheory : Theory = Theory.fromFile("ZX")
+ val ZXRails : Theory = Theory.fromFile("ZXRails")
+
+ def zxCNOT(size: Int = 2): Block = {
+ require(size >= 2)
+ // Size 2 is the standard
+ val tensorSize = Math.pow(2, size).toInt
+ val cut = tensorSize - 2
+ val swap = Tensor.swap((0 until size).map {
+ case 0 => 0
+ case 1 => size - 1
+ case n => n - 1
+ }.toList)
+ val penultimate = size - 1
+ val graph = (0 until size).foldLeft(QuickGraph(ZXRails).addInput(size).addOutput(size)) {
+ (g, i) =>
+ i match {
+ case 0 => g.node("Z", nodeName = "z", xCoord = i).join("i-" + i, "z").join("z", "o-" + i)
+ case `penultimate` =>
+ g.node("X", nodeName = "x", xCoord = i).join("i-" + i, "x").join("x", "o-" + i)
+ case _ => g.join("i-" + i, "o-" + i, Some("rail"))
+ }
+ }.join("z", "x", Some("string"))
+ Block(size, size, "CNOT" + size,
+ swap o
+ (
+ Tensor(Array(Array(1, 0, 0, 0), Array(0, 1, 0, 0), Array(0, 0, 0, 1), Array(0, 0, 1, 0)))
+ x Tensor.idWires(size - 2)
+ )
+ o swap.transpose
+ ,
+ graph
+ )
+ }
+ def zxTONC(size: Int = 2): Block = {
+ require(size >= 2)
+ // Size 2 is the standard
+ val swapList = (0 until size).map {
+ i => {
+ if (i == 0) size - 1
+ else if (i == size - 1) 0
+ else i
+ }
+ }.toList
+ val swap = Tensor.swap(swapList)
+ val penultimate = size - 1
+ Block(size, size, "TONC" + size, swap.transpose o zxCNOT(size) o swap,
+ (0 until size).foldLeft(QuickGraph(ZXRails).addInput(size).addOutput(size)) {
+ (g, i) =>
+ i match {
+ case `penultimate` => g.node("Z", nodeName = "z", xCoord = i).join("i-" + i, "z").join("z", "o-" + i)
+ case 0 => g.node("X", nodeName = "x", xCoord = i).join("i-" + i, "x").join("x", "o-" + i)
+ case _ => g.join("i-" + i, "o-" + i)
+ }
+ }.join("z", "x", Some("string"))
+ )
+ }
+
+
+ def zxCNOTs(maxWidth: Int = 2): List[Block] = (for (i <- 2 to maxWidth) yield zxCNOT(i)).toList
+ def zxTONCs(maxWidth: Int = 2): List[Block] = (for (i <- 2 to maxWidth) yield zxCNOT(i)).toList
+
+ val zxQubitHadamard : Block = Block(1, 1, " H ", Hadamard(2), QuickGraph(ZXRails).addInput().addOutput()
+ .node("hadamard", nodeName = "h").join("i-0", "h").join("h", "o-0"))
+
+ def zxQubitTwists(twoPiDivision: Int = 4): IndexedSeq[Block] = {
+ def tensor(angle: Int, nodeType: String): Tensor = {
+ nodeType match {
+ case "Z" => Tensor(Array(
+ Array(Complex.one, Complex.zero),
+ Array(Complex.zero, ei(2 * angle * math.Pi / twoPiDivision))))
+ case "X" =>
+ Tensor(Array(
+ Array(1 + ei(2 * angle * math.Pi / twoPiDivision), 1 - ei(2 * angle * math.Pi / twoPiDivision)),
+ Array(1 - ei(2 * angle * math.Pi / twoPiDivision), 1 + ei(2 * angle * math.Pi / twoPiDivision))))
+ }
+ }
+
+ def block(angle: Int, nodeType: String): Block = {
+ Block(1, 1, (2*angle) + nodeType + twoPiDivision, tensor(angle, nodeType),
+ QuickGraph(ZXRails)
+ .addInput().addOutput()
+ .node(nodeType = nodeType, nodeName = nodeType, angle = s"${2*angle} / $twoPiDivision")
+ .join("i-0", nodeType, Some("rail")).join(nodeType, "o-0", Some("rail")))
+ }
+
+ (0 until twoPiDivision).flatMap(i =>
+ List(block(i, "X"), block(i, "Z"))
+ )
+ }
+
+
+ val ZXClifford: List[Block] = List(
+ Block(1, 1, " 1 ", Tensor.idWires(1), QuickGraph(ZXRails).addInput().addOutput().join("i-0", "o-0", Some("rail"))),
+ zxQubitHadamard,
+ Block(2, 2, " s ", Tensor.swap(List(1, 0)),
+ QuickGraph(ZXRails).addInput(2).addOutput(2).join("i-0", "o-1", Some("rail")).join("i-1", "o-0", Some("rail"))),
+ Block(1, 1, "gpi", Tensor(Array(Array(1, 0), Array(0, -1))), QuickGraph(ZXRails).addInput().addOutput()
+ .node("Z", nodeName = "zpi", angle = raw"\pi").join("i-0", "zpi", Some("rail")).join("zpi", "o-0", Some("rail"))),
+ Block(1, 1, "rpi", Tensor(Array(Array(0, 1), Array(1, 0))), QuickGraph(ZXRails).addInput().addOutput()
+ .node("X", nodeName = "xpi", angle = raw"\pi").join("i-0", "xpi", Some("rail")).join("xpi", "o-0", Some("rail"))),
+ Block(1, 1, "gp2", Tensor(Array(Array(Complex(1, 0), Complex(0, 0)), Array(Complex(0, 0), Complex(0, 1)))),
+ QuickGraph(ZXRails).addInput().addOutput()
+ .node("Z", nodeName = "z", angle = raw"\pi / 2").join("i-0", "z", Some("rail")).join("z", "o-0", Some("rail"))),
+ Block(1, 1, "rp2", Tensor(Array(Array(Complex(1, 1), Complex(1, -1)), Array(Complex(1, -1), Complex(1, 1)))),
+ QuickGraph(ZXRails).addInput().addOutput()
+ .node("X", nodeName = "x", angle = raw"\pi / 2").join("i-0", "x", Some("rail")).join("x", "o-0", Some("rail"))),
+ Block(2, 2, "CNT", Tensor(Array(Array(1, 0, 0, 0), Array(0, 1, 0, 0), Array(0, 0, 0, 1), Array(0, 0, 1, 0))),
+ QuickGraph(ZXRails).addInput(2).addOutput(2).node("Z", nodeName = "z").node("X", xCoord = 1, nodeName = "x")
+ .join("i-0", "z", Some("rail")).join("i-1", "x", Some("rail"))
+ .join("o-0", "z", Some("rail")).join("o-1", "x", Some("rail"))
+ .join("z", "x", Some("string")))
+ )
+
+
+ val ZXCNOT: List[Block] = List(
+ Block(1, 1, " 1 ", Tensor.idWires(1), QuickGraph(ZXRails).addInput().addOutput().join("i-0", "o-0", Some("rail"))),
+ Block(2, 2, " s ", Tensor.swap(List(1, 0)),
+ QuickGraph(ZXRails).addInput(2).addOutput(2).join("i-0", "o-1", Some("rail")).join("i-1", "o-0", Some("rail"))),
+ Block(2, 2, "CNT", Tensor(Array(Array(1, 0, 0, 0), Array(0, 1, 0, 0), Array(0, 0, 0, 1), Array(0, 0, 1, 0))),
+ QuickGraph(ZXRails).addInput(2).addOutput(2).node("Z", nodeName = "z").node("X", xCoord = 1, nodeName = "x")
+ .join("i-0", "z", Some("rail")).join("i-1", "x", Some("rail"))
+ .join("o-0", "z", Some("rail")).join("o-1", "x", Some("rail"))
+ .join("z", "x", Some("string")))
+ )
+
+ val BellTeleportation: List[Block] = List(
+ Block(List(0), List(0), " A ", Tensor.id(2)),
+ Block(List(-1), List(-1), " B ", Tensor.id(2)),
+ Block(List(0, 0), List(1), " m1", Tensor(Array(Array(1, 0, 0, 1)))),
+ Block(List(0, 0), List(2), " m2", Tensor(Array(Array(1, 0, 0, -1)))),
+ Block(List(0, 0), List(3), " m3", Tensor(Array(Array(0, 1, 1, 0)))),
+ Block(List(0, 0), List(4), " m4", Tensor(Array(Array(0, 1, -1, 0)))),
+ Block(List(-1, 1), List(-1), " c1", Tensor(Array(Array(1, 0), Array(0, 1)))),
+ Block(List(-1, 2), List(-1), " c2", Tensor(Array(Array(1, 0), Array(0, -1)))),
+ Block(List(-1, 3), List(-1), " c3", Tensor(Array(Array(0, 1), Array(1, 0)))),
+ Block(List(-1, 4), List(-1), " c4", Tensor(Array(Array(0, 1), Array(-1, 0)))),
+ Block(List(), List(0, -1), " p ", Tensor(Array(Array(1, 0, 0, 1))).transpose)
+ ) :::
+ swapQuantumClassical(List(0, -1), Tensor.id(2), List(1, 2, 3, 4)) :::
+ makeClassicalIdentites(List(1, 2, 3, 4))
+ val ZW: List[Block] = List(
+ // BOTTOM TO TOP!
+ Block(1, 1, " 1 ", Tensor.idWires(1)),
+ Block(2, 2, " s ", Tensor.swap(List(1, 0))),
+ Block(2, 2, "crs", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 1, 0), Array(0, 1, 0, 0), Array(0, 0, 0, -1)))),
+ Block(0, 2, "cup", Tensor(Array(Array(1, 0, 0, 1))).transpose),
+ Block(2, 0, "cap", Tensor(Array(Array(1, 0, 0, 1)))),
+ Block(1, 1, " w ", Tensor(Array(Array(1, 0), Array(0, -1)))),
+ Block(1, 2, "1w2", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 0, -1))).transpose),
+ Block(2, 1, "2w1", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 0, -1)))),
+ Block(1, 1, " b ", Tensor(Array(Array(0, 1), Array(1, 0)))),
+ Block(1, 2, "1b2", Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 1))).transpose),
+ Block(2, 1, "2b1", Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 1))))
+ )
+
+ def swapQuantumClassical(listQuantum: List[Int], quantumTensor: Tensor, listClassical: List[Int]): List[Block] = {
+ (for (w1 <- listQuantum; w2 <- listClassical) yield {
+ List(Block(List(w1, w2), List(w2, w1), w1 + "s" + w2, quantumTensor),
+ Block(List(w2, w1), List(w1, w2), w2 + "s" + w1, quantumTensor))
+ }).flatten
+ }
+
+ def swapQuantumQuantum(listQuantum: List[Int], quantumTensor: Tensor): List[Block] = {
+ (for (w1 <- listQuantum; w2 <- listQuantum) yield {
+ List(Block(List(w1, w2), List(w2, w1), w1 + "s" + w2, quantumTensor),
+ Block(List(w2, w1), List(w1, w2), w2 + "s" + w1, quantumTensor))
+ }).flatten
+ }
+
+ def makeClassicalIdentites(listClassical: List[Int]): List[Block] = {
+ for (w <- listClassical) yield {
+ Block(List(w), List(w), "w" + w + " ", Tensor.id(1))
+ }
+ }
+
+ // Traditionally the number of angles is 3 (Clifford) or 9 (Clifford+T)
+ def ZXQudit(dimension: Int, numAngles: Int): List[Block] = {
+ require(dimension > 1)
+
+ def swapIndex(i: Int): Int = {
+ val left: Int = i / dimension
+ val right = i % dimension
+ right * dimension + left
+ }
+
+ val H: Tensor = Hadamard(dimension)
+
+ val greenFork = Tensor(dimension, dimension * dimension,
+ (i, j) => if (j == i * (dimension + 1)) Complex.one else Complex.zero)
+
+ // Go through the diagonal entries creating all the different spiders
+ val greenBlocks = (1 until dimension).foldLeft(
+ List(
+ Block(1, 1, "g", Tensor(dimension, dimension, (i, j) => if (i == 0 && j == 0) Complex.one else Complex.zero))
+ )
+ )((lb, i) => lb.flatMap(b => (0 until numAngles).map(x =>
+ Block(1, 1, b.name + "|" + x, b.tensor + Tensor(dimension, dimension, (j, k) =>
+ if (j == i && k == i) ei(x * 2 * math.Pi / numAngles) else Complex.zero)
+ ))))
+
+ List(
+ Block(1, 1, " 1 ", Tensor.id(dimension)),
+ Block(2, 2, " s ", Tensor.permutation((0 until dimension * dimension).toList.map(x => swapIndex(x)))),
+ Block(1, 1, " H ", H),
+ Block(1, 1, " H'", H.dagger),
+ Block(2, 1, "2g1", greenFork),
+ Block(1, 2, "1g2", greenFork.dagger),
+ Block(0, 1, "gu ", Tensor(dimension, 1, (_, _) => Complex.one).scaled(1.0 / math.sqrt(dimension))),
+ Block(1, 2, "1r2", (H.dagger o greenFork o (H x H)).dagger),
+ Block(2, 1, "2r1", H.dagger o greenFork o (H x H)),
+ Block(0, 1, "ru ", Tensor(dimension, 1, (i, j) => if (i == 0 && j == 0) Complex.one else Complex.zero))
+ ) ::: greenBlocks ::: greenBlocks.map(b =>
+ Block(1, 1, "r" + b.name.tail, H.dagger o b.tensor o H)
+ )
+ }
+
+ def Hadamard(dimension: Int): Tensor =
+ Tensor(dimension, dimension, (i, j) => ei(2 * math.Pi * i * j / dimension)).scaled(1 / math.sqrt(dimension))
+
+ // Traditionally the number of angles is 3 (Clifford) or 9 (Clifford+T)
+ def ZXQutrit(numAngles: Int = 9): List[Block] = {
+ val H3 = Hadamard(3)
+ List(
+ Block(1, 1, " 1 ", Tensor.id(3)),
+ Block(2, 2, " s ", Tensor.permutation(List(0, 3, 6, 1, 4, 7, 2, 5, 8))),
+ Block(1, 1, " H ", H3),
+ Block(1, 1, " H'", H3.dagger),
+ Block(2, 1, "2g1", Tensor(Array(
+ Array(1, 0, 0, 0, 0, 0, 0, 0, 0),
+ Array(0, 0, 0, 0, 1, 0, 0, 0, 0),
+ Array(0, 0, 0, 0, 0, 0, 0, 0, 1)
+ ))),
+ Block(1, 2, "1g2", Tensor(Array(
+ Array(1, 0, 0, 0, 0, 0, 0, 0, 0),
+ Array(0, 0, 0, 0, 1, 0, 0, 0, 0),
+ Array(0, 0, 0, 0, 0, 0, 0, 0, 1)
+ )).transpose),
+ Block(0, 1, "gu ", Tensor(Array(Array(1, 1, 1))).scaled(1.0 / math.sqrt(3)).transpose),
+ Block(1, 2, "1r2", (H3.dagger o Tensor(Array(
+ Array(1, 0, 0, 0, 0, 0, 0, 0, 0),
+ Array(0, 0, 0, 0, 1, 0, 0, 0, 0),
+ Array(0, 0, 0, 0, 0, 0, 0, 0, 1)
+ )) o (H3 x H3)).dagger),
+ Block(2, 1, "2r1", H3.dagger o Tensor(Array(
+ Array(1, 0, 0, 0, 0, 0, 0, 0, 0),
+ Array(0, 0, 0, 0, 1, 0, 0, 0, 0),
+ Array(0, 0, 0, 0, 0, 0, 0, 0, 1)
+ )) o (H3 x H3)),
+ Block(0, 1, "ru ", Tensor(Array(Array(1, 0, 0))).transpose)
+ ) :::
+ (for (i <- 0 until numAngles; j <- 0 until numAngles) yield {
+ val gs = Tensor(Array(
+ Array[Complex](1, 0, 0),
+ Array[Complex](0, ei(i * 2 * math.Pi / numAngles), 0),
+ Array[Complex](0, 0, ei(j * 2 * math.Pi / numAngles))
+ ))
+ List(Block(1, 1, "g|" + i.toString + "|" + j.toString, gs),
+ Block(1, 1, "r|" + i.toString + "|" + j.toString, H3.dagger o gs o H3))
+ }).flatten.toList
+ }
+
+ def ZXGates(numAngles: Int = 8, CNOTWidth: Int = 2): List[Block] = List(
+ Block(1, 1, " 1 ", Tensor.idWires(1), QuickGraph(ZXRails).addInput().addOutput().join("i-0", "o-0")),
+ Block(2, 2, " s ", Tensor.swap(List(1, 0)))
+ //, zxQubitHadamard
+ ) :::
+ zxQubitTwists(numAngles).toList :::
+ zxCNOTs(CNOTWidth) :::
+ zxTONCs(CNOTWidth)
+
+ private def ei(angle: Double) = Complex(math.cos(angle), math.sin(angle))
+
+ def Bian2Qubit: List[Block] = List(
+ // Block(0, 0, " w ", Tensor.id(1).scaled(ei(math.Pi / 4))), Ignored for now.
+ Block(1, 1, " 1 ", Tensor.id(2)),
+ Block(2, 2, " Zc", Tensor.diagonal(Array(Complex.one, Complex.one, Complex.one, Complex.zero.-(Complex.one)))),
+ Block(1, 1, " T ", Tensor.diagonal(Array(Complex.one, ei(math.Pi / 4)))),
+ Block(1, 1, " H ", Tensor(Array(Array(1, 1), Array(1, -1))).scaled(1.0 / math.sqrt(2))),
+ Block(1, 1, " S ", Tensor.diagonal(Array(Complex.one, ei(math.Pi / 2))))
+ )
+
+
+ def Bian2QubitToGraph(block: Block): Graph = {
+
+ // The graph produced must be 0-indexed on inputs and outputs, and of the form /i\d+/ and /o\d+/
+
+ implicit def vname(str: String): VName = VName(str)
+
+ implicit def vnamepair(p: (String, String)): (VName, VName) = VName(p._1) -> VName(p._2)
+
+ implicit def ename(str: String): EName = EName(str)
+
+ val rg = Theory.fromFile("red_green")
+
+ var g = new Graph()
+
+ var eCount = 0
+
+ def join(v0: String, v1: String): Unit = {
+ g = g.addEdge(g.edges.fresh, UndirEdge(), vnamepair(v0,v1))
+ eCount += 1
+ }
+
+ def addVertex(name: String, data: VData): Unit = {
+ g = g.addVertex(vname(name), data)
+ }
+
+ for (i <- block.inputs.indices) {
+ addVertex("i" + i, WireV())
+ }
+ for (i <- block.outputs.indices) {
+ addVertex("o" + i, WireV())
+ }
+ block.name match {
+ case " 1 " =>
+ join("i0", "o0")
+ case " T " =>
+ addVertex("v0", NodeV(data = JsonObject("type" -> "X", "value" -> "pi/4"), theory = rg))
+ join("i0", "v0")
+ join("v0", "o0")
+ case " H " =>
+ addVertex("v0", NodeV(data = JsonObject("type" -> "hadamard", "value" -> "0"), theory = rg))
+ join("i0", "v0")
+ join("v0", "o0")
+ case " S " =>
+ addVertex("v0", NodeV(data = JsonObject("type" -> "X", "value" -> "pi/2"), theory = rg))
+ join("i0", "v0")
+ join("v0", "o0")
+ case " Zc" =>
+ addVertex("v0", NodeV(data = JsonObject("type" -> "X", "value" -> "0"), theory = rg))
+ addVertex("v1", NodeV(data = JsonObject("type" -> "Z", "value" -> "0"), theory = rg))
+ join("v0", "v1")
+ join("i0", "v0")
+ join("v0", "o0")
+ join("i1", "v1")
+ join("v1", "o1")
+ }
+
+ g
+ }
+
+}
diff --git a/scala/src/main/scala/quanto/cosy/CoSyRun.scala b/scala/src/main/scala/quanto/cosy/CoSyRun.scala
new file mode 100644
index 00000000..5518a29c
--- /dev/null
+++ b/scala/src/main/scala/quanto/cosy/CoSyRun.scala
@@ -0,0 +1,426 @@
+package quanto.cosy
+
+import java.io.File
+import java.util.Calendar
+
+import quanto.cosy.Interpreter.{ZXAngleData, interpretZXSpider}
+import quanto.data.Theory.{ValueType, VertexDesc}
+import quanto.data._
+import quanto.rewrite.{Match, Matcher}
+import quanto.util.FileHelper._
+import quanto.util.json.{Json, JsonObject}
+import quanto.util.{FileHelper, Rational, UserAlerts}
+
+import scala.concurrent.duration.Duration
+import scala.util.matching.Regex
+
+/**
+ * This class performs the actual batch conjecture synthesis
+ */
+abstract class CoSyRun[S, T](
+ rulesDir: File,
+ theory: Theory,
+ duration: Duration,
+ outputDir: Option[File]
+ ) {
+
+ val Generator: Iterator[S]
+ var reductionRules: List[Rule] = List()
+ var equivClasses: Map[T, Graph] = Map()
+
+ // Turn your generator into a graph
+ def makeGraph(gen: S): Graph
+
+ // Turn your generator into a tensor
+ def makeTensor(gen: S): T
+
+ // If you want to check for isomorphisms specify a regex to match the boundaries here
+ val matchBorders: Option[Regex]
+
+ def checkIsomorphic(graph1: Graph, graph2: Graph): Boolean =
+ GraphAnalysis.checkIsomorphic(theory, Some(matchBorders.getOrElse("".r)))(graph1,graph2)
+
+ // Positive if left bigger than right:
+ def compareGraph(left: Graph, right: Graph): Int
+
+ // See if two tensors should be considered equivalent
+ // e.g. isRoughly or isRoughlyUpToScalar
+ def compareTensor(a: T, b: T): Boolean
+
+ def findClassesCloseTo(tensor: T): Map[T, Graph] = equivClasses.filter(tv => compareTensor(tv._1, tensor))
+
+ // How to store the values in values.txt
+ def makeString(a: S, b: T): String
+
+ // what to do with graphs that aren't hit by reduction rules
+ // e.g. for circuits put that circuit into the pile to be added to for next iteration
+ def doWithUnmatched(a: S): Unit
+
+ // Core loop.
+ // Comes with option of time restriction.
+ def begin(): List[Rule] = {
+ def now(): Long = Calendar.getInstance().getTimeInMillis
+
+ val timeStart = now()
+ while (Duration(now() - timeStart, "millis") < duration && Generator.hasNext) {
+ // Get a graph
+ val next: S = Generator.next()
+ val nextGraph = makeGraph(next)
+
+
+ /*
+// Print out each graph made
+ if (outputDir.nonEmpty) {
+ FileHelper.printToFile(
+ outputDir.get.toURI.resolve("./" + nextGraph.hashCode + ".qgraph"),
+ Graph.toJson(nextGraph).toString,
+ append = true)
+ }
+*/
+
+ val matchesReductionRule = reductionRules.exists(rule => Matcher.findMatches(rule.lhs, nextGraph).nonEmpty)
+
+ // Check if it can be reduced by known rules
+ if (!matchesReductionRule) {
+ val interpretation = makeTensor(next)
+ val nearbyClasses = findClassesCloseTo(interpretation)
+
+ nearbyClasses.size match {
+ case 0 =>
+ // doesn't fit into any existing class
+ doWithUnmatched(next)
+ equivClasses = equivClasses + (interpretation -> nextGraph)
+ updateValuesFile(next, interpretation)
+ case 1 =>
+ // Something with that tensor exists
+ val equivClass = nearbyClasses.head
+ val existing: Graph = equivClass._2
+
+
+ // Don't create a rule between isomorphic (constrained at the boundary) graphs
+ val isomorphic = if (matchBorders.nonEmpty) {
+ checkIsomorphic(nextGraph, existing)
+ } else {
+ false
+ }
+
+
+ if (!isomorphic) {
+ doWithUnmatched(next)
+ if (compareGraph(existing, nextGraph) > 0) {
+ // update class with smaller graph
+ equivClasses = equivClasses + (equivClass._1 -> nextGraph)
+ createRule(existing, nextGraph)
+ } else {
+ // new graph is at least as big as current
+ // keep status quo, create rule
+ createRule(nextGraph, existing)
+ }
+ }
+
+ case n =>
+ // Somehow in the approximation radius of two classes
+ // For now, add to both.
+ for (equivClass <- nearbyClasses) {
+ // Something with that tensor exists
+ val existing: Graph = equivClass._2
+
+
+ val isomorphic = if (matchBorders.nonEmpty) {
+ checkIsomorphic(nextGraph, existing)
+ } else {
+ false
+ }
+
+
+ if (!isomorphic) {
+ doWithUnmatched(next)
+ createRule(nextGraph, existing)
+ // update class with smaller graph
+ if (compareGraph(existing, nextGraph) > 0) {
+ equivClasses = equivClasses + (equivClass._1 -> nextGraph)
+ }
+ } else {
+ // Don't create a rule between isomorphic (constrained) graphs
+ }
+ }
+ }
+ }
+ }
+ reductionRules
+ }
+
+ private def updateValuesFile(next: S, interpretation: T): Unit = {
+ if (outputDir.nonEmpty) {
+ FileHelper.printToFile(
+ outputDir.get.toURI.resolve("./values.txt"),
+ makeString(next, interpretation),
+ append = true)
+ }
+ }
+
+ def createRule(lhs: Graph, rhs: Graph): Rule = {
+ val name = s"${lhs.hashCode}_${rhs.hashCode}.qrule"
+ val r = new Rule(lhs, rhs, derivation = Some("CoSy"), description = RuleDesc(name))
+ if (outputDir.nonEmpty) {
+ printJson(outputDir.get.toURI.resolve("./" + name).getPath, Rule.toJson(r, lhs.data.theory))
+ }
+ loadRule(r)
+ r
+ }
+
+ def loadRule(rule: Rule): Unit = {
+ def reduceRules(rules: List[Rule]) : List[Rule] = {
+ RuleSynthesis.greedyReduceRules(compareGraph, Some((theory, matchBorders)))(rules).filter(
+ rule => !checkIsomorphic(rule.lhs, rule.rhs)
+ )
+ }
+ // Please don't put bbox rules into here unless you really mean them to be here and they reduce left->right
+ if (rule.lhs.bboxes.nonEmpty) {
+ reductionRules = rule :: reductionRules
+ } else {
+ // No bboxes, act normally
+ if (compareGraph(rule.lhs, rule.rhs) > 0) {
+ reductionRules = rule :: reductionRules
+ reductionRules = reduceRules(reductionRules)
+ } else if (compareGraph(rule.rhs, rule.lhs) > 0) {
+ reductionRules = rule.inverse :: reductionRules
+ reductionRules = reduceRules(reductionRules)
+ } else {
+ // Not a reduction rule, so leave it out
+ }
+ }
+ }
+
+ FileHelper.readAllOfType(rulesDir.getAbsolutePath, ".*qrule", Rule.fromJson(_, theory)).foreach(loadRule)
+}
+
+object CoSyRuns {
+
+
+ class CoSyCircuit(rulesDir: File,
+ theory: Theory,
+ duration: Duration,
+ outputDir: Option[File],
+ numBoundaries: Int
+ ) extends CoSyRun[BlockStack, Tensor](rulesDir, theory, duration, outputDir) {
+
+
+ // Include the empty diagram
+ // equivClasses = equivClasses + (Tensor(Array(Array(1))) -> new Graph())
+
+ override val Generator: Iterator[BlockStack] = new Iterator[BlockStack] {
+
+ override def hasNext: Boolean = unusedStacks.hasNext || unusedRows.hasNext || nextRoundOfStacks.nonEmpty
+
+ override def next(): BlockStack = {
+ if (!unusedRows.hasNext) {
+ if (!unusedStacks.hasNext) {
+ unusedStacks = nextRoundOfStacks.toIterator
+ nextRoundOfStacks = List()
+ UserAlerts.alert("Finished next row")
+ }
+ unusedRows = rows.toIterator
+ currentStack = unusedStacks.next()
+ }
+ currentStack.append(unusedRows.next())
+ }
+ }
+ override val matchBorders = Some(raw"""(i|o)-(\d+)""".r)
+ val blocks: List[Block] = BlockGenerators.ZXGates(4, numBoundaries)
+ UserAlerts.alert(s"Created ${blocks.length} blocks")
+ val rows: List[BlockRow] = BlockRowMaker.makeRowsUpToSize(numBoundaries, blocks, Some(numBoundaries))
+ .filter(br => br.inputs.size == numBoundaries && br.outputs.size == numBoundaries)
+ UserAlerts.alert(s"Created ${rows.length} rows")
+ var unusedRows: Iterator[BlockRow] = rows.toIterator
+ var unusedStacks: Iterator[BlockStack] = rows.map(r => BlockStack(List(r))).toIterator
+ var nextRoundOfStacks: List[BlockStack] = List()
+ var currentStack: BlockStack = unusedStacks.next()
+
+ override def doWithUnmatched(a: BlockStack): Unit = {
+ nextRoundOfStacks = a :: nextRoundOfStacks
+ }
+
+
+ override def compareTensor(a: Tensor, b: Tensor): Boolean = a.isRoughlyUpToScalar(b)
+
+ override def compareGraph(left: Graph, right: Graph) : Int = GraphAnalysis.zxCircuitCompare(left, right)
+
+ override def makeTensor(gen: BlockStack): Tensor = gen.tensor
+
+ override def makeString(a: BlockStack, b: Tensor): String = s"$a: ${b.toJson},"
+
+ override def makeGraph(gen: BlockStack): Graph = {
+ val g = gen.graph.minimise
+ val IOPattern = raw"r-\d+-(i|o)-(\d+)".r
+ val renameMap: Map[VName, VName] = g.verts.map(vn => vn -> (vn.s match {
+ case IOPattern(io, n) => VName(io + "-" + n)
+ case _ => vn
+ })).toMap
+ g.rename(renameMap)
+ }
+
+ }
+
+ class CoSyZX(rulesDir: File,
+ theory: Theory,
+ duration: Duration,
+ outputDir: Option[File],
+ numAngles: Int,
+ numBoundaries: List[Int],
+ numVertices: Int,
+ scalars: Boolean
+ ) extends CoSyRun[AdjMat, Tensor](rulesDir, theory, duration, outputDir) {
+
+
+ override val Generator: Iterator[AdjMat] = {
+ val identitiesFirst = ColbournReadEnum.enumerate(1, 1, numBoundaries.max, 0).
+ filter(a => numBoundaries.contains(a.numBoundaries))
+
+ val CR = ColbournReadEnum.enumerate(numAngles, numAngles, numBoundaries.max, numVertices)
+
+ UserAlerts.alert(s"CoSy: Finished Colbourn-Read (${CR.size})")
+
+ val CRScalars = if(scalars) CR else CR.filter(adj => !GraphAnalysis.containsScalars(adj))
+
+ UserAlerts.alert(s"CoSy: Filtered out scalars (${CRScalars.size})")
+
+ val CRScalarsSorted = CRScalars.sortBy(_.size)
+
+ UserAlerts.alert("CoSy: Sorted AdjMats")
+
+ val combined = identitiesFirst.iterator ++
+ CRScalarsSorted.iterator.filter(a => numBoundaries.contains(a.numBoundaries))
+
+ combined
+ }
+
+ private val gdata = (for (i <- 0 until numAngles) yield {
+ NodeV(data = JsonObject("type" -> "Z", "value" -> angleMap(i).toString), theory = theory)
+ }).toVector
+ private val rdata = (for (i <- 0 until numAngles) yield {
+ NodeV(data = JsonObject("type" -> "X", "value" -> angleMap(i).toString), theory = theory)
+ }).toVector
+
+ override def compareTensor(a: Tensor, b: Tensor): Boolean = if (!scalars) {
+ a.isRoughlyUpToScalar(b)
+ } else {
+ a.isRoughly(b)
+ }
+
+ override val matchBorders = None
+
+ override def doWithUnmatched(a: AdjMat): Unit = {
+ // Don't need to do anything, since Colbourn-Read handles generating adj-mats
+ }
+
+ override def makeTensor(gen: AdjMat): Tensor = {
+ val asGraph = makeGraph(gen)
+ Interpreter.interpretZXGraph(asGraph, asGraph.verts.filter(asGraph.isTerminalWire).toList.sortBy(_.s), List())
+ }
+
+ override def makeGraph(gen: AdjMat): Graph = Graph.fromAdjMat(gen, rdata, gdata)
+
+ override def makeString(a: AdjMat, b: Tensor): String = {
+ s"adj${a.hash}: ${b.toJson},"
+ }
+
+ override def compareGraph(left: Graph, right: Graph) : Int = GraphAnalysis.zxGraphCompare(left, right)
+
+ private def angleMap = (x: Int) => PhaseExpression(new Rational(2 * x, numAngles), ValueType.AngleExpr)
+
+
+ }
+
+ class CoSyZXBool(rulesDir: File,
+ theory: Theory,
+ duration: Duration,
+ outputDir: Option[File],
+ numAngles: Int,
+ numBoundaries: List[Int],
+ numVertices: Int,
+ scalars: Boolean
+ ) extends CoSyRun[AdjMat, Tensor](rulesDir, theory, duration, outputDir) {
+
+
+ override val Generator: Iterator[AdjMat] =
+ (ColbournReadEnum.enumerate(1, 1, numBoundaries.max, 0).iterator ++
+ ColbournReadEnum.enumerate(2*numAngles, 2*numAngles, numBoundaries.max, numVertices).iterator).
+ filter(a => numBoundaries.contains(a.numBoundaries))
+
+
+
+ private val gdata = (for (i <- 0 until 2*numAngles) yield {
+ NodeV(data = JsonObject("type" -> "Z", "value" -> angleMap(i).toString), theory = theory)
+ }).toVector
+ private val rdata = (for (i <- 0 until 2*numAngles) yield {
+ NodeV(data = JsonObject("type" -> "X", "value" -> angleMap(i).toString), theory = theory)
+ }).toVector
+
+ override def compareTensor(a: Tensor, b: Tensor): Boolean = if (!scalars) {
+ a.isRoughlyUpToScalar(b)
+ } else {
+ a.isRoughly(b)
+ }
+
+ override val matchBorders = None
+
+ private implicit def stringToPhase(s: String): PhaseExpression = {
+ CompositeExpression.parseKnowingTypes(s, Vector(ValueType.AngleExpr, ValueType.Boolean))(1)
+ }
+
+ override def doWithUnmatched(a: AdjMat): Unit = {
+ // Don't need to do anything, since Colbourn-Read handles generating adj-mats
+ }
+
+ override def makeTensor(gen: AdjMat): Tensor = {
+ val asGraph = makeGraph(gen)
+
+ def spiderInterpreter(vdata: NodeV, inputs: Int, outputs: Int): Tensor = {
+
+ val zxData: ZXAngleData = {
+ val isGreen = vdata.typ == "Z"
+ val angle = CompositeExpression.parseKnowingTypes(vdata.value, Vector(ValueType.AngleExpr, ValueType.Boolean))
+ if (angle(1).constant == Rational(1, 1)) {
+ ZXAngleData(isGreen, angle(0))
+ } else {
+ ZXAngleData(isGreen, PhaseExpression(new Rational(0, 1), ValueType.AngleExpr))
+ }
+ }
+
+ interpretZXSpider(zxData, inputs, outputs)
+ }
+
+
+ Interpreter.interpretSpiderGraph(spiderInterpreter)(asGraph, asGraph.verts.filter(asGraph.isTerminalWire).toList.sortBy(_.s), List())
+ }
+
+ override def makeGraph(gen: AdjMat): Graph = Graph.fromAdjMat(gen, rdata, gdata)
+
+ override def makeString(a: AdjMat, b: Tensor): String = {
+ s"adj${a.hash}: ${b.toJson},"
+ }
+
+
+ //TODO: Graph comparison for ZXBool
+ override def compareGraph(left: Graph, right: Graph) : Int = RuleSynthesis.basicGraphComparison(left, right)
+
+ private def angleMap(x: Int): CompositeExpression =
+ if (x < numAngles) {
+ CompositeExpression(Vector(ValueType.AngleExpr, ValueType.Boolean),
+
+ Vector(PhaseExpression(new Rational(2 * x, numAngles), ValueType.AngleExpr),
+ PhaseExpression(new Rational(1, 1), ValueType.Boolean)))
+
+ } else {
+ CompositeExpression(Vector(ValueType.AngleExpr, ValueType.Boolean),
+
+ Vector(PhaseExpression(new Rational(2 * (x-numAngles), numAngles), ValueType.AngleExpr),
+ PhaseExpression(new Rational(0, 1), ValueType.Boolean)))
+
+ }
+
+
+ }
+}
+
diff --git a/scala/src/main/scala/quanto/cosy/ColbournReadEnum.scala b/scala/src/main/scala/quanto/cosy/ColbournReadEnum.scala
index dd25db3d..520c9bc2 100644
--- a/scala/src/main/scala/quanto/cosy/ColbournReadEnum.scala
+++ b/scala/src/main/scala/quanto/cosy/ColbournReadEnum.scala
@@ -34,21 +34,20 @@ case class AdjMat(numRedTypes: Int,
mat: Vector[Vector[Boolean]] = Vector())
extends Ordered[AdjMat] {
lazy val size: Int = mat.length
- lazy val numRed : Int = red.sum
- lazy val numGreen : Int = green.sum
- lazy val hash : String = makeHash()
- def toJson : JsonObject = JsonObject("hash" -> hash)
+ lazy val numRed: Int = red.sum
+ lazy val numGreen: Int = green.sum
+ lazy val hash: String = makeHash()
lazy val vertexColoursAndTypes: List[(VertexColour.EnumVal, Int)] = {
var _vertexColoursAndTypes: List[(VertexColour.EnumVal, Int)] = List()
var colCount = 0
var angleTypeCount = 0
- for (i <- 0 until numBoundaries) {
+ for (_ <- 0 until numBoundaries) {
_vertexColoursAndTypes = (VertexColour.Boundary, 0) :: _vertexColoursAndTypes
colCount += 1
}
for (j <- red) {
- for (i <- 0 until j) {
+ for (_ <- 0 until j) {
_vertexColoursAndTypes = (VertexColour.Red, angleTypeCount) :: _vertexColoursAndTypes
colCount += 1
}
@@ -57,7 +56,7 @@ case class AdjMat(numRedTypes: Int,
angleTypeCount = 0
for (j <- green) {
- for (i <- 0 until j) {
+ for (_ <- 0 until j) {
_vertexColoursAndTypes = (VertexColour.Green, angleTypeCount) :: _vertexColoursAndTypes
colCount += 1
}
@@ -66,6 +65,8 @@ case class AdjMat(numRedTypes: Int,
_vertexColoursAndTypes.reverse
}
+ def toJson: JsonObject = JsonObject("hash" -> hash)
+
// advance to the next type of vertex added by the addVertex method. The order is boundaries,
// then each red type, then each green type.
def nextType: Option[AdjMat] = {
@@ -76,7 +77,7 @@ case class AdjMat(numRedTypes: Int,
// This method grows the adjacency matrix by adding a new boundary, red node, or green node, with the given
// vector of edges.
- def addVertex(connection: Vector[Boolean]) : AdjMat = {
+ def addVertex(connection: Vector[Boolean]): AdjMat = {
if (red.isEmpty && green.isEmpty) { // new vertex is a boundary
copy(numBoundaries = numBoundaries + 1, mat = growMatrix(connection))
} else if (red.nonEmpty && green.isEmpty) { // new vertex is a red node
@@ -110,15 +111,6 @@ case class AdjMat(numRedTypes: Int,
// a matrix is canonical if it is lexicographically smaller than any vertex permutation
def isCanonical(permuteBoundary: Boolean = false): Boolean = validPerms(permuteBoundary).forall { p => compareWithPerm(p) <= 0 }
- // compare this matrix with itself, but with the rows and columns permuted according to "perm"
- def compareWithPerm(perm: Vector[Int]): Int = {
- for (i <- 0 until size)
- for (j <- 0 to i)
- if (mat(i)(j) < mat(perm(i))(perm(j))) return -1
- else if (mat(i)(j) > mat(perm(i))(perm(j))) return 1
- 0
- }
-
// return all the vertex-permutations which respect type and keep boundary fixed
def validPerms(permuteBoundary: Boolean): Vector[Vector[Int]] = {
var idx = numBoundaries
@@ -186,6 +178,15 @@ case class AdjMat(numRedTypes: Int,
false
}
+ // compare this matrix with itself, but with the rows and columns permuted according to "perm"
+ def compareWithPerm(perm: Vector[Int]): Int = {
+ for (i <- 0 until size)
+ for (j <- 0 to i)
+ if (mat(i)(j) < mat(perm(i))(perm(j))) return -1
+ else if (mat(i)(j) > mat(perm(i))(perm(j))) return 1
+ 0
+ }
+
// returns true if all boundaries are connected to something
def isComplete: Boolean = (0 until numBoundaries).forall(i => mat(i).contains(true))
@@ -248,7 +249,7 @@ object AdjMat {
case _ => Vector(Vector())
}
- def fromHash(hash: String): AdjMat = {
+ def fromHash(hash: String): AdjMat = {
// "boundaries.red1-red2.green1-green2.matBase36"
val dotChunk = hash.split("\\.")
val numBoundaries = dotChunk(0).toInt
@@ -256,10 +257,10 @@ object AdjMat {
val green: Vector[Int] = dotChunk(2).split("-").map(a => a.toInt).toVector
val size = numBoundaries + red.sum + green.sum
val longMatStringUnpadded = java.lang.Long.toString(java.lang.Long.parseLong(dotChunk(3), 36), 2)
- val longMatString = (1 to size * size - longMatStringUnpadded.length).foldLeft("") {(a,b) => "0" + a} +
+ val longMatString = (1 to size * size - longMatStringUnpadded.length).foldLeft("") { (a, _) => "0" + a } +
longMatStringUnpadded
val longMatVec = longMatString.map(x => x == '1').toVector
- val mat = if(size > 0) longMatVec.grouped(size).toVector else List().toVector
+ val mat = if (size > 0) longMatVec.grouped(size).toVector else List().toVector
new AdjMat(red.length, green.length, numBoundaries, red, green, mat)
}
}
diff --git a/scala/src/main/scala/quanto/cosy/EQCAnalysis.scala b/scala/src/main/scala/quanto/cosy/EQCAnalysis.scala
index 24028e90..5804cf9f 100644
--- a/scala/src/main/scala/quanto/cosy/EQCAnalysis.scala
+++ b/scala/src/main/scala/quanto/cosy/EQCAnalysis.scala
@@ -1,7 +1,5 @@
package quanto.cosy
-import quanto.data._
-import quanto.util.Rational
/**
* Analyse the properties of equivalence classes
diff --git a/scala/src/main/scala/quanto/cosy/EquivClassBatchRunner.scala b/scala/src/main/scala/quanto/cosy/EquivClassBatchRunner.scala
index e8e4c9ce..2f83a2e3 100644
--- a/scala/src/main/scala/quanto/cosy/EquivClassBatchRunner.scala
+++ b/scala/src/main/scala/quanto/cosy/EquivClassBatchRunner.scala
@@ -8,6 +8,7 @@ import quanto.util.json.{JsonArray, JsonObject}
/**
* Created by hector on 20/06/2017.
* A wrapper object for the most common qsynth run
+ * Superseded by the CoSyRun system
*/
object EquivClassBatchRunner {
@@ -15,7 +16,7 @@ object EquivClassBatchRunner {
def apply(numAngles: Int = 8, boundaries: Int = 3, vertices: Int, outputFileName: String = "default.qrun"): Unit = {
val rg = Theory.fromFile("red_green")
- var results = EquivClassRunAdjMat(
+ val results = EquivClassRunAdjMat(
numAngles = numAngles,
tolerance = EquivClassRunAdjMat.defaultTolerance,
rulesList = List(),
@@ -26,7 +27,7 @@ object EquivClassBatchRunner {
s"ColbournRead $numAngles $numAngles $boundaries $vertices")
new File(outputPath).mkdirs()
- var testFile = new File(outputPath + "/" + outputFileName)
+ val testFile = new File(outputPath + "/" + outputFileName)
quanto.util.FileHelper.printToFile(testFile, append = false)(
p => p.println(results.toJSON.toString())
)
@@ -43,7 +44,7 @@ object TensorBatchRunner {
val rg = Theory.fromFile("red_green")
val diagramStream = ColbournReadEnum.enumerate(numAngles, numAngles, boundaries, vertices)
- var results = EquivClassRunAdjMat(
+ val results = EquivClassRunAdjMat(
numAngles = numAngles,
tolerance = EquivClassRunAdjMat.defaultTolerance,
rulesList = List(),
diff --git a/scala/src/main/scala/quanto/cosy/EquivalenceClasses.scala b/scala/src/main/scala/quanto/cosy/EquivalenceClasses.scala
index 212b93ee..78e2edf3 100644
--- a/scala/src/main/scala/quanto/cosy/EquivalenceClasses.scala
+++ b/scala/src/main/scala/quanto/cosy/EquivalenceClasses.scala
@@ -2,14 +2,13 @@ package quanto.cosy
import quanto.cosy.Interpreter._
import quanto.data._
+import quanto.util.Rational
import quanto.util.json.{Json, JsonAccessException, JsonArray, JsonObject}
-import scala.concurrent.{Await, Future}
-import scala.concurrent.ExecutionContext.Implicits.global
-import scala.concurrent.duration.Duration
-
/**
* Synthesises diagrams, holds the data and generates equivalence classes
+ *
+ * Everything here has essentially be superseded by the new CoSyRun system
*/
@@ -167,6 +166,11 @@ abstract class EquivClassRun[T](val tolerance: Double = 1e-14) {
this
}
+ def add(that: T, tensor: Tensor): EquivalenceClass[T] = {
+ compareAndAddToClass(that, tensor)
+ compareAndAddToClass(that, tensor, normalised = true)
+ }
+
// Finds the closest class and adds it, or creates a new class if outside tolerance
def compareAndAddToClass(candidate: T, tensor: Tensor, normalised: Boolean = false): EquivalenceClass[T] = {
@@ -190,6 +194,7 @@ abstract class EquivClassRun[T](val tolerance: Double = 1e-14) {
}
}
+ implicit def interpret(that: T): Tensor
// Returns (new Class, -1) if no EquivalenceClasses here
def closestClassTo(that: Tensor, normalised: Boolean = false): Option[(EquivalenceClass[T], Double)] = {
@@ -204,18 +209,11 @@ abstract class EquivClassRun[T](val tolerance: Double = 1e-14) {
}
- implicit def interpret(that: T): Tensor
-
def add(that: T): EquivalenceClass[T] = {
val tensor = interpret(that)
add(that, tensor)
}
- def add(that: T, tensor: Tensor): EquivalenceClass[T] = {
- compareAndAddToClass(that, tensor)
- compareAndAddToClass(that, tensor, normalised = true)
- }
-
def newClass(t: T, tensor: Tensor, normalised: Boolean): EquivalenceClass[T]
}
@@ -299,7 +297,7 @@ object EquivClassRunBlockStack {
Tensor.fromJson((js / "tensor").asObject))
)
val ecrr = new EquivClassRunBlockStack(tolerance)
- for ((adj, ten) <- results) {
+ for ((adj, _) <- results) {
ecrr.add(adj)
}
ecrr
@@ -365,7 +363,9 @@ object EquivClassRunAdjMat {
tolerance: Double,
theory: Theory,
rulesList: List[Rule]): EquivClassRunAdjMat = {
- def angleMap = (x: Int) => x * math.Pi * 2.0 / numAngles
+ def angleMap(x: Int): Rational = {
+ new Rational(x, numAngles)
+ }
val gdata = (for (i <- 0 until numAngles) yield {
NodeV(data = JsonObject("type" -> "Z", "value" -> angleMap(i).toString), theory = theory)
diff --git a/scala/src/main/scala/quanto/cosy/GraphAnalysis.scala b/scala/src/main/scala/quanto/cosy/GraphAnalysis.scala
index d62eea80..9a1f3108 100644
--- a/scala/src/main/scala/quanto/cosy/GraphAnalysis.scala
+++ b/scala/src/main/scala/quanto/cosy/GraphAnalysis.scala
@@ -1,9 +1,230 @@
package quanto.cosy
+import quanto.data.Theory.{ValueType, VertexDesc}
import quanto.data._
+import quanto.rewrite.Matcher
import quanto.util.Rational
+import quanto.util.json.{Json, JsonObject}
+
+import scala.util.matching.Regex
+import scala.util.parsing.combinator.RegexParsers
object GraphAnalysis {
+
+ def zxCircuitCompare(left: Graph, right: Graph): Int = {
+
+ // returns x where
+ // x < 0 iff this < that
+ // x == 0 iff this == that
+ // x > 0 iff this > that
+
+ // Circuit comparison of graphs
+ // Cares about T-count
+
+ val Angle = ValueType.AngleExpr
+
+ def phase(vdata: VData): PhaseExpression =
+ vdata match {
+ case NodeV(d, a, t) => vdata.asInstanceOf[NodeV].phaseData.first[PhaseExpression](ValueType.AngleExpr) match {
+ case Some(p) => p
+ case None => PhaseExpression.zero(ValueType.AngleExpr)
+ }
+ case _ => PhaseExpression.zero(ValueType.AngleExpr)
+ }
+
+
+ // Number of T-gates
+ def countT(graph: Graph): Int = graph.vdata.count(nd => {
+ val const = phase(nd._2).constant
+ // T-gate if an odd multiple of \pi/4
+ ((const.n * (const.d / 4)) % 2) == 1
+ })
+
+ val tDiff = countT(left) - countT(right)
+ if (tDiff != 0) return tDiff
+
+ // Number of nodes
+ def nodes(graph: Graph): Int = graph.vdata.size
+
+ val nodeDiff = nodes(left) - nodes(right)
+ if (nodeDiff != 0) return nodeDiff
+
+ // Number of edges
+ def edges(graph: Graph): Int = graph.edata.size
+
+ val edgeDiff = edges(left) - edges(right)
+ if (edgeDiff != 0) return edgeDiff
+
+ // Number of "Z" nodes
+ // We favour these!
+ // Purely for aesthetic reasons
+ def countZ(graph: Graph): Int = graph.vdata.count(nd => nd._2.typ == "Z")
+
+ val zDiff = countZ(left) - countZ(right)
+ if (zDiff != 0) return zDiff
+
+ // sum of the phases
+ def phaseSum(graph: Graph): Rational = graph.vdata.map(nd => phase(nd._2)).
+ foldLeft(Rational(0, 1)) { (s, a) => s + a.constant }
+
+ val phaseDiff = phaseSum(left) - phaseSum(right)
+ if (phaseDiff > 0) return 1
+ if (phaseDiff < 0) return -1
+
+ // Weighting by row
+ // example node name: r-2-bl-1-h-1
+ def positionWeighting(graph: Graph): Double = {
+ def nodeWeighting(node: NodeV): Double = {
+ (node.typ match {
+ case "X" =>
+ 2 * (1 + node.phaseData.values.head.constant)
+ case "Z" =>
+ 1 + node.phaseData.values.head.constant
+ case "hadamard" =>
+ 1
+ }) / 3 // Scale so any given node has weight at most 1, bust still > 0
+ }
+
+ graph.vdata.toList.map(nameData => {
+ nameData._2 match {
+ case v: NodeV =>
+ // Non-wire nodes are weighted biased towards the bottom left
+ val placement = CircuitPlacementParser.p(nameData._1.toString)
+ placement._1 + placement._2 + nodeWeighting(v)
+ case _ =>
+ // wires have no weight
+ 0
+ }
+ }).sum
+ }
+
+ val circuitWeightLeft = positionWeighting(left)
+ val circuitWeightRight = positionWeighting(right)
+ if (circuitWeightLeft > circuitWeightRight) {
+ return 1
+ }
+ if (circuitWeightLeft < circuitWeightRight) {
+ return -1
+ }
+
+ 0
+ }
+
+ case class CircuitPlacementParseException(input: String) extends Error
+
+ object CircuitPlacementParser extends RegexParsers {
+
+ override def skipWhitespace = true
+
+ def INT: Parser[Int] =
+ """[0-9]+""".r ^^ {
+ _.toInt
+ }
+
+ def IDENT: Parser[String] =
+ """[\\a-zA-Z_][a-zA-Z0-9_]*""".r ^^ {
+ _.toString
+ }
+
+ // example node name: r-2-bl-1-h-1
+ def expr: Parser[(Int, Int, String)] =
+ "r-" ~ INT ~ "-bl-" ~ INT ~ "-" ~ IDENT ~ "-" ~ INT ^^ { case _ ~ r ~ _ ~ bl ~ _ ~ s ~ _ ~ _ => (r, bl, s) }
+
+ def p(s: String): (Int, Int, String) = parseAll(expr, s) match {
+ case Success(e, _) => e
+ case Failure(msg, _) => throw CircuitPlacementParseException(msg)
+ case Error(msg, _) => throw CircuitPlacementParseException(msg)
+ }
+ }
+
+
+ def zxGraphCompare(left: Graph, right: Graph): Int = {
+
+ // returns x where
+ // x < 0 iff this < that
+ // x == 0 iff this == that
+ // x > 0 iff this > that
+
+
+ // Graph comparison for ZX diagrams
+ // Cares about node count, then phases
+
+ implicit def stringToPhase(s: String): PhaseExpression = {
+ PhaseExpression.parse(s, ValueType.AngleExpr)
+ }
+
+
+ // First count number of nodes
+ def nodes(graph: Graph): Int = graph.vdata.size
+
+ val node = nodes(left) - nodes(right)
+ if (node != 0) return node
+
+ // Number of edges
+ def edges(graph: Graph): Int = graph.edata.size
+
+ val edge = edges(left) - edges(right)
+ if (edge != 0) return edge
+
+ // Number of "Z" nodes
+ def countZ(graph: Graph): Int = graph.vdata.count(nd => nd._2.typ == "Z")
+
+ val zDiff = countZ(left) - countZ(right)
+ if (zDiff != 0) return zDiff
+
+ // Sum of Z angles
+ val Pi = math.Pi
+
+ def sumAngles(graph: Graph, filterType: String): Rational = graph.vdata.
+ filter(nd => nd._2.typ == filterType).
+ foldLeft(Rational(0, 1)) {
+ (angle, nd) => angle + stringToPhase(nd._2.asInstanceOf[NodeV].value).constant
+ }
+
+ // sumAngles returns a rational that is probably bigger than 2 (remember that the pi is left out)
+
+ val ZAngles: Rational = sumAngles(left, "Z") - sumAngles(right, "Z")
+ if (ZAngles > 0) return 1
+ if (ZAngles < 0) return -1
+
+ // Sum of X angles
+
+ val XAngles: Rational = sumAngles(left, "X") - sumAngles(right, "X")
+ if (XAngles > 0) return 1
+ if (XAngles < 0) return -1
+
+
+ 0
+ }
+
+ def connectionClasses(adjMat: AdjMat): Vector[Int] = {
+ val initialVector = (0 until adjMat.size).toVector
+
+ def join(v: Vector[Int], i: Int, j: Int): Vector[Int] = {
+ v.map { a => if (a == v(i) || a == v(j)) {
+ math.min(v(i), v(j))
+ } else a
+ }
+ }
+
+ adjMat.mat.zipWithIndex.flatMap(rowWithIndex => {
+
+ val rci = rowWithIndex._1.zipWithIndex
+ rci.map(bi => (bi._1, bi._2, rowWithIndex._2))
+ }).filter(_._1).map(bii => (bii._2, bii._3)).foldLeft(initialVector) {
+ (v, ii) => join(v, ii._1, ii._2)
+ }
+ }
+
+ def containsScalars(adjMat: AdjMat): Boolean = {
+ // Boundaries are at the front of the adjmat, and are given labels first
+ // So if any class still has labels higher than the number of boundaries
+ // it can't be connected to any of the boundaries
+ val cc = connectionClasses(adjMat)
+ !cc.forall(_ < adjMat.numBoundaries)
+ }
+
+
type BMatrix = Vector[Vector[Boolean]]
def tensorToBooleanMatrix(tensor: Tensor): BMatrix = {
@@ -20,39 +241,40 @@ object GraphAnalysis {
def namesToIndex(name: VName) = vertexList.indexOf(name)
val targets = ends.map(namesToIndex).toSet
- val errorNames = detectErrors(graph)
+ val errorNames = detectPiNodes(graph)
val errors = errorNames.map(namesToIndex)
val rawAdjacencyMatrix = adjacencyMatrix(graph)
- val bypassedAdjacencyMatrix = bypassSpecial(detectErrors)(graph, rawAdjacencyMatrix)
+ val bypassedAdjacencyMatrix = bypassSpecial(detectPiNodes)(graph, rawAdjacencyMatrix)
pathDistanceSet(bypassedAdjacencyMatrix, errors, targets) match {
case None => None
- case Some(d) => Some(d-1) //account for errors being over-counted in standard distance
+ case Some(d) => Some(d - 1) //account for errors being over-counted in standard distance
}
}
- def distanceOfSingleErrorFromEnd(ends: Set[VName])(graph : Graph, vNames : Set[VName]) : Option[Double] = {
+ def distanceOfSingleErrorFromEnd(ends: Set[VName])(graph: Graph, vNames: Set[VName]): Option[Double] = {
val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(graph)
val vertexList = graph.verts.toList
def namesToIndex(name: VName) = vertexList.indexOf(name)
- val targets = ends.map(namesToIndex).toSet
- val errorNames = detectErrors(graph)
+ val targets = ends.map(namesToIndex)
+ // val errorNames = detectPiNodes(graph)
- val errors = errorNames.map(namesToIndex)
- val bypassedAdjacencyMatrix = bypassSpecial(detectErrors)(graph, adjacencyMatrix)
+ // val errors = errorNames.map(namesToIndex)
+ // val bypassedAdjacencyMatrix = bypassSpecial(detectPiNodes)(graph, adjacencyMatrix)
+
implicit def nameToInt(name: VName): Int = {
adjacencyMatrix._1.indexOf(name)
}
GraphAnalysis.pathDistanceSet(adjacencyMatrix, vNames.map(nameToInt), targets) match {
case None => None
- case Some(d) => Some(d-1)
+ case Some(d) => Some(d - 1)
}
}
@@ -81,48 +303,12 @@ object GraphAnalysis {
(matrixWithNames._1, bypassedMatrix)
}
- def detectErrors(graph: Graph): Set[VName] = {
+ def detectPiNodes(graph: Graph): Set[VName] = {
graph.verts.
filterNot(name => graph.vdata(name).isBoundary || graph.vdata(name).isWireVertex).
- filter(name => graph.vdata(name).asInstanceOf[NodeV].angle.equals(AngleExpression(Rational(1))))
- }
-
- def distanceSpecialFromEnds(specials: List[VName])(ends: List[VName])(graph: Graph): Option[Double] = {
-
- val vertexList = graph.verts.toList
-
- def namesToIndex(name: VName) = vertexList.indexOf(name)
-
- val targets = ends.map(namesToIndex).toSet
- // Called errors for historical reasons
- val errors = specials.map(namesToIndex).toSet
- val aMatrix = adjacencyMatrix(graph)
- pathDistanceSet(aMatrix, errors, targets)
- }
-
- def pathConnectionMatrices(graph: Graph): List[(Int, Tensor)] = {
- val matrixWithNames = adjacencyMatrix(graph)
- pathConnectionMatrices(matrixWithNames)
- }
-
- def pathConnectionMatrices(matrixWithNames: (List[VName], BMatrix)): List[(Int, Tensor)] = {
- val adjTensor = booleanMatrixToTensor(matrixWithNames._2)
-
- var rollingPower = Tensor.id(adjTensor.width)
- (for (i <- matrixWithNames._1.indices) yield {
- (i, {
- rollingPower = rollingPower.compose(adjTensor)
- rollingPower
- })
- }).toList
- }
-
- def booleanMatrixToTensor(bMatrix: BMatrix): Tensor = {
- def complexify(v: Vector[Boolean]): Array[Complex] = v.map(b => if (b) Complex(1, 0) else Complex(0, 0)).toArray
-
- def complexify2(v: Vector[Vector[Boolean]]): Array[Array[Complex]] = v.map(b => complexify(b)).toArray
-
- Tensor(complexify2(bMatrix))
+ filter(name => graph.vdata(name).asInstanceOf[NodeV].phaseData.
+ firstOrError(ValueType.AngleExpr).
+ equals(PhaseExpression(Rational(1), ValueType.AngleExpr))) // Pull out those angle expressions with value \pi
}
def adjacencyMatrix(graph: Graph): (List[VName], BMatrix) = {
@@ -135,12 +321,24 @@ object GraphAnalysis {
(vertexNames, vertexNames.foldLeft(Vector[Vector[Boolean]]())((vs, v) => vs :+ setToVector(graph.adjacentVerts(v))))
}
+ def pathDistanceSet(matrixWithNames: (List[VName], BMatrix),
+ distancesToFind: Set[Int],
+ measuredFrom: Set[Int]):
+ Option[Double] = {
- def neighbours(matrixWithNames: (List[VName], BMatrix), target: VName) : Set[VName] = {
val names = matrixWithNames._1
- val matrix = matrixWithNames._2
- val index = names.indexOf(target)
- matrix(index).zipWithIndex.filter(_._1).map(_._2).map(names(_)).toSet
+ // val matrix = matrixWithNames._2
+
+ val distances = distancesFromInitial(matrixWithNames, Set(), measuredFrom.map(i => names(i)))
+
+ val importantDistances = distances.filter(nd => distancesToFind.contains(names.indexOf(nd._1)))
+
+ if (importantDistances.nonEmpty) {
+ val d = intsWithCount(importantDistances.values.toList)
+ if (d < 0) None else Some(d)
+ } else {
+ None
+ }
}
def distancesFromInitial(matrixWithNames: (List[VName], BMatrix), ignoring: Set[VName], initials: Set[VName]):
@@ -177,7 +375,6 @@ object GraphAnalysis {
names.zipWithIndex.map(name_index => name_index._1 -> finalDistances(name_index._2)).toMap
}
-
private def intsWithCount(ints: List[Int]): Double = {
val max = ints.max
val count = max match {
@@ -188,23 +385,147 @@ object GraphAnalysis {
max + ((count - 1).toDouble / count.toDouble)
}
- def pathDistanceSet(matrixWithNames : (List[VName], BMatrix),
- distancesToFind: Set[Int],
- measuredFrom: Set[Int]):
- Option[Double] = {
+ def distanceSpecialFromEnds(specials: List[VName])(ends: List[VName])(graph: Graph): Option[Double] = {
+
+ val vertexList = graph.verts.toList
+
+ def namesToIndex(name: VName) = vertexList.indexOf(name)
+
+ val targets = ends.map(namesToIndex).toSet
+ // Called errors for historical reasons
+ val errors = specials.map(namesToIndex).toSet
+ val aMatrix = adjacencyMatrix(graph)
+ pathDistanceSet(aMatrix, errors, targets)
+ }
+
+ def pathConnectionMatrices(graph: Graph): List[(Int, Tensor)] = {
+ val matrixWithNames = adjacencyMatrix(graph)
+ pathConnectionMatrices(matrixWithNames)
+ }
+
+ def pathConnectionMatrices(matrixWithNames: (List[VName], BMatrix)): List[(Int, Tensor)] = {
+ val adjTensor = booleanMatrixToTensor(matrixWithNames._2)
+
+ var rollingPower = Tensor.id(adjTensor.width)
+ (for (i <- matrixWithNames._1.indices) yield {
+ (i, {
+ rollingPower = rollingPower.compose(adjTensor)
+ rollingPower
+ })
+ }).toList
+ }
+
+ def booleanMatrixToTensor(bMatrix: BMatrix): Tensor = {
+ def complexify(v: Vector[Boolean]): Array[Complex] = v.map(b => if (b) Complex(1, 0) else Complex(0, 0)).toArray
+
+ def complexify2(v: Vector[Vector[Boolean]]): Array[Array[Complex]] = v.map(b => complexify(b)).toArray
+ Tensor(complexify2(bMatrix))
+ }
+
+ def neighbours(matrixWithNames: (List[VName], BMatrix), target: VName): Set[VName] = {
val names = matrixWithNames._1
- // val matrix = matrixWithNames._2
+ val matrix = matrixWithNames._2
+ val index = names.indexOf(target)
+ matrix(index).zipWithIndex.filter(_._1).map(_._2).map(names(_)).toSet
+ }
- val distances = distancesFromInitial(matrixWithNames, Set(), measuredFrom.map(i => names(i)))
+ def boundariesFromRegex(graph: Graph, regex: Option[Regex]) : Set[VName] = {
+ regex match {
+ case Some(r) => graph.verts.filter(vn => vn.s.matches(r.regex))
+ case None => graph.verts.filter(vn => graph.vdata(vn).isBoundary)
+ }
+ }
- val importantDistances = distances.filter(nd => distancesToFind.contains(names.indexOf(nd._1)))
+ def checkIsomorphic(theory: Theory = Theory.DefaultTheory, boundaryByRegex: Option[Regex] = None)
+ (g1: Graph, g2: Graph): Boolean = {
+ // Check whether graphs are isomorphic, after constraining their boundaries
- if (importantDistances.nonEmpty) {
- val d = intsWithCount(importantDistances.values.toList)
- if(d < 0) None else Some(d)
+ if(g1.verts.size != g2.verts.size) return false
+
+ // Currently the .isBoundary on wires is set by JSON, not programmatically, and as such I don't trust it.
+ def borderNodes(g: Graph): Set[VName] = boundariesFromRegex(g, boundaryByRegex)
+
+ val overlappingNodes = borderNodes(g1).intersect(borderNodes(g2))
+
+ //First check they have the same number of boundaries, and the same names between them.
+
+ if(overlappingNodes != borderNodes(g1) || overlappingNodes != borderNodes(g2)){
+ return false
+ }
+
+ val dummyVertexDesc = VertexDesc.fromJson(
+ Json.parse("""{
+ | "value": {
+ | "type": "empty",
+ | "latex_constants": true,
+ | "validate_with_core": false
+ | },
+ | "style": {
+ | "label": {
+ | "position": "center",
+ | "fg_color": [
+ | 0.0,
+ | 0.0,
+ | 0.0
+ | ]
+ | },
+ | "stroke_color": [
+ | 0.0,
+ | 0.0,
+ | 0.0
+ | ],
+ | "fill_color": [
+ | 0.0,
+ | 1.0,
+ | 1.0
+ | ],
+ | "shape": "rectangle"
+ | },
+ | "default_data": {
+ | "type": "dummyBoundary",
+ | "value": ""
+ | }
+ | } """.stripMargin))
+
+ def enforceUnique(s: String) : String = if(theory.vertexTypes.keys.exists(_ == s)){
+ enforceUnique(s + "1")
} else {
- None
+ s
+ }
+ val dummyVertexName: String = enforceUnique("dummyBoundary")
+ val dummyTheory = theory.mixin(newVertexTypes = Map(dummyVertexName -> dummyVertexDesc))
+ val boundaryData = NodeV(JsonObject(
+ "type" -> dummyVertexName,
+ "value" -> ""
+ ),
+ JsonObject(),
+ dummyTheory)
+
+ def makeSolidBoundaries(graph: Graph): Graph = {
+ overlappingNodes.foldLeft(graph.copy(data = graph.data.copy(theory = dummyTheory))) { (g, vn) =>
+ g.updateVData(vn) { _ => boundaryData }
+ }
}
+
+ val solid1 = makeSolidBoundaries(g1)
+ val solid2 = makeSolidBoundaries(g2)
+
+ // The graphs are isomorphic if there are matches in both directions that are the identity on boundaries
+
+ val matches12 = Matcher.findMatches(solid1, solid2).filter(
+ m => overlappingNodes.forall(
+ vn => m.map.v.dom.toList.contains(vn) && m.map.v.domf(vn).contains(vn)
+ )
+ )
+
+ val matches21 = Matcher.findMatches(solid2, solid1).filter(
+ m => overlappingNodes.forall(
+ vn => m.map.v.dom.toList.contains(vn) && m.map.v.domf(vn).contains(vn)
+ )
+ )
+
+ matches21.nonEmpty && matches12.nonEmpty
}
+
}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/cosy/Interpreter.scala b/scala/src/main/scala/quanto/cosy/Interpreter.scala
index 7d634e65..49fa2143 100644
--- a/scala/src/main/scala/quanto/cosy/Interpreter.scala
+++ b/scala/src/main/scala/quanto/cosy/Interpreter.scala
@@ -1,6 +1,8 @@
package quanto.cosy
+import quanto.data.Theory.ValueType
import quanto.data._
+import quanto.util.json.JsonObject
/**
* An interpreter is given a diagram (as an adjMat and variable assignment) and returns a tensor
@@ -8,8 +10,7 @@ import quanto.data._
object Interpreter {
// Converts a given graph or spider into tensor form
- type cachedSpiders = collection.mutable.Map[String, Tensor]
- type AngleMap = Int => Double
+ private type cachedSpiders = collection.mutable.Map[String, Tensor]
val cached: cachedSpiders = collection.mutable.Map.empty[String, Tensor]
def makeHadamards(n: Int, current: Tensor = Tensor.id(1)): Tensor = n match {
@@ -18,52 +19,67 @@ object Interpreter {
case _ => Tensor.hadamard x makeHadamards(n - 1, current)
}
- def interpretZXSpider(green: Boolean, angle: Double, inputs: Int, outputs: Int): Tensor = {
+ def interpretZXSpider(zxAngleData: ZXAngleData, inputs: Int, outputs: Int): Tensor = {
// Converts spider to tensor. If green==false then it is a red spider
- val toString = "ZX:" + green.toString + ":" + angle + ":" + inputs + ":" + outputs
+
+ val colour = if (zxAngleData.isGreen) {
+ "green"
+ } else {
+ "red"
+ }
+ val angle = zxAngleData.angle.constant
+ val toString = s"ZX:$colour:$angle:$inputs:$outputs"
+
+
if (cached.contains(toString)) cached(toString) else {
def gen(i: Int, j: Int): Complex = {
Complex.zero +
(if (i == 0 && j == 0) Complex.one else Complex.zero) +
(if (i == math.pow(2, outputs) - 1 && j == math.pow(2, inputs) - 1)
- Complex(math.cos(angle), math.sin(angle)) else Complex.zero)
+ Complex(math.cos(angle * math.Pi), math.sin(angle * math.Pi)) else Complex.zero)
}
val mid = Tensor(math.pow(2, outputs).toInt, math.pow(2, inputs).toInt, gen)
- val spider = if (green) mid else makeHadamards(outputs) o mid o makeHadamards(inputs)
+ val spider = if (zxAngleData.isGreen) mid else makeHadamards(outputs) o mid o makeHadamards(inputs)
cached += (toString -> spider)
spider
}
}
- def interpretZWSpider(black: Boolean, outputs: Int): Tensor = {
+
+ implicit def stringToPhase(s: String): PhaseExpression = {
+ PhaseExpression.parse(s, ValueType.AngleExpr)
+ }
+
+ // ASSUME EVERYTHING HAS ANGLE DATA
+ implicit def pullOutAngleData(composite: CompositeExpression): PhaseExpression = {
+ composite.firstOrError(ValueType.AngleExpr)
+ }
+
+ def interpretZWSpiderNoInputs(black: Boolean, outputs: Int): Tensor = {
require(outputs >= 0)
val toString = "ZW:" + black.toString + ":" + outputs
val spider = if (cached.contains(toString)) cached(toString) else {
- black match {
- case true =>
- // Black spider
- outputs match {
- case 0 => Tensor(Array(Array(0)))
- case 1 => Tensor(Array(Array(0, 1))).transpose
- case 2 => Tensor(Array(Array(0, 1, 1, 0))).transpose
- case 3 => Tensor(Array(Array(0, 1, 1, 0, 1, 0, 0, 0))).transpose
- case _ =>
- val bY = Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 0))).transpose
- val base = interpretZWSpider(black, outputs - 1)
- (Tensor.idWires(outputs - 2) x bY) o base
- }
-
- case false =>
- // White spider
- outputs match {
- case 0 => Tensor(Array(Array(0)))
- case _ =>
- Tensor(1,
- Math.pow(2, outputs).toInt,
- (i, j) => (if (j == 0) Complex.one else Complex.zero) -
- (if (j == Math.pow(2, outputs) - 1) Complex.one else Complex.zero)).transpose
- }
+ if (black) {
+ outputs match {
+ case 0 => Tensor(Array(Array(0)))
+ case 1 => Tensor(Array(Array(0, 1))).transpose
+ case 2 => Tensor(Array(Array(0, 1, 1, 0))).transpose
+ case 3 => Tensor(Array(Array(0, 1, 1, 0, 1, 0, 0, 0))).transpose
+ case _ =>
+ val bY = Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 0))).transpose
+ val base = interpretZWSpiderNoInputs(black, outputs - 1)
+ (Tensor.idWires(outputs - 2) x bY) o base
+ }
+ } else {
+ outputs match {
+ case 0 => Tensor(Array(Array(0)))
+ case _ =>
+ Tensor(1,
+ Math.pow(2, outputs).toInt,
+ (_, j) => (if (j == 0) Complex.one else Complex.zero) -
+ (if (j == Math.pow(2, outputs) - 1) Complex.one else Complex.zero)).transpose
+ }
}
}
@@ -75,7 +91,240 @@ object Interpreter {
interpretZXAdjMatSpidersFirst(adjMat, greenAM, redAM)
}
- def interpretZWAdjMat(adjMat: AdjMat): Tensor = interpretZWAdjMatSpidersFirst(adjMat)
+ def zwSpiderInterpreter(nodeV: NodeV, inputs: Int, outputs: Int) : Tensor = {
+ (inputs, outputs) match {
+ case (0, 0) => Tensor.id(1)
+ case (0, n) => interpretZWSpiderNoInputs( nodeV.typ.toLowerCase == "w",n)
+ case (m, n) =>
+ (Tensor.idWires(n) x Tensor(Array(Array(1, 0, 0, 1)))) o
+ (zwSpiderInterpreter(nodeV, m-1, n+1) x Tensor.idWires(1))
+ }
+ }
+
+ def interpretZWAdjMat(adjMat: AdjMat, inputs: List[VName], outputs: List[VName]): Tensor = {
+ val blackNodes = Vector(
+ NodeV(
+ JsonObject(
+ "type" -> "w",
+ "value" -> ""
+ )
+ )
+ )
+ val whiteNodes = Vector(
+ NodeV(
+ JsonObject(
+ "type" -> "z",
+ "value" -> "1"
+ )
+ )
+ )
+ val graph = Graph.fromAdjMat(adjMat, blackNodes, whiteNodes)
+ interpretSpiderGraph(zwSpiderInterpreter)(graph, inputs, outputs)
+ }
+
+ def interpretSpiderGraph(spiderInterpreter: (NodeV, Int, Int) => Tensor)(graph: Graph, inputList: List[VName], outputList: List[VName]): Tensor = {
+
+ // remove any wire vertices etc
+ val minGraph = graph.minimise
+
+ minGraph.vdata.count(nd => !nd._2.isWireVertex) match {
+ case 0 =>
+ stringGraph(minGraph, interpretZXSpider(
+ ZXAngleData(isGreen = true, PhaseExpression.zero(ValueType.AngleExpr)), 2, 0),
+ inputList,
+ outputList
+ )
+ case _ =>
+ pullOutVertexGraph(minGraph, inputList, outputList, interpretSpiderGraph(spiderInterpreter), spiderInterpreter)
+ }
+ }
+
+ def zxSpiderInterpreter(vdata: NodeV, inputs: Int, outputs: Int): Tensor = {
+
+ val zxData: ZXAngleData = {
+ val isGreen = vdata.typ == "Z"
+ val angle = PhaseExpression.parse(vdata.value, ValueType.AngleExpr)
+ ZXAngleData(isGreen, angle)
+ }
+
+ interpretZXSpider(zxData, inputs, outputs)
+ }
+
+ val interpretZXGraph : (Graph, List[VName], List[VName]) => Tensor = interpretSpiderGraph(zxSpiderInterpreter)
+
+ /**
+ * Given a graph of just boundaries and edges,
+ * return the tensor treating those edges as caps
+ *
+ * @param graph : Graph
+ * @return
+ */
+ def stringGraph(graph: Graph, cap: Tensor, inputList: List[VName], outputList: List[VName]): Tensor = {
+ if (graph.verts.size % 2 != 0) {
+ throw new Error("String graph should have an even number of boundaries")
+ }
+
+ // GRAPHS ARE READ BOTTOM TO TOP HERE.
+
+
+ // Errors here are from node names appearing in the input or output list but not in the graph
+ val numInternalInputs = inputList.count(vn => {
+ inputList.contains(graph.adjacentVerts(vn).head)
+ })
+ val numInternalOutputs = outputList.count(vn => {
+ outputList.contains(graph.adjacentVerts(vn).head)
+ })
+
+ val joinLowerList = inputList.filterNot(vn => inputList.contains(graph.adjacentVerts(vn).head))
+ val joinUpperList = outputList.filterNot(vn => outputList.contains(graph.adjacentVerts(vn).head))
+
+ var claimedInternalCaps: List[VName] = List()
+ var claimedInternalCups: List[VName] = List()
+
+ val numJoin = (inputList.size + outputList.size - numInternalInputs - numInternalOutputs) / 2
+
+ def toLowerNeighbour(place: Int): Int = {
+ val name: VName = inputList(place)
+ val neighbour: VName = graph.adjacentVerts(name).head
+ if (inputList.contains(neighbour) && inputList.contains(name)) {
+ val index = claimedInternalCaps.indexOf(name)
+ if (index < 0) {
+ claimedInternalCaps = claimedInternalCaps :+ name
+ claimedInternalCaps = claimedInternalCaps :+ neighbour
+ toLowerNeighbour(place)
+ } else {
+ index
+ }
+ } else {
+ numInternalInputs + joinLowerList.indexOf(name)
+ }
+ }
+
+
+ def toUpperNeighbour(place: Int): Int = {
+ val name: VName = outputList(place)
+ val neighbour: VName = graph.adjacentVerts(name).head
+ if (outputList.contains(neighbour) && outputList.contains(name)) {
+ val index = claimedInternalCups.indexOf(name)
+ if (index < 0) {
+ claimedInternalCups = claimedInternalCups :+ name
+ claimedInternalCups = claimedInternalCups :+ neighbour
+ toUpperNeighbour(place)
+ } else {
+ index
+ }
+ } else {
+ numInternalOutputs + joinUpperList.indexOf(name)
+ }
+ }
+
+ def toJoin(place: Int) : Int = {
+ val name: VName = joinLowerList(place)
+ val neighbour: VName = graph.adjacentVerts(name).head
+ joinUpperList.indexOf(neighbour)
+ }
+
+ val caps = cap.power(numInternalInputs / 2)
+ val cups = cap.transpose.power(numInternalOutputs / 2)
+
+ val sigmaLower = if (inputList.nonEmpty) {
+ Tensor.swap(inputList.length, toLowerNeighbour)
+ } else {
+ Tensor(Array(Array(Complex(1, 0))))
+ }
+
+ val sigmaUpper = if (outputList.nonEmpty) {
+ Tensor.swap(outputList.length, toUpperNeighbour).transpose
+ } else {
+ Tensor(Array(Array(Complex(1, 0))))
+ }
+
+ val sigmaMid = if (joinLowerList.nonEmpty) {
+ Tensor.swap(joinLowerList.length, toJoin)
+ } else {
+ Tensor(Array(Array(Complex(1, 0))))
+ }
+
+ sigmaUpper o (cups x Tensor.idWires(numJoin)) o sigmaMid o (caps x Tensor.idWires(numJoin)) o sigmaLower
+ }
+
+ private def pullOutVertexGraph(startingGraph: Graph,
+ inputList: List[VName],
+ outputList: List[VName],
+ graphInterpreter: (Graph, List[VName], List[VName]) => Tensor,
+ spiderInterpreter: (NodeV, Int, Int) => Tensor): Tensor = {
+ def ratioDanglingWires(name: VName): Double = {
+ // A measure of how many boundary vs non-boundary neighbours the node has
+ val numBoundary = startingGraph.adjacentVerts(name).count(vn => inputList.contains(vn))
+ (1 + numBoundary).toDouble / (1 + startingGraph.adjacentVerts(name).size - numBoundary)
+ }
+
+ // def boundaries(g: Graph): Set[VName] = g.verts.filter(g.isTerminalWire)
+
+ // Pick a vertex to shift from graph to tensor
+ val cutVertex = startingGraph.verts
+ .filterNot(vn => startingGraph.vdata(vn).isWireVertex)
+ .maxBy(ratioDanglingWires)
+ val numCutSpiderInOuts = startingGraph.adjacentVerts(cutVertex).size
+
+ // cut out that vertex, as well as any boundaries it was attached to
+ val (reducedGraph, freshMadeBoundaries, removedBoundaries) = startingGraph.cutVertex(cutVertex, inputList.toSet)
+ val uncutBoundariesVector: Vector[VName] =
+ inputList.filter(b => !removedBoundaries.contains(b)).toVector
+ val reducedGraphBoundaries = uncutBoundariesVector.toList ++ freshMadeBoundaries.toList
+
+ val spiderTensor = spiderInterpreter(
+ startingGraph.vdata(cutVertex).asInstanceOf[NodeV],
+ numCutSpiderInOuts - freshMadeBoundaries.size,
+ freshMadeBoundaries.size)
+
+ val reducedGraphBoundariesVector: Vector[VName] = reducedGraphBoundaries.toVector
+
+ val bottomSigma: Tensor = {
+ val swapList: List[Int] = {
+ var leftCount = 0
+ var rightCount = (reducedGraphBoundaries.toSet -- freshMadeBoundaries).size
+ inputList.indices.map(i =>
+ if (reducedGraphBoundaries.contains(inputList(i))) {
+ leftCount += 1
+ leftCount - 1
+ } else {
+ rightCount += 1
+ rightCount - 1
+ }
+ ).toList
+ }
+ Tensor.swap(inputList.size, swapList)
+ }
+
+ /*
+ val topSigma: Tensor = {
+ val inputs: Vector[VName] = uncutBoundariesVector ++ freshMadeBoundaries.toVector
+ Tensor.swap(reducedGraphBoundaries.size,
+ i => reducedGraphBoundariesVector.indexOf(inputs(i)))
+ }
+ */
+
+ /**
+ * Starting with the graph G, pulling out the spider S
+ * You want:
+ *
+ * [ G' ]
+ * [id] x [S]
+ * [ botSigma ]
+ *
+ * Where G' is the reduced graph, with vector of inputs: reducedGraphBoundariesVector
+ * S is the cut spider,
+ * the identity is on uncutBoundariesVector
+ * And the bottom sigma acts on allInputsVector
+ * (Rather than having topSigma now just have different lis tof inputs for G')
+ */
+
+ graphInterpreter(reducedGraph, reducedGraphBoundaries, outputList) o
+ // topSigma o
+ (Tensor.idWires(uncutBoundariesVector.size) x spiderTensor) o
+ bottomSigma
+ }
private def interpretAdjMat(adj: AdjMat, join: Tensor, vertexToTensor: (Int) => Tensor): Tensor = {
// Interpret the graph as (caps) o (crossings) o (vertices)
@@ -113,7 +362,7 @@ object Interpreter {
connectionList = connectionList ++ List(claimLeg(i), claimLeg(j))
}
}
- connectingCaps.plugAbove(allSpidersTensors, connectionList)
+ connectingCaps o Tensor.swap(connectionList).transpose o allSpidersTensors
}
}
@@ -124,27 +373,25 @@ object Interpreter {
} else {
// Tensor representation of a spider
def vertexToSpider(v: Int): Tensor = {
- def pullOutAngle(nv: NodeV) = if (!nv.value.isEmpty) {
- try {
- nv.value.toDouble
- } catch {
- case e: Error => nv.angle.evaluate(Map("pi" -> math.Pi)) * math.Pi
- }
+ def pullOutAngle(nv: NodeV): PhaseExpression = if (!nv.value.isEmpty) {
+ nv.value
} else {
- nv.angle.evaluate(Map("pi" -> math.Pi)) * math.Pi
+ "0"
}
val (colour, nodeType) = adj.vertexColoursAndTypes(v)
val numLegs = adj.mat(v).count(p => p)
- val green = true
colour match {
case VertexColour.Boundary => Tensor.id(2)
- case VertexColour.Green => interpretZXSpider(green, pullOutAngle(greenAM(nodeType)), 0, numLegs)
- case VertexColour.Red => interpretZXSpider(!green, pullOutAngle(redAM(nodeType)), 0, numLegs)
+ case VertexColour.Green => interpretZXSpider(
+ ZXAngleData(isGreen = true, pullOutAngle(greenAM(nodeType))), 0, numLegs)
+ case VertexColour.Red => interpretZXSpider(
+ ZXAngleData(isGreen = false, pullOutAngle(redAM(nodeType))), 0, numLegs)
}
+
}
- val cap = interpretZXSpider(green = true, 0, 2, 0)
+ val cap = interpretZXSpider(ZXAngleData(isGreen = true, "0"), 2, 0)
interpretAdjMat(adj, cap, vertexToSpider)
}
@@ -160,11 +407,14 @@ object Interpreter {
// Using ZX colours for now
colour match {
case VertexColour.Boundary => Tensor.id(2)
- case VertexColour.Green => interpretZWSpider(!black, numLegs)
- case VertexColour.Red => interpretZWSpider(black, numLegs)
+ case VertexColour.Green => interpretZWSpiderNoInputs(!black, numLegs)
+ case VertexColour.Red => interpretZWSpiderNoInputs(black, numLegs)
}
}
interpretAdjMat(adj, cup, vertexToSpider)
}
+
+ case class ZXAngleData(isGreen: Boolean, angle: PhaseExpression)
+
}
diff --git a/scala/src/main/scala/quanto/cosy/RuleSynthesis.scala b/scala/src/main/scala/quanto/cosy/RuleSynthesis.scala
index c8019902..fe46fa63 100644
--- a/scala/src/main/scala/quanto/cosy/RuleSynthesis.scala
+++ b/scala/src/main/scala/quanto/cosy/RuleSynthesis.scala
@@ -1,18 +1,35 @@
package quanto.cosy
-import quanto.data._
+import quanto.cosy.RuleSynthesis.GraphComparison
import quanto.data.Derivation.DerivationWithHead
+import quanto.data._
import quanto.rewrite._
-import quanto.util.json.Json
+import quanto.util.json.{Json, JsonObject}
import scala.annotation.tailrec
import scala.util.Random
+import scala.util.matching.Regex
+import quanto.data.Names._
/**
* Created by hector on 29/06/17.
*/
object RuleSynthesis {
+ type GraphComparison = (Graph, Graph) => Int
+ // returns x where
+ // x < 0 iff left < right
+ // x > 0 iff left > right
+ // x == 0 otherwise
+
+ def basicGraphComparison(left: Graph, right: Graph): Int = {
+ if (left < right) -1 else {
+ if (left > right) 1 else 0
+ }
+ }
+
+
+
def loadRuleDirectory(directory: String): List[Rule] = {
quanto.util.FileHelper.getListOfFiles(directory, raw".*\.qrule").
map(file => (file.getName.replaceFirst(raw"\.qrule", ""), Json.parse(file))).
@@ -21,6 +38,7 @@ object RuleSynthesis {
}
/** Given an equivalence class creates rules for any irreducible members of the class */
+ // Superseded by CoSyRun
def graphEquivClassReduction[T](makeGraph: (T => Graph),
equivalenceClass: EquivalenceClass[T],
knownRules: List[Rule]): List[Rule] = {
@@ -50,27 +68,117 @@ object RuleSynthesis {
}
}
- def discardDirectlyReducibleRules(rules: List[Rule], theory: Theory, seed: Random = new Random()): List[Rule] = {
- rules.filter(rule =>
- AutoReduce.genericReduce(graphToDerivation(rule.lhs, theory), rules.filter(r => r != rule), seed) >= rule.lhs
- )
+ def extendMatchingSpidersWithBBoxes(rule: Rule, boundariesRegex : Option[Regex]) : Rule = {
+ require(!rule.hasBBoxes)
+ // This is not safe to do if the rule already has bboxes.
+
+ val boundaries = GraphAnalysis.boundariesFromRegex(rule.lhs, boundariesRegex).toList
+ def nearestNeighbourType(graph: Graph, vName: VName) : Option[(VName, String)] = {
+ val neighbours = graph.adjacentNodesAndBoundaries(vName)
+ if(neighbours.size == 1){
+ val t : String = (graph.vdata(neighbours.head).data / "type").toString
+ Some((neighbours.head, t))
+ } else None
+ }
+
+ def addBBoxIfOkay(rule: Rule, vName: VName) : Rule = {
+ val lhsT = nearestNeighbourType(rule.lhs, vName)
+ val rhsT = nearestNeighbourType(rule.rhs, vName)
+ if(lhsT.nonEmpty && rhsT.nonEmpty && lhsT.get._2 == rhsT.get._2 && rule.lhs.vdata(lhsT.get._1).isInstanceOf[NodeV]){
+ val leftNeighbour = lhsT.get._1
+ val rightNeighbour = rhsT.get._1
+ val bBName = rule.lhs.bboxes.freshWithSuggestion("bb0")
+
+ val lhsB = rule.lhs.addBBox(bBName, BBData(), Set(vName))
+ val rhsB = rule.lhs.addBBox(bBName, BBData(), Set(vName))
+ Rule(lhsB, rhsB, rule.derivation, rule.description)
+ } else rule
+ }
+
+ def removeBoundaryIfSuperfluous(rule: Rule, vName: VName): Rule = {
+ val lhsT = nearestNeighbourType(rule.lhs, vName)
+ val rhsT = nearestNeighbourType(rule.rhs, vName)
+ if(lhsT.nonEmpty && rhsT.nonEmpty && lhsT.get == rhsT.get && rule.lhs.vdata(lhsT.get._1).isInstanceOf[NodeV]) {
+ val ln = lhsT.get._1
+ val rn = rhsT.get._1
+ val lNeighbourhood = rule.lhs.adjacentNodesAndBoundaries(ln).intersect(boundaries.toSet)
+ val rNeighbourhood = rule.rhs.adjacentNodesAndBoundaries(rn).intersect(boundaries.toSet)
+ if((lNeighbourhood intersect rNeighbourhood).size > 1) {
+ val lCut = rule.lhs.deleteVertex(vName)
+ val rCut = rule.rhs.deleteVertex(vName)
+ Rule(lCut, rCut)
+ } else rule
+ } else rule
+ }
+
+ val withBBoxes = boundaries.foldLeft(rule)(addBBoxIfOkay)
+ boundaries.foldLeft(withBBoxes)(removeBoundaryIfSuperfluous)
}
- def graphToDerivation(graph: Graph, theory: Theory): DerivationWithHead = {
- (new Derivation(theory, graph), None)
+ def removeIsomorphisms(theory: Theory, boundaryRegex: Option[Regex], rules: List[Rule]) : List[Rule] = {
+ def isIso(rule: Rule) : Boolean = GraphAnalysis.checkIsomorphic(theory, boundaryRegex)(rule.lhs, rule.rhs)
+ rules.filter(!isIso(_))
}
- def minimiseRuleset(rules: List[Rule], theory: Theory, seed: Random = new Random()): List[Rule] = {
- rules.map(rule => minimiseRuleInPresenceOf(rule, rules.filter(otherRule => otherRule != rule), theory))
+ def greedyReduceRules(comparison: GraphComparison, throwOutIsos : Option[(Theory, Option[Regex])] = None)
+ (rules: List[Rule]): List[Rule] = {
+
+ // Will automatically invert rules that head upwards
+
+ // This does not throw out isomorphisms for you; it returns the list with each entry altered
+
+ // Yes, this is imperative rather than functional. We really do want to update a list as we act on it.
+
+ var rulesAsMap = rules.zipWithIndex.map(ri => (ri._2, ri._1)).toMap // i -> rule_i
+ var updatedThisRun = true
+
+ def isomorphism(rule: Rule): Boolean = throwOutIsos match {
+ case Some(tor) =>
+ GraphAnalysis.checkIsomorphic(tor._1, tor._2)(rule.lhs, rule.rhs)
+ case None => false
+ }
+
+ while (updatedThisRun) {
+ updatedThisRun = false
+ for (i <- rulesAsMap.keys) {
+ val rule = rulesAsMap(i)
+ val otherRules = rulesAsMap.filterKeys(_ != i).values.toList.filter(!isomorphism(_)).map(
+ rule => {
+ if (comparison(rule.lhs, rule.rhs) < 0) {
+ rule.inverse
+ } else rule
+ }
+ )
+ val newLhs: Graph = AutoReduce.greedyReduce(comparison, graphToDerivation(rule.lhs), otherRules)
+ val newRhs: Graph = AutoReduce.greedyReduce(comparison, graphToDerivation(rule.rhs), otherRules)
+ if (newLhs != rule.lhs || newRhs != rule.rhs) {
+ rulesAsMap = rulesAsMap + (i -> new Rule(newLhs,
+ newRhs,
+ derivation = None,
+ RuleDesc(rule.name + " reduced")
+ ))
+ updatedThisRun = true
+ }
+ }
+ }
+
+ rulesAsMap.values.toList
}
- def minimiseRuleInPresenceOf(rule: Rule, otherRules: List[Rule], theory: Theory, seed: Random = new Random()): Rule = {
- val minLhs: Graph = AutoReduce.genericReduce(graphToDerivation(rule.lhs, theory), otherRules, seed)
- val minRhs: Graph = AutoReduce.genericReduce(graphToDerivation(rule.rhs, theory), otherRules, seed)
- val wasItReduced = (minLhs < rule.lhs) || (minRhs < rule.rhs)
- new Rule(minLhs, minRhs, description = RuleDesc(
- rule.name + (if (wasItReduced) " reduced" else "")))
+ def discardDirectlyReducibleRules(comparison: GraphComparison,
+ rules: List[Rule],
+ seed: Random = new Random()): List[Rule] = {
+ rules.filter(rule =>
+ AutoReduce.greedyReduce(comparison, graphToDerivation(rule.lhs), rules.filter(r => r != rule)) >= rule.lhs
+ ).filter(rule =>
+ AutoReduce.greedyReduce(comparison, graphToDerivation(rule.rhs), rules.filter(r => r != rule)) >= rule.rhs
+ )
}
+
+ def graphToDerivation(graph: Graph): DerivationWithHead = {
+ (new Derivation(graph), None)
+ }
+
}
/**
@@ -82,8 +190,6 @@ object AutoReduce {
* Automatically reduce, with no handler or multithreading
*/
- implicit def inverseToRuleVariant(inverse: Boolean): RuleVariant = if (inverse) RuleInverse else RuleNormal
-
def smallestStepNameBelow(derivationHeadPair: (Derivation, Option[DSName])): Option[DSName] = {
derivationHeadPair._2 match {
case Some(head) =>
@@ -101,19 +207,20 @@ object AutoReduce {
}
// Tries multiple methods and is sure to return nothing larger than what you started with
- def genericReduce(derivationAndHead: DerivationWithHead,
+ def genericReduce(comparison: GraphComparison)
+ (derivationAndHead: DerivationWithHead,
rules: List[Rule],
seed: Random = new Random()):
DerivationWithHead = {
var latestDerivation: DerivationWithHead = derivationAndHead
latestDerivation._2 match {
case Some(initialHead) =>
- latestDerivation = annealingReduce(latestDerivation, rules, seed)
- latestDerivation = greedyReduce(latestDerivation, rules)
- latestDerivation = greedyReduce((latestDerivation._1.addHead(initialHead), Some(initialHead)), rules)
+ latestDerivation = annealingReduce(comparison, latestDerivation, rules, seed)
+ latestDerivation = greedyReduce(comparison, latestDerivation, rules)
+ latestDerivation = greedyReduce(comparison, (latestDerivation._1.addHead(initialHead), Some(initialHead)), rules)
case None =>
- latestDerivation = annealingReduce(latestDerivation, rules, seed)
- latestDerivation = greedyReduce(latestDerivation, rules)
+ latestDerivation = annealingReduce(comparison, latestDerivation, rules, seed)
+ latestDerivation = greedyReduce(comparison, latestDerivation, rules)
}
// Go back to original request, find smallest child
@@ -121,17 +228,19 @@ object AutoReduce {
}
// Simplest entry point
- def annealingReduce(derivationHeadPair: DerivationWithHead,
+ def annealingReduce(comparison: GraphComparison,
+ derivationHeadPair: DerivationWithHead,
rules: List[Rule],
seed: Random = new Random(),
vertexLimit: Option[Int] = None): DerivationWithHead = {
- val maxTime = math.pow(derivationHeadPair.verts.size, 2).toInt // Set as squaring #vertices for now
+ val maxTime = 100 + math.pow(derivationHeadPair.verts.size, 2).toInt // Set as squaring #vertices for now
val timeDilation = 3 // Gives an e^-3 ~ 0.05% chance of a non-reduction rule on the final step
- annealingReduce(derivationHeadPair, rules, maxTime, timeDilation, seed, vertexLimit)
+ annealingReduce(comparison, derivationHeadPair, rules, maxTime, timeDilation, seed, vertexLimit)
}
// Enter here to have control over e.g. how long it runs for
- def annealingReduce(derivationHeadPair: DerivationWithHead,
+ def annealingReduce(comparison: GraphComparison,
+ derivationHeadPair: DerivationWithHead,
rules: List[Rule],
maxTime: Int,
timeDilation: Double,
@@ -144,41 +253,45 @@ object AutoReduce {
val suggestedNextStep = randomSingleApply(d, randRule, seed, None, None, None)
val head = Derivation.derivationHeadPairToGraph(d)
val smallEnough = vertexLimit.isEmpty || (head.verts.size < vertexLimit.get)
- if ((allowIncrease && smallEnough) || suggestedNextStep < head) suggestedNextStep else d
+ if ((allowIncrease && smallEnough) || comparison(suggestedNextStep, head) < 0) suggestedNextStep else d
} else d
}
}
@tailrec
- def greedyReduce(derivationHeadPair: DerivationWithHead,
+ def greedyReduce(comparison: GraphComparison,
+ derivationHeadPair: DerivationWithHead,
rules: List[Rule],
remainingRules: List[Rule]): DerivationWithHead = {
remainingRules match {
case r :: tailRules => Matcher.findMatches(r.lhs, derivationHeadPair) match {
- case ruleMatch #:: t =>
+ case ruleMatch #:: _ =>
val reducedGraph = Rewriter.rewrite(ruleMatch, r.rhs)._1.minimise
val stepName = quanto.data.Names.mapToNameMap(derivationHeadPair._1.steps).
freshWithSuggestion(DSName(r.description.name))
- greedyReduce((derivationHeadPair._1.addStep(
- derivationHeadPair._2,
- DStep(stepName, r, reducedGraph)
- ), Some(stepName)),
+ greedyReduce(comparison,
+ (derivationHeadPair._1.addStep(
+ derivationHeadPair._2,
+ DStep(stepName, r, reducedGraph)
+ ), Some(stepName)),
rules,
rules)
case Stream.Empty =>
- greedyReduce(derivationHeadPair, rules, tailRules)
+ greedyReduce(comparison, derivationHeadPair, rules, tailRules)
}
case Nil => derivationHeadPair
}
}
// Simply apply the first reduction rule it can find until there are none left
- def greedyReduce(derivationHeadPair: DerivationWithHead, rules: List[Rule]): DerivationWithHead = {
+ def greedyReduce(comparison: GraphComparison,
+ derivationHeadPair: DerivationWithHead,
+ rules: List[Rule]): DerivationWithHead = {
val reducingRules = rules.filter(rule => rule.lhs > rule.rhs)
- val reduced = greedyReduce(derivationHeadPair, reducingRules, reducingRules)
+ val reduced = greedyReduce(comparison, derivationHeadPair, reducingRules, reducingRules)
// Go round again until it stops reducing
- if (reduced < derivationHeadPair) {
- greedyReduce(reduced, rules)
+ if (comparison(reduced, derivationHeadPair) < 0) {
+ greedyReduce(comparison, reduced, rules)
} else reduced
}
diff --git a/scala/src/main/scala/quanto/cosy/SimplificationProcedure.scala b/scala/src/main/scala/quanto/cosy/SimplificationProcedure.scala
index 7b9ca718..1aabc1b7 100644
--- a/scala/src/main/scala/quanto/cosy/SimplificationProcedure.scala
+++ b/scala/src/main/scala/quanto/cosy/SimplificationProcedure.scala
@@ -134,9 +134,9 @@ object SimplificationProcedure {
(DerivationWithHead, State) = {
import state._
val d = derivation
- val randRule = rules.toList(seed.nextInt(rules.size))
+ val randRule = rules(seed.nextInt(rules.size))
val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(d)
- val errors = GraphAnalysis.detectErrors(d)
+ val errors = GraphAnalysis.detectPiNodes(d)
val specialsDistances = errors.map(eName => (eName, state.weightFunction(d, Set(eName))))
val maxDistance = specialsDistances.maxBy[Double](ed => ed._2.getOrElse(0))._2.getOrElse(0)
@@ -153,8 +153,9 @@ object SimplificationProcedure {
val changed = suggestedNextStep._1.steps.size > derivation._1.steps.size
- val shrunkNextStep = AutoReduce.greedyReduce(suggestedNextStep, greedyRules.getOrElse(Set()).toList)
- val newErrors = GraphAnalysis.detectErrors(shrunkNextStep)
+ val shrunkNextStep = AutoReduce.greedyReduce(RuleSynthesis.basicGraphComparison,
+ suggestedNextStep, greedyRules.getOrElse(Set()).toList)
+ val newErrors = GraphAnalysis.detectPiNodes(shrunkNextStep)
val suggestedNewSize: Double = state.weightFunction(shrunkNextStep, newErrors).getOrElse(0)
// Bias towards strict reduction
@@ -168,7 +169,7 @@ object SimplificationProcedure {
println(randRule.description)
(shrunkNextStep, state.next(Some(suggestedNewSize)))
} else {
- println("rej " + suggestedNewSize)
+ println("rej " + suggestedNewSize)
(d, state.next(currentDistance))
}
}
diff --git a/scala/src/main/scala/quanto/cosy/SimprocBatch.scala b/scala/src/main/scala/quanto/cosy/SimprocBatch.scala
new file mode 100644
index 00000000..adf5e931
--- /dev/null
+++ b/scala/src/main/scala/quanto/cosy/SimprocBatch.scala
@@ -0,0 +1,199 @@
+package quanto.cosy
+
+import java.util.Calendar
+
+import quanto.data.Names._
+import quanto.data._
+import quanto.gui.{BatchDerivationCreatorPanel, QuantoDerive}
+import quanto.rewrite.Simproc
+import quanto.util.UserAlerts.{Elevation, alert}
+import quanto.util.json.{Json, JsonArray, JsonObject}
+import quanto.util.{FileHelper, UserOptions}
+
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.concurrent.Future
+import scala.swing.Publisher
+import scala.swing.event.Event
+import scala.util.{Failure, Success}
+
+
+// Each simproc, graph pair generates a SimprocSingleRun result
+// The result holds the name of the simproc, the generated derivation, and the timings for each step
+case class SimprocSingleRun(simprocName: String, derivation: Derivation, derivationTimings: List[(String, Double)])
+
+object SimprocSingleRun {
+ def toJson(ssr: SimprocSingleRun): Json = {
+ JsonObject(
+ "simproc" -> ssr.simprocName,
+ "derivation" -> Derivation.toJson(ssr.derivation),
+ "timings" -> JsonArray(
+ ssr.derivationTimings.map(t =>
+ JsonObject(
+ "step" -> t._1,
+ "time" -> t._2
+ )
+ )
+ )
+ )
+ }
+
+ def fromJson(json: Json): SimprocSingleRun = {
+ val name: String = (json / "simproc").stringValue
+ val derivation: Derivation = Derivation.fromJson(json / "derivation")
+ val timings: List[(String, Double)] = (json / "timings").asArray.map(j => (
+ (j / "step").stringValue,
+ (j / "time").doubleValue)).toList
+ SimprocSingleRun(name, derivation, timings)
+ }
+}
+
+// Collect the single runs and the simproc definitions in one place
+case class SimprocBatchResult(selectedSimprocs: List[String],
+ allSimprocs: Map[String, String],
+ singleResults: List[SimprocSingleRun],
+ notes: String) {
+
+ lazy val toJson: JsonObject = {
+ JsonObject(
+ "python" -> JsonArray(
+ allSimprocs.map(
+ ss => JsonObject(
+ ss._1 -> ss._2
+ )
+ )
+ ),
+ "selected_simprocs" -> JsonArray(selectedSimprocs),
+ "results" -> JsonArray(
+ singleResults.map(sr => SimprocSingleRun.toJson(sr))
+ ),
+ "notes" -> notes
+ )
+ }
+}
+
+// This class contains just the metadata for the run
+// It is used when loading the results file when you don't want everything to be pulled from the json
+case class SimprocLazyBatchResult(notes: String, selectedSimprocs: List[String], resultCount: Int)
+
+object SimprocBatchResult {
+ def collate(SBResult: SimprocBatchResult): Map[String, List[(Derivation, List[(String, Double)])]] = {
+ SBResult.singleResults.
+ groupBy { case SimprocSingleRun(a, _, _) => a }.
+ mapValues(_.map { case SimprocSingleRun(_, b, c) => (b, c) })
+ }
+
+ // Import everything into memory
+ def fromJson(json: Json): SimprocBatchResult = {
+ val notes: String = (json / "notes").stringValue
+ val selectedSimprocs: List[String] = (json / "selected_simprocs").asArray.map(a => a.stringValue).toList
+ val allSimprocs: Map[String, String] = (json / "python").asArray.flatMap(j => j.asObject.v.map(sj => sj._1 -> sj._2.stringValue)).toMap
+ val results: List[SimprocSingleRun] = (json / "results").asArray.map(j => SimprocSingleRun.fromJson(j.asObject)).toList
+ SimprocBatchResult(selectedSimprocs, allSimprocs, results, notes)
+ }
+
+ // Just import the metadata into memory
+ def lazyFromJson(json: Json): SimprocLazyBatchResult = {
+ val notes: String = (json / "notes").stringValue
+ val selectedSimprocs: List[String] = (json / "selected_simprocs").asArray.map(a => a.stringValue).toList
+ val resultCount: Int = (json / "results").asArray.size
+ SimprocLazyBatchResult(notes, selectedSimprocs, resultCount)
+ }
+}
+
+
+// Set up the run as a SimprocBatch
+// This contains all the user supplied information
+// Then pulls loaded simprocs etc. from the project
+case class SimprocBatch(selectedSimprocs: List[String], selectedGraphs: List[Graph], notes: String) {
+ def run(): Unit = {
+ alert(s"Beginning simproc batch run on ${selectedGraphs.length} graphs, " +
+ s"with simprocs\n${selectedSimprocs.mkString("\n")}")
+
+ val simprocGraphPairs = for (simprocName <- selectedSimprocs; graph <- selectedGraphs) yield (simprocName, graph)
+ val listFutureResults = simprocGraphPairs.map(sg => {
+ Future {
+ val derivationData = SimprocBatch.runSimprocGetTimings(sg._1, sg._2)
+ SimprocSingleRun(sg._1, derivationData._1, derivationData._2)
+ }
+ })
+ val futureListResults = Future.sequence(listFutureResults)
+
+ futureListResults.onComplete {
+ case Success(list) =>
+ val result = SimprocBatchResult(selectedSimprocs,
+ SimprocBatch.loadedSimprocs.map(ss => (ss._1, ss._2.sourceCode)),
+ list,
+ notes)
+ SimprocBatch.publish(SimprocBatchRunComplete(result))
+ case Failure(e) =>
+ alert("Simproc batch run failed!", Elevation.ERROR)
+ e.printStackTrace()
+ }
+ }
+}
+
+case class SimprocNotLoaded(simprocName: String) extends Exception
+
+case class SimprocBatchRunComplete(result: SimprocBatchResult) extends Event
+
+object SimprocBatch extends Publisher {
+
+ var timeout: Long = 60 * 1000 // timeout time in milliseconds
+
+ // CONCURRENTLY run a simproc on a graph
+ // Makes its own derivation with timing data
+ // Runs until completion or timeout
+ def runSimprocGetTimings(simprocName: String, graph: Graph): (Derivation, List[(String, Double)]) = {
+ val simproc = simprocFromName(simprocName)
+ val startTime: Long = now
+ val jobIDAtStart = BatchDerivationCreatorPanel.jobID
+
+ def timeElapsed = now - startTime
+
+ var timings: List[(String, Double)] = List()
+ var derivation = new Derivation(graph)
+ var parentOpt: Option[DSName] = None
+ for ((graph, rule) <- simproc.simp(graph)) {
+ // Stop if taking too long
+ // Stop if the user has incremented the jobID (indicating they want the job to halt)
+ if (timeElapsed < timeout || BatchDerivationCreatorPanel.jobID > jobIDAtStart) {
+ val suggest = rule.name.replaceFirst("^.*\\/", "") + "-0"
+ val step = DStep(
+ name = derivation.steps.freshWithSuggestion(DSName(suggest)),
+ rule = rule,
+ graph = graph.minimise) // layout is already done by simproc now
+
+ derivation = derivation.addStep(parentOpt, step)
+ timings = (step.name.toString, timeElapsed.toDouble) :: timings
+ parentOpt = Some(step.name)
+ }
+ }
+
+ (derivation, timings)
+ }
+
+ listenTo(this)
+ reactions += {
+ case SimprocBatchRunComplete(result) =>
+ alert("Simproc Batch completed")
+ val fileName = UserOptions.preferredDateTimeFormat.format(Calendar.getInstance().getTime)
+ .replace(":","-").replace(".","--") + ".qsbr"
+ val projectRoot = QuantoDerive.CurrentProject.map(p => p.rootFolder + "/").getOrElse("")
+ FileHelper.printJson(projectRoot + "batch_results/" + fileName, result.toJson)
+ }
+
+ implicit def simprocFromName(name: String): Simproc = {
+ try {
+ loadedSimprocs(name)
+ } catch {
+ case _: Exception =>
+ alert(s"Requested simproc $name was not found. Please load it first.", Elevation.ERROR)
+ throw SimprocNotLoaded(name)
+ }
+ }
+
+ def now: Long = Calendar.getInstance().getTimeInMillis
+
+ def loadedSimprocs: Map[String, Simproc] = QuantoDerive.CurrentProject.map(p => p.simprocs).getOrElse(Map())
+
+}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/cosy/Tensor.scala b/scala/src/main/scala/quanto/cosy/Tensor.scala
index d845bd59..79febda3 100644
--- a/scala/src/main/scala/quanto/cosy/Tensor.scala
+++ b/scala/src/main/scala/quanto/cosy/Tensor.scala
@@ -2,8 +2,6 @@ package quanto.cosy
import quanto.util.json.{JsonArray, JsonObject}
-import scala.runtime.RichInt
-
/**
* A tensor-valued interpretation of a graph
*
@@ -81,16 +79,16 @@ class Tensor(c: Array[Array[Complex]]) {
def t: Tensor = this.transpose
- def transpose: Tensor = {
- // transpose
- Tensor(this.width, this.height, (i, j) => this.c(j)(i))
- }
-
def dagger: Tensor = {
// conjugate transpose
this.conjugate.transpose
}
+ def transpose: Tensor = {
+ // transpose
+ Tensor(this.width, this.height, (i, j) => this.c(j)(i))
+ }
+
def conjugate: Tensor = {
Tensor(this.height, this.width, (i, j) => this.c(i)(j).conjugate)
}
@@ -164,20 +162,41 @@ class Tensor(c: Array[Array[Complex]]) {
def entry(down: Int, across: Int): Complex = contents(down)(across)
def isRoughlyUpToScalar(that: Tensor, distance: Double = Tensor.defaultDistance): Boolean = {
- //
+
+ if (!this.isSameShapeAs(that)) return false
+
+ val thisIsRoughly0 = this.isRoughly(Tensor.zero(this.height, this.width), distance)
+ val thatIsRoughly0 = that.isRoughly(Tensor.zero(this.height, this.width), distance)
+
+ if (thisIsRoughly0 && thatIsRoughly0) return true
+ if (thisIsRoughly0 && !thatIsRoughly0) return false
+ if (!thisIsRoughly0 && thatIsRoughly0) return false
this.distanceAfterScaling(that) < distance
}
def distanceAfterScaling(that: Tensor): Double = {
if (this.isSameShapeAs(that)) {
- var maxEntry = (0, 0)
- var maxEntryValue = Complex.zero
- for (i <- this.c.indices; j <- this.c.head.indices) {
+
+ val (maxEntry, maxEntryValue) =
+ (for (i <- this.c.indices; j <- this.c.head.indices) yield (i, j)).toList.foldLeft((0, 0), Complex.zero) {
+ (agg, coord) => {
+ val value = this.c(coord._1)(coord._2)
+ if (value.abs > agg._2.abs) {
+ (coord, value)
+ } else {
+ agg
+ }
+ }
+ }
+
+ /*
+ {
if (this.c(i)(j).abs > maxEntryValue.abs) {
maxEntry = (i, j)
maxEntryValue = this.c(i)(j)
}
}
+ */
if (maxEntryValue == Complex.zero) {
this.distance(that)
} else {
@@ -189,7 +208,7 @@ class Tensor(c: Array[Array[Complex]]) {
}
}
} else {
- -1
+ 1
}
}
@@ -198,6 +217,41 @@ class Tensor(c: Array[Array[Complex]]) {
isRoughly(_, maxDist)
}
+ def isRoughly(that: Tensor, maxDistance: Double = 1e-14): Boolean =
+ // Compare two tensors up to a given distance
+ if (this.isSameShapeAs(that)) {
+ this.distance(that) < maxDistance
+ } else {
+ false
+ }
+
+ /** Returns max abs distance */
+ def distance(that: Tensor): Double = {
+ require(this.isSameShapeAs(that))
+ (this - that).contents.flatten.foldLeft(0.0) { (a: Double, b: Complex) => math.max(a, b.abs) }
+ }
+
+ def isSameShapeAs(that: Tensor): Boolean = {
+ this.width == that.width && this.height == that.height
+ }
+
+ def -(that: Tensor): Tensor = {
+ // this - that
+ this + that.scaled(Complex(-1, 0))
+ }
+
+ def +(that: Tensor): Tensor = {
+ // this + that
+ require(this.width == that.width)
+ require(this.height == that.height)
+ Tensor(this.height, this.width, (i, j) => this.c(i)(j) + that.contents(i)(j))
+ }
+
+ def scaled(factor: Complex): Tensor = {
+ // scalar multiplication
+ Tensor(this.height, this.width, (i, j) => this.c(i)(j) * factor)
+ }
+
override def equals(other: Any): Boolean =
// Compares matrix-entry by matrix-entry
other match {
@@ -250,37 +304,6 @@ class Tensor(c: Array[Array[Complex]]) {
this.scaled(maxAbsEntry.inverse())
}
}
-
- def isRoughly(that: Tensor, maxDistance: Double = 1e-14): Boolean =
- // Compare two tensors up to a given distance
- this.distance(that) < maxDistance
-
- /** Returns max abs distance */
- def distance(that: Tensor): Double = {
- require(this.isSameShapeAs(that))
- (this - that).contents.flatten.foldLeft(0.0) { (a: Double, b: Complex) => math.max(a, b.abs) }
- }
-
- def isSameShapeAs(that: Tensor): Boolean = {
- this.width == that.width && this.height == that.height
- }
-
- def -(that: Tensor): Tensor = {
- // this - that
- this + that.scaled(Complex(-1, 0))
- }
-
- def +(that: Tensor): Tensor = {
- // this + that
- require(this.width == that.width)
- require(this.height == that.height)
- Tensor(this.height, this.width, (i, j) => this.c(i)(j) + that.contents(i)(j))
- }
-
- def scaled(factor: Complex): Tensor = {
- // scalar multiplication
- Tensor(this.height, this.width, (i, j) => this.c(i)(j) * factor)
- }
}
object Tensor {
@@ -293,6 +316,8 @@ object Tensor {
Tensor(Array(Array(1, 1), Array(1, -1))).scaled(math.pow(2, -0.5))
}
val defaultDistance = 1e-14
+ private var permutationCache: Map[List[Int], Matrix] = Map()
+ private var swapCache: Map[List[Int], Tensor] = Map()
def idWires(n: Int): Tensor = {
// Identity on n wires, i.e. 2^n * 2^n matrix
@@ -304,13 +329,6 @@ object Tensor {
Tensor(n, n, (a: Int, b: Int) => if (a == b) new Complex(1) else new Complex(0))
}
- def apply(c: Array[Array[Complex]]) = new Tensor(c)
-
- def apply(cInt: Array[Array[Int]]): Tensor = {
- // Convert integer matrix to complex matrix
- Tensor(cInt.length, cInt(0).length, (i, j) => Complex.doubleToComplex(cInt(i)(j)))
- }
-
def apply(height: Int, width: Int, generator: Tensor.Generator): Tensor = {
// create tensor based on size and generating function
new Tensor(generatorToMatrix(height, width, generator))
@@ -335,47 +353,72 @@ object Tensor {
(for (_ <- 0 until width) yield Complex.zero).toArray).toArray
}
+ def apply(c: Array[Array[Complex]]) = new Tensor(c)
+
+ def apply(cInt: Array[Array[Int]]): Tensor = {
+ // Convert integer matrix to complex matrix
+ Tensor(cInt.length, cInt(0).length, (i, j) => Complex.doubleToComplex(cInt(i)(j)))
+ }
+
def permutation(asList: List[Int]): Tensor = {
// Produce the matrix that sends i -> asList(i)
val gen = (x: Int) => asList(x)
new Tensor(permutationMatrix(asList.length, gen))
}
- def permutationMatrix(size: Int, gen: Int => Int): Matrix = {
- val base = emptyMatrix(size, size)
- for (i <- 0 until size) {
- base(gen(i))(i) = Complex.one
- }
- base
- }
-
def permutation(asArray: Array[Int]): Tensor = {
// Produce the matrix that sends i -> asArray(i)
val gen = (x: Int) => asArray(x)
new Tensor(permutationMatrix(asArray.length, gen))
}
+ def permutationMatrix(size: Int, gen: Int => Int): Matrix = {
+ val genAsList: List[Int] = (0 until size).map(gen(_)).toList
+ if (permutationCache.contains(genAsList)) {
+ permutationCache(genAsList)
+ } else {
+ val base = emptyMatrix(size, size)
+ for (i <- 0 until size) {
+ base(genAsList(i))(i) = Complex.one
+ }
+ permutationCache += genAsList -> base
+ base
+ }
+ }
+
def swap(asList: List[Int]): Tensor = {
// Produce the matrix that sends WIRE i to WIRE asList(i)
+ // READ BOTTOM TO TOP
val gen = (x: Int) => asList(x)
swap(asList.length, gen)
}
def swap(size: Int, gen: Int => Int): Tensor = {
// Produce the matrix that sends WIRE i to WIRE gen(i)
- def padLeft(s: String, n: Int): String = if (s.length < n) padLeft("0" + s, n) else s
-
- def permGen(i: Int): Int = {
- val binaryStringIn = padLeft((i: RichInt).toBinaryString, size)
- val permedString = (for (j <- 0 until size) yield binaryStringIn(gen(j))).mkString("")
- permedString match {
- case "" => 0
- case s => Integer.parseInt(s, 2)
+ // READ BOTTOM TO TOP
+ val genAsList: List[Int] = (0 until size).map(gen).toList
+
+ if (swapCache.contains(genAsList)) {
+ swapCache(genAsList)
+ } else {
+
+ def padLeft(s: String, n: Int): String = if (s.length < n) padLeft("0" + s, n) else s
+
+ def permGen(i: Int): Int = {
+ val binaryStringIn = padLeft(i.toBinaryString, size)
+ val permedString = (for (j <- 0 until size) yield binaryStringIn(gen(j))).mkString("")
+ permedString match {
+ case "" => 0
+ case s => Integer.parseInt(s, 2)
+ }
+
}
+ val answer = permutation(math.pow(2, size).toInt, permGen).transpose
+ swapCache += genAsList -> answer
+ answer
}
- permutation(math.pow(2, size).toInt, permGen)
}
def permutation(size: Int, gen: Int => Int): Tensor = {
diff --git a/scala/src/main/scala/quanto/cosy/qutrits.scala b/scala/src/main/scala/quanto/cosy/qutrits.scala
deleted file mode 100644
index a3ae36b7..00000000
--- a/scala/src/main/scala/quanto/cosy/qutrits.scala
+++ /dev/null
@@ -1,17 +0,0 @@
-package quanto.cosy
-
-object qutrits {
-
- def ei ( arg : Double) = Complex(math.cos(arg),math.sin(arg))
- def conj (T : Tensor) = Tensor(T.height,T.width,(i, j) => T(j,i).conjugate)
- val w : Complex = ei(2*math.Pi/3)
- val H_unnormalised = Tensor(Array(Array[Complex](1,1,1),Array[Complex](1,w,w*w),Array[Complex](1,w*w,w)))
- val H = H_unnormalised.scaled(1.0 / math.sqrt(3))
- val e0 = Tensor(Array(Array(1,0,0))).transpose
- val e1 = Tensor(Array(Array(0,1,0))).transpose
- val e2 = Tensor(Array(Array(0,0,1))).transpose
- val f0 = Tensor(Array(Array(1,1,1))).transpose.scaled(1.0 / math.sqrt(3))
- val f1 = Tensor(Array(Array(Complex.one,w,w*w))).transpose.scaled(1.0 / math.sqrt(3))
- val f2 = Tensor(Array(Array(Complex.one,w*w,w))).transpose.scaled(1.0 / math.sqrt(3))
-}
-
diff --git a/scala/src/main/scala/quanto/data/AngleExpression.scala b/scala/src/main/scala/quanto/data/AngleExpression.scala
deleted file mode 100644
index e252485a..00000000
--- a/scala/src/main/scala/quanto/data/AngleExpression.scala
+++ /dev/null
@@ -1,139 +0,0 @@
-package quanto.data
-
-// ported from linrat_angle_expr.ML
-
-import quanto.util.Rational
-
-import scala.util.parsing.combinator._
-
-class AngleParseException(message: String)
- extends Exception(message)
-
-class AngleEvaluationException(message: String)
- extends Exception(message)
-
-class AngleExpression(val const : Rational, val coeffs : Map[String,Rational]) {
- lazy val vars = coeffs.keySet
-
- def *(r : Rational): AngleExpression =
- AngleExpression(const * r, coeffs.mapValues(x => x * r))
-
- def *(i : Int): AngleExpression = this * Rational(i)
-
- def +(e : AngleExpression) = AngleExpression(const + e.const, e.coeffs.foldLeft(coeffs) {
- case (m, (k,v)) => m + (k -> (v + m.getOrElse(k, Rational(0))))
- })
-
- def -(e: AngleExpression) : AngleExpression = this + (e * -1)
-
- def subst(v : String, e : AngleExpression) : AngleExpression = {
- val c = coeffs.getOrElse(v,Rational(0))
- this - AngleExpression(Rational(0), Map(v -> c)) + (e * c)
- }
-
- def subst(mp : Map[String, AngleExpression]): AngleExpression =
- mp.foldLeft(this) { case (e, (v,e1)) => e.subst(v,e1) }
-
- def evaluate(mp: Map[String, Double]) : Double = {
- try {
- const + coeffs.foldLeft(0.0) { (a, b) => a + (mp(b._1) * Rational.rationalToDouble(b._2)) }
- } catch {
- case e : Exception => new AngleEvaluationException("Given arguments do not match those in the coefficient list")
- 0
- }
- }
-
- override def equals(that: Any): Boolean = that match {
- case e : AngleExpression =>
- const == e.const && coeffs == e.coeffs
- case _ => false
- }
-
- override def toString: String = {
- var fst = true
- var s = ""
- if (!const.isZero) {
- fst = false
- if (const.isOne) s += "\\pi"
- else s += const.toString + " \\pi"
- }
-
- coeffs.foreach { case (x,c) =>
- if (fst) {
- fst = false
- s = s + (if (c == Rational(1)) "" else c.toString + " ") + x
- } else {
- if (c < Rational(0)) {
- s = s + " - " + (if (c == Rational(-1)) "" else (c * -1).toString + " ") + x
- } else {
- s = s + " + " + (if (c == Rational(1)) "" else c.toString + " ") + x
- }
- }
- }
-
- s
- }
-
-
-}
-
-object AngleExpression {
- def apply(const : Rational = Rational(0),
- coeffs : Map[String,Rational] = Map()) =
- new AngleExpression(const mod 2, coeffs.filter { case (_,c) => !c.isZero })
-
- val ZERO = AngleExpression(Rational(0))
- val ONE_PI = AngleExpression(Rational(1))
-
- def parse(s : String) = AngleExpressionParser.p(s)
-
- private object AngleExpressionParser extends RegexParsers {
- override def skipWhitespace = true
- def INT: Parser[Int] = """[0-9]+""".r ^^ { _.toInt }
- def INT_OPT : Parser[Int] = INT.? ^^ { _.getOrElse(1) }
- def IDENT : Parser[String] = """[\\a-zA-Z_][a-zA-Z0-9_]*""".r ^^ { _.toString }
- def PI : Parser[Unit] = """\\?[pP][iI]""".r ^^ { _ => Unit }
-
-
- def coeff : Parser[Rational] =
- INT ~ "/" ~ INT ^^ { case n ~ _ ~ d => Rational(n,d) } |
- "(" ~ coeff ~ ")" ^^ { case _ ~ c ~ _ => c } |
- INT ^^ { n => Rational(n) }
-
-
- def frac : Parser[AngleExpression] =
- INT_OPT ~ "*".? ~ PI ~ "/" ~ INT ^^ { case n ~ _ ~ _ ~ _ ~ d => AngleExpression(Rational(n,d)) } |
- INT_OPT ~ "*".? ~ IDENT ~ "/" ~ INT ^^ {
- case n ~ _ ~ x ~ _ ~ d => AngleExpression(Rational(0), Map(x -> Rational(n,d)))
- }
-
- def term : Parser[AngleExpression] =
- frac |
- "-" ~ term ^^ { case _ ~ t => t * -1 } |
- coeff ~ "*".? ~ PI ^^ { case c ~ _ ~ _ => AngleExpression(c) } |
- PI ^^ { _ => ONE_PI } |
- coeff ~ "*".? ~ IDENT ^^ { case c ~ _ ~ x => AngleExpression(Rational(0), Map(x -> c)) } |
- IDENT ^^ { case x => AngleExpression(Rational(0), Map(x -> Rational(1))) } |
- coeff ^^ { AngleExpression(_) } |
- "(" ~ expr ~ ")" ^^ { case _ ~ t ~ _ => t }
-
- def term1 : Parser[AngleExpression] =
- "+" ~ term ^^ { case _ ~ t => t } |
- "-" ~ term ^^ { case _ ~ t => t * -1 }
-
- def terms : Parser[AngleExpression] =
- term1 ~ terms ^^ { case s ~ t => s + t } |
- term1
-
- def expr : Parser[AngleExpression] =
- term ~ terms ^^ { case s ~ t => s + t } |
- term |
- "" ^^ { _ => ZERO }
-
- def p(s : String) = parseAll(expr, s) match {
- case Success(e, _) => e
- case Failure(msg, _) => throw new AngleParseException(msg)
- case Error(msg, _) => throw new AngleParseException(msg)
- }
- }
-}
diff --git a/scala/src/main/scala/quanto/data/BBData.scala b/scala/src/main/scala/quanto/data/BBData.scala
index 95549032..b41d1552 100644
--- a/scala/src/main/scala/quanto/data/BBData.scala
+++ b/scala/src/main/scala/quanto/data/BBData.scala
@@ -3,14 +3,14 @@ package quanto.data
import quanto.util.json.JsonObject
/**
- * A class which represents the data for bang boxes
- *
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BBData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- * @author Aleks Kissinger
- */
+ * A class which represents the data for bang boxes
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BBData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ * @author Aleks Kissinger
+ */
case class BBData(
- data: JsonObject = JsonObject(),
- annotation: JsonObject = JsonObject(),
- theory: Theory = Theory.DefaultTheory
-) extends GraphElementData
+ data: JsonObject = JsonObject(),
+ annotation: JsonObject = JsonObject(),
+ theory: Theory = Theory.DefaultTheory
+ ) extends GraphElementData
diff --git a/scala/src/main/scala/quanto/data/BinRel.scala b/scala/src/main/scala/quanto/data/BinRel.scala
index 521c7834..4e7f29b5 100644
--- a/scala/src/main/scala/quanto/data/BinRel.scala
+++ b/scala/src/main/scala/quanto/data/BinRel.scala
@@ -1,32 +1,31 @@
/** A package which contains the data objects */
package quanto.data
-import collection.immutable.TreeSet
-import scala.collection.{mutable, IterableLike}
+import scala.collection.immutable.TreeSet
+import scala.collection.{IterableLike, mutable}
/**
- * A trait which contains useful methods for binary relations
- *
- * @tparam A type of the domain
- * @tparam B type of the codomain
- *
- * @author Aleks Kissinger
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BinRel.scala Source code]]
- */
-trait BinRel[A,B] extends IterableLike[(A,B), BinRel[A,B]] {
+ * A trait which contains useful methods for binary relations
+ *
+ * @tparam A type of the domain
+ * @tparam B type of the codomain
+ * @author Aleks Kissinger
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BinRel.scala Source code]]
+ */
+trait BinRel[A, B] extends IterableLike[(A, B), BinRel[A, B]] {
/** Domain function - assigns a set to each element of the domain */
- def domf : Map[A,Set[B]]
+ def domf: Map[A, Set[B]]
/** Codomain function - assigns a set to each element of the codomain */
- def codf : Map[B,Set[A]]
+ def codf: Map[B, Set[A]]
- /**
- * Add an element to relation
- *
- * @param kv element to be added
- * @return New relation containing '''kv''' in addition to the current one
- */
- def +(kv: (A,B)): BinRel[A,B]
+ /**
+ * Add an element to relation
+ *
+ * @param kv element to be added
+ * @return New relation containing '''kv''' in addition to the current one
+ */
+ def +(kv: (A, B)): BinRel[A, B]
/**
* Check if an element is in a relation
@@ -34,166 +33,173 @@ trait BinRel[A,B] extends IterableLike[(A,B), BinRel[A,B]] {
* @param kv element to be checked
* @return boolean
*/
- def contains(kv: (A,B)): Boolean
+ def contains(kv: (A, B)): Boolean
/**
- * Remove an element from relation
- *
- * @param kv element to be removed
- * @return New relation containing the same elements as the
- * current one except '''kv'''
- */
- def unmap(kv: (A,B)): BinRel[A,B]
- def -(kv: (A,B)) : BinRel[A,B] = unmap(kv)
+ * Remove an element from relation
+ *
+ * @param kv element to be removed
+ * @return New relation containing the same elements as the
+ * current one except '''kv'''
+ */
+ def unmap(kv: (A, B)): BinRel[A, B]
+
+ def -(kv: (A, B)): BinRel[A, B] = unmap(kv)
/**
- * Remove all relation pairs '''(d,_)'''
- *
- * @param d Element to be removed from domain
- * @return New relation with relation pairs '''(d,_)''' removed
- */
- def unmapDom(d: A): BinRel[A,B]
+ * Remove all relation pairs '''(d,_)'''
+ *
+ * @param d Element to be removed from domain
+ * @return New relation with relation pairs '''(d,_)''' removed
+ */
+ def unmapDom(d: A): BinRel[A, B]
/**
- * Remove all relation pairs '''(_,c)'''
- *
- * @param c Element to be removed from codomain
- * @return New relation with relation pairs '''(_,c)''' removed
- */
- def unmapCod(c: B): BinRel[A,B]
+ * Remove all relation pairs '''(_,c)'''
+ *
+ * @param c Element to be removed from codomain
+ * @return New relation with relation pairs '''(_,c)''' removed
+ */
+ def unmapCod(c: B): BinRel[A, B]
/** The domain set of the relation */
- def dom = domf.keys
- def domSet = domf.keySet
+ def dom: Iterable[A] = domf.keys
+
+ def domSet: Set[A] = domf.keySet
/** The codomain set of the relation */
- def cod = codf.keys
- def codSet = codf.keySet
+ def cod: Iterable[B] = codf.keys
+
+ def codSet: Set[B] = codf.keySet
/**
- * The codomain image of a set of domain elements under this relation
- *
- * @param set A set containing elements of type '''A'''
- * @return The set of codomain elements which are in relation with some
- * element from '''set'''
- */
- def directImage(set: Set[A]) = set.foldLeft(Set[B]()){ (s,x) => s union domf.getOrElse(x, Set()) }
+ * The codomain image of a set of domain elements under this relation
+ *
+ * @param set A set containing elements of type '''A'''
+ * @return The set of codomain elements which are in relation with some
+ * element from '''set'''
+ */
+ def directImage(set: Set[A]): Set[B] = set.foldLeft(Set[B]()) { (s, x) => s union domf.getOrElse(x, Set()) }
/**
- * The domain image of a set of codomain elements under this relation
- *
- * @param set A set containing elements of type '''B'''
- * @return The set of domain elements which are in relation with some
- * element from '''set'''
- */
- def inverseImage(set: Set[B]) = set.foldLeft(Set[A]()) { (s,y) => s union codf.getOrElse(y, Set()) }
-
-
+ * The domain image of a set of codomain elements under this relation
+ *
+ * @param set A set containing elements of type '''B'''
+ * @return The set of domain elements which are in relation with some
+ * element from '''set'''
+ */
+ def inverseImage(set: Set[B]): Set[A] = set.foldLeft(Set[A]()) { (s, y) => s union codf.getOrElse(y, Set()) }
+
+
/**
- * BinRel inherits equality from '''domf'''
- *
- * @return The hashcode of '''domf'''
- */
- override def hashCode = domf.hashCode()
+ * BinRel inherits equality from '''domf'''
+ *
+ * @return The hashcode of '''domf'''
+ */
+ override def hashCode: Int = domf.hashCode()
- /** True iff '''other''' is of type '''BinRel[_,_]''' */
- override def canEqual(other: Any) = other match {
- case _: BinRel[_,_] => true
+ /** BinRel inherits equality from '''domf''' */
+ override def equals(other: Any): Boolean = other match {
+ case that: BinRel[_, _] => (that canEqual this) && (this.domf == that.domf)
case _ => false
}
- /** BinRel inherits equality from '''domf''' */
- override def equals(other: Any) = other match {
- case that: BinRel[_,_] => (that canEqual this) && (this.domf == that.domf)
+ /** True iff '''other''' is of type '''BinRel[_,_]''' */
+ override def canEqual(other: Any): Boolean = other match {
+ case _: BinRel[_, _] => true
case _ => false
}
- override def toString() = {
- getClass.getSimpleName + "(" + seq.map{ case (k,v) => k.toString + " -> " + v.toString }.mkString(", ") + ")"
+ override def toString: String = {
+ getClass.getSimpleName + "(" + seq.map { case (k, v) => k.toString + " -> " + v.toString }.mkString(", ") + ")"
}
// TODO: get ++ implemented correctly (i.e. using builders etc) for PFun/BinRel
- def ++(r:BinRel[A,B]) = r.foldLeft(this) { case (mp, kv) => mp + kv }
+ def ++(r: BinRel[A, B]): BinRel[A, B] = r.foldLeft(this) { case (mp, kv) => mp + kv }
}
/**
- * A class which represents a binary relation as a pair of two functions -
- * the domain map and the codomain map
- *
- * @tparam A The type of the elements in the domain of the relation
- * @tparam B The type of the elements in the codomain of the relation
- *
- * @constructor Create an instance of the class from two functions mapping
- * elements to sets of elements
- *
- * @param domMap The domain map - maps an element of type '''A''' to the
- * set of elements of type '''B''' which are in relation with it
- *
- * @param codMap The codomain map - maps an element of type '''B''' to the
- * set of elements of type '''A''' which are in relation with it
- *
- * @author Aleks Kissinger
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BinRel.scala Source code]]
- */
-class MapPairBinRel[A,B](domMap: Map[A,TreeSet[B]], codMap: Map[B,TreeSet[A]])
- (implicit domOrd: Ordering[A], codOrd: Ordering[B])
- extends BinRel[A,B] {
-
- def domf = domMap.withDefaultValue(TreeSet()(codOrd))
- def codf = codMap.withDefaultValue(TreeSet()(domOrd))
-
- def +(kv: (A,B)) = {
- new MapPairBinRel[A,B](
+ * A class which represents a binary relation as a pair of two functions -
+ * the domain map and the codomain map
+ *
+ * @tparam A The type of the elements in the domain of the relation
+ * @tparam B The type of the elements in the codomain of the relation
+ * @constructor Create an instance of the class from two functions mapping
+ * elements to sets of elements
+ * @param domMap The domain map - maps an element of type '''A''' to the
+ * set of elements of type '''B''' which are in relation with it
+ * @param codMap The codomain map - maps an element of type '''B''' to the
+ * set of elements of type '''A''' which are in relation with it
+ * @author Aleks Kissinger
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BinRel.scala Source code]]
+ */
+class MapPairBinRel[A, B](domMap: Map[A, TreeSet[B]], codMap: Map[B, TreeSet[A]])
+ (implicit domOrd: Ordering[A], codOrd: Ordering[B])
+ extends BinRel[A, B] {
+
+ def +(kv: (A, B)): MapPairBinRel[A, B] = {
+ new MapPairBinRel[A, B](
domMap + (kv._1 -> (domMap.getOrElse(kv._1, TreeSet()(codOrd)) + kv._2)),
codMap + (kv._2 -> (codMap.getOrElse(kv._2, TreeSet()(domOrd)) + kv._1))
)
}
- def contains(kv: (A,B)) = domf.get(kv._1).contains(kv._2)
+ def contains(kv: (A, B)) : Boolean = domf.get(kv._1).contains(kv._2)
- def unmap(kv: (A,B)) = {
- new MapPairBinRel[A,B](
+ def unmapDom(d: A): MapPairBinRel[A, B] = domf(d).foldLeft(this) { (rel, c) => rel unmap(d, c) }
+
+ def domf : Map[A, TreeSet[B]] = domMap.withDefaultValue(TreeSet()(codOrd))
+
+ def unmap(kv: (A, B)): MapPairBinRel[A, B] = {
+ new MapPairBinRel[A, B](
domMap.get(kv._1) match {
- case Some(xs) if (xs.size == 1 && xs.contains(kv._2)) => domMap - kv._1
+ case Some(xs) if xs.size == 1 && xs.contains(kv._2) => domMap - kv._1
case Some(xs) => domMap + (kv._1 -> (xs - kv._2))
case None => domMap
},
codMap.get(kv._2) match {
- case Some(xs) if (xs.size == 1 && xs.contains(kv._1)) => codMap - kv._2
+ case Some(xs) if xs.size == 1 && xs.contains(kv._1) => codMap - kv._2
case Some(xs) => codMap + (kv._2 -> (xs - kv._1))
case None => codMap
}
)
}
- def unmapDom(d: A) = domf(d).foldLeft(this) { (rel,c) => rel unmap (d, c) }
- def unmapCod(c: B) = codf(c).foldLeft(this) { (rel,d) => rel unmap (d, c) }
+ def unmapCod(c: B): MapPairBinRel[A, B] = codf(c).foldLeft(this) { (rel, d) => rel unmap(d, c) }
+
+ def codf: Map[B, TreeSet[A]] = codMap.withDefaultValue(TreeSet()(domOrd))
+
+ def seq: Seq[(A, B)] = iterator.toSeq
/** Returns an iterator over pairs '''(a,b)''' of the relation */
- def iterator = domMap.foldLeft(Iterator[(A,B)]()) { case (iter, (domElement, codSet)) =>
+ def iterator: Iterator[(A, B)] = domMap.foldLeft(Iterator[(A, B)]()) { case (iter, (domElement, codSet)) =>
iter ++ (Iterator.continually(domElement) zip codSet.iterator)
}
- protected[this] def newBuilder = new mutable.Builder[(A,B),BinRel[A,B]] {
- val s = collection.mutable.Buffer[(A,B)]()
+ protected[this] def newBuilder : mutable.Builder[(A, B), BinRel[A, B]] = new mutable.Builder[(A, B), BinRel[A, B]] {
+ val s: mutable.Buffer[(A, B)] = collection.mutable.Buffer[(A, B)]()
+
def result() = BinRel(s: _*)
- def clear() = s.clear()
- def +=(elem: (A,B)) = { s += elem; this }
- }
- def seq = iterator.toSeq
+ def clear(): Unit = s.clear()
+
+ def +=(elem: (A, B)): this.type = {
+ s += elem
+ this
+ }
+ }
}
/**
- * Companion object for the BinRel trait
- *
- * @author Aleks Kissinger
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BinRel.scala Source code]]
- */
+ * Companion object for the BinRel trait
+ *
+ * @author Aleks Kissinger
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BinRel.scala Source code]]
+ */
object BinRel {
/** Construct a binary relation from a sequence of pairs '''(a,b)''' */
- def apply[A,B](kvs: (A,B)*)(implicit domOrd: Ordering[A], codOrd: Ordering[B]) : BinRel[A,B] = {
- kvs.foldLeft(new MapPairBinRel[A,B](Map(),Map())(domOrd,codOrd)){ (rel, kv) => rel + kv }
+ def apply[A, B](kvs: (A, B)*)(implicit domOrd: Ordering[A], codOrd: Ordering[B]): BinRel[A, B] = {
+ kvs.foldLeft(new MapPairBinRel[A, B](Map(), Map())(domOrd, codOrd)) { (rel, kv) => rel + kv }
}
}
diff --git a/scala/src/main/scala/quanto/data/CompositeExpression.scala b/scala/src/main/scala/quanto/data/CompositeExpression.scala
new file mode 100644
index 00000000..5c87ba58
--- /dev/null
+++ b/scala/src/main/scala/quanto/data/CompositeExpression.scala
@@ -0,0 +1,207 @@
+package quanto.data
+
+import quanto.data.Theory.ValueType
+import quanto.util.Rational
+
+import scala.util.parsing.combinator.RegexParsers
+
+
+case class MismatchedPhaseData(data1: Vector[ValueType], data2: Vector[ValueType])
+ extends Exception(data1.toString + " != " + data2.toString)
+
+case class GenericParseException(message: String) extends Exception(message)
+
+case class TypeNotFoundException(message: String) extends Exception(message)
+
+
+// The types are stored as a vector of enums
+// The data is stored as a vector of PhaseExpressions, which are then cast according to the type data
+case class CompositeExpression(valueTypes: Vector[ValueType], values: Vector[PhaseExpression]) {
+
+ lazy val varsWithType: Set[(ValueType, String)] =
+ values.zipWithIndex.flatMap(valueWithIndex =>
+ valueWithIndex._1.vars.map(variableName =>
+ (valueTypes(valueWithIndex._2), variableName))
+ ).toSet
+ val vars: Set[String] = values.flatMap(_.vars).toSet
+ val description: ValueType = ValueType.Empty
+
+ // Addition
+ def +(e: CompositeExpression): CompositeExpression = {
+
+ if (valueTypes != e.valueTypes) {
+ throw MismatchedPhaseData(valueTypes, e.valueTypes)
+ }
+
+
+ val summedValues: Vector[PhaseExpression] =
+ valueTypes.zipWithIndex.map(vi => values(vi._2) + e.values(vi._2))
+ CompositeExpression(valueTypes, summedValues)
+ }
+
+ // Subtraction
+ def -(e: CompositeExpression): CompositeExpression = {
+
+
+ val negatedValues: Vector[PhaseExpression] =
+ valueTypes.zipWithIndex.map(vi => values(vi._2) - e.values(vi._2))
+ CompositeExpression(valueTypes, negatedValues)
+ }
+
+ // Combine strings of each subvalue
+ override def toString: String = {
+
+ if (values.forall(e => e == PhaseExpression.zero(e.description))) {
+ ""
+ } else {
+ val stringValues = values.zipWithIndex.map(pi => (pi._1.description, pi._2) match {
+ case (ValueType.String, _) =>
+ pi._1.toString // Always render strings directly
+ case (_, 0) => // Always render the first entry directly
+ pi._1.toString
+ case (_, _) => // Put a space before anything else to aid legibiliy
+ " " + pi._1.toString
+ })
+ stringValues.mkString(",")
+ }
+ }
+
+ def *(r: Rational): CompositeExpression = {
+
+ val scaledValues: Vector[PhaseExpression] = values.map(v => v * r)
+ CompositeExpression(valueTypes, scaledValues)
+ }
+
+ def firstOrError[T <: PhaseExpression](valueType: ValueType): T = {
+ val typeIndex = valueTypes.zipWithIndex.find(x => x._1 == valueType)
+ if (typeIndex.nonEmpty) {
+ values(typeIndex.get._2).asInstanceOf[T]
+ } else {
+ throw TypeNotFoundException(valueType.toString + " was not present in " + valueTypes.mkString(","))
+ }
+ }
+
+ def first[T <: PhaseExpression](valueType: ValueType): Option[T] = {
+ try {
+ Some(this.firstOrError(valueType))
+ }
+ catch {
+ case _: Throwable => None
+ }
+ }
+
+ def substSubValues(mp: Map[String, PhaseExpression]): CompositeExpression =
+ mp.foldLeft(this) { case (e, (v, e1)) => e.substSubValue(v, e1) }
+
+ def substSubVariables(mp: Map[(ValueType, String), String]): CompositeExpression = {
+ substSubValues(mp.map(vss => vss._1._2 -> PhaseExpression.parse(vss._2, vss._1._1) ))
+ }
+
+ def substSubValue(variableName: String, phase: PhaseExpression): CompositeExpression = {
+ // Apply substitution to subvalues with the correct valueType
+
+ val newValues: Vector[PhaseExpression] =
+ valueTypes.zipWithIndex.map(vi => {
+ val current = values(vi._2)
+ if (phase.description == current.description) {
+ current.subst(variableName, phase)
+ }
+ else {
+ current
+ }
+ })
+ CompositeExpression(valueTypes, newValues)
+ }
+
+}
+
+object CompositeExpression {
+
+ implicit val modulus: Option[Int] = None
+
+ def empty: CompositeExpression = {
+ CompositeExpression(Vector(), Vector())
+ }
+
+ def zero(valueTypes: Vector[ValueType]): CompositeExpression = {
+ val zeroValues: Vector[PhaseExpression] = valueTypes.map(t => PhaseExpression.zero(t))
+ CompositeExpression(valueTypes, zeroValues)
+ }
+
+ def one(valueTypes: Vector[ValueType]): CompositeExpression = {
+ val oneValues: Vector[PhaseExpression] = valueTypes.map(t => PhaseExpression.one(t))
+ CompositeExpression(valueTypes, oneValues)
+ }
+
+ def parse(types: String, values: String): CompositeExpression = {
+ val typeVector = parseTypes(types)
+ CompositeExpression(typeVector, parseKnowingTypes(values, typeVector))
+ }
+
+ def parseKnowingTypes(s: String, v: Vector[ValueType]): Vector[PhaseExpression] = {
+ // Will fill with empties if more types requested than string elements given
+ val split: Array[String] = s.split(",")
+ v.zipWithIndex.map(si => parseSingle(split.lift(si._2).getOrElse(""), si._1))
+ }
+
+ def parseSingle(s: String, v: ValueType): PhaseExpression = PhaseExpression.parse(s, v)
+
+ def parseTypes(s: String): Vector[ValueType] = TypeExpressionParser.p(s).toVector
+
+ def wrap[T <: PhaseExpression](expression: T): CompositeExpression = {
+ CompositeExpression(Vector(expression.description), Vector(expression))
+ }
+
+ private object TypeExpressionParser extends RegexParsers {
+ override def skipWhitespace = true
+
+ // Partial matches will confuse things!
+ // Make sure a supermatch comes before a submatch
+ def ANGLE: Parser[ValueType] =
+ """(angle_expr|(LinRat|)[Aa]ngle)""".r ^^ { _ => ValueType.AngleExpr }
+
+ def BOOL: Parser[ValueType] = """[bB]ool(ean|)""".r ^^ { _ => ValueType.Boolean }
+
+ def RATIONAL: Parser[ValueType] = """[Rr]ational""".r ^^ { _ => ValueType.Rational }
+
+ def INTEGER: Parser[ValueType] = """[Ii]nt(eger|)""".r ^^ { _ => ValueType.Integer }
+
+ def STRING: Parser[ValueType] = """[Ss]tring""".r ^^ { _ => ValueType.String }
+
+ def LONG: Parser[ValueType] = """long(_string|)""".r ^^ { _ => ValueType.Long }
+
+ def ENUM: Parser[ValueType] = """enum""".r ^^ { _ => ValueType.Enum }
+
+ def EMPTY: Parser[ValueType] = """[Ee]mpty""".r ^^ { _ => ValueType.Empty }
+
+
+ def term: Parser[ValueType] =
+ ANGLE |
+ BOOL |
+ RATIONAL |
+ INTEGER |
+ STRING |
+ LONG |
+ ENUM |
+ EMPTY
+
+
+ def terms: Parser[List[ValueType]] =
+ "(" ~ terms ~ ")" ^^ { case _ ~ t ~ _ => t } |
+ term ~ "," ~ terms ^^ { case s ~ _ ~ t => s :: t } |
+ term ^^ { t => List(t) }
+
+ def expr: Parser[List[ValueType]] =
+ terms |
+ term ^^ { t => List(t) } |
+ "" ^^ { _ => List(ValueType.Empty) }
+
+ def p(s: String): List[ValueType] = parseAll(expr, s) match {
+ case Success(e, _) => e
+ case Failure(msg, _) => throw GenericParseException(msg)
+ case Error(msg, _) => throw GenericParseException(msg)
+ }
+ }
+
+}
+
diff --git a/scala/src/main/scala/quanto/data/Derivation.scala b/scala/src/main/scala/quanto/data/Derivation.scala
index 055e666b..eb515168 100644
--- a/scala/src/main/scala/quanto/data/Derivation.scala
+++ b/scala/src/main/scala/quanto/data/Derivation.scala
@@ -1,11 +1,11 @@
package quanto.data
+import quanto.gui.{DeriveState, HeadState, StepState}
+import quanto.layout.ForceLayout
+import quanto.util.TreeSeq
import quanto.util.json._
-import javax.management.remote.rmi._RMIConnectionImpl_Tie
+
import scala.collection.SortedSet
-import quanto.gui.{StepState, HeadState, DeriveState}
-import quanto.util.TreeSeq
-import quanto.layout.ForceLayout
trait DerivationException
@@ -13,20 +13,9 @@ case class DerivationLoadException(message: String, cause: Throwable = null)
extends Exception(message, cause)
with DerivationException
-sealed abstract class RuleVariant
-
-case object RuleNormal extends RuleVariant {
- override def toString = "normal"
-}
-
-case object RuleInverse extends RuleVariant {
- override def toString = "inverse"
-}
-
case class DStep(name: DSName,
ruleName: String,
rule: Rule,
- variant: RuleVariant,
graph: Graph) {
def layout: DStep = {
val layoutProc = new ForceLayout
@@ -36,7 +25,7 @@ case class DStep(name: DSName,
//layoutProc.edgeLength = 0.1
layoutProc.gravity = 0.0
- val rhsi = rule.rhs.verts.filter(!rule.rhs.isBoundary(_))
+ val rhsi = rule.rhs.verts.filter(!rule.rhs.isTerminalWire(_))
graph.verts.foreach { v =>
if (!rhsi.contains(v)) layoutProc.lockVertex(v)
@@ -52,19 +41,15 @@ case class DStep(name: DSName,
def copy(name: DSName = name,
ruleName: String = ruleName,
rule: Rule = rule,
- variant: RuleVariant = variant,
graph: Graph = graph)
- = DStep(name, ruleName, rule, variant, graph)
+ = DStep(name, ruleName, rule, graph)
}
object DStep {
- implicit def booleanToVariant(isInverted : Boolean) : RuleVariant = {
- if (isInverted) RuleNormal else RuleInverse
- }
- def apply(name: DSName, rule: Rule, graph: Graph) : DStep =
- DStep(name, rule.name, rule, rule.description.inverse, graph)
+ def apply(name: DSName, rule: Rule, graph: Graph): DStep =
+ DStep(name, rule.name, rule, graph)
def toJson(dstep: DStep, parent: Option[DSName], thy: Theory = Theory.DefaultTheory): Json = {
JsonObject(
@@ -72,23 +57,25 @@ object DStep {
"parent" -> parent.map(_.toString),
"rule_name" -> dstep.ruleName,
"rule" -> Rule.toJson(dstep.rule, thy),
- "rule_variant" -> (dstep.variant match {
- case RuleNormal => JsonNull;
- case v => v.toString
+ "rule_variant" -> (if (dstep.rule.description.inverse) {
+ "inverse"
+ } else {
+ "forwards"
}),
"graph" -> Graph.toJson(dstep.graph, thy)
- ).noEmpty
+ )
}
def fromJson(name: DSName, json: Json, thy: Theory = Theory.DefaultTheory): DStep = try {
+ val baseRule = Rule.fromJson(json / "rule", thy)
+ val rule: Rule = json ? "rule_variant" match {
+ case JsonString("inverse") => baseRule.inverse
+ case _ => baseRule
+ }
DStep(
name = name,
ruleName = (json / "rule_name").stringValue,
- rule = Rule.fromJson(json / "rule", thy),
- variant = json ? "rule_variant" match {
- case JsonString("inverse") => RuleInverse;
- case _ => RuleNormal
- },
+ rule = rule,
graph = Graph.fromJson(json / "graph", thy)
)
} catch {
@@ -104,8 +91,7 @@ object DStep {
}
}
-case class Derivation(theory: Theory,
- root: Graph,
+case class Derivation(root: Graph,
steps: Map[DSName, DStep] = Map(),
heads: SortedSet[DSName] = SortedSet(),
parentMap: PFun[DSName, DSName] = PFun())
@@ -118,7 +104,7 @@ case class Derivation(theory: Theory,
def stepsTo(head: DSName): Array[DSName] =
(parentMap.get(head) match {
case Some(p) => stepsTo(p)
- case None => Array()
+ case None => Array.empty[DSName]
}) :+ head
def graphsTo(head: DSName): Array[Graph] = root +: stepsTo(head).map(s => steps(s).graph)
@@ -129,12 +115,6 @@ case class Derivation(theory: Theory,
copy(steps = steps + (s -> s1))
}
- def copy(theory: Theory = theory,
- root: Graph = root,
- steps: Map[DSName, DStep] = steps,
- heads: SortedSet[DSName] = heads,
- parent: PFun[DSName, DSName] = parentMap) = Derivation(theory, root, steps, heads, parent)
-
def allChildren(s: DSName): Set[DSName] =
children(s).foldLeft(Set[DSName]()) { case (set, c) => set union allChildren(c) } + s
@@ -146,6 +126,12 @@ case class Derivation(theory: Theory,
def addHead(h: DSName): Derivation = copy(heads = heads + h)
+ def copy(
+ root: Graph = root,
+ steps: Map[DSName, DStep] = steps,
+ heads: SortedSet[DSName] = heads,
+ parent: PFun[DSName, DSName] = parentMap) = Derivation(root, steps, heads, parent)
+
def deleteHead(h: DSName): Derivation = copy(heads = heads - h)
def addStep(parentOpt: Option[DSName], step: DStep): Derivation = parentOpt match {
@@ -169,7 +155,7 @@ case class Derivation(theory: Theory,
copy(
steps = steps1,
heads = parentOpt match {
- case Some(s) => heads1 + s;
+ case Some(`s`) => heads1 + s;
case _ => heads1
},
parent = parent1
@@ -222,7 +208,7 @@ case class Derivation(theory: Theory,
object Derivation {
type DerivationWithHead = (Derivation, Option[DSName])
- def fromJson(json: Json, thy: Theory = Theory.DefaultTheory) = try {
+ def fromJson(json: Json, thy: Theory = Theory.DefaultTheory): Derivation = try {
val parent = (json ? "steps").asObject.foldLeft(PFun[DSName, DSName]()) {
case (pf, (step, obj)) => obj.get("parent") match {
case Some(JsonString(p)) => pf + (DSName(step) -> DSName(p))
@@ -237,23 +223,22 @@ object Derivation {
val heads = (json ? "heads").asArray.foldLeft(SortedSet[DSName]()) { case (set, h) => set + DSName(h.stringValue) }
Derivation(
- theory = thy,
root = Graph.fromJson(json / "root", thy),
steps = steps,
heads = heads,
parentMap = parent
)
} catch {
- case e: JsonAccessException => throw new DerivationLoadException(e.getMessage, e)
+ case e: JsonAccessException => throw DerivationLoadException(e.getMessage, e)
case e: GraphLoadException =>
- throw new DerivationLoadException("Graph 'root': " + e.getMessage, e)
+ throw DerivationLoadException("Graph 'root': " + e.getMessage, e)
case e: DerivationLoadException => throw e
case e: Exception =>
e.printStackTrace()
- throw new DerivationLoadException("Error reading JSON", e)
+ throw DerivationLoadException("Error reading JSON", e)
}
- def toJson(derive: Derivation, thy: Theory = Theory.DefaultTheory) = {
+ def toJson(derive: Derivation, thy: Theory = Theory.DefaultTheory): JsonObject = {
val steps = derive.steps.map { case (k, v) => (k.toString, DStep.toJson(v, derive.parentMap.get(k), thy)) }
JsonObject(
"root" -> Graph.toJson(derive.root, thy),
diff --git a/scala/src/main/scala/quanto/data/EData.scala b/scala/src/main/scala/quanto/data/EData.scala
index 7644e342..6ff8e3f6 100644
--- a/scala/src/main/scala/quanto/data/EData.scala
+++ b/scala/src/main/scala/quanto/data/EData.scala
@@ -3,73 +3,85 @@ package quanto.data
import quanto.util.json._
/**
- * An abstract class providing an interface for accessing edge data
- *
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- * @author Aleks Kissinger
- */
+ * An abstract class providing an interface for accessing edge data
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ * @author Aleks Kissinger
+ */
abstract class EData extends GraphElementData {
def isDirected: Boolean
+ def typeInfo = theory.edgeTypes(typ)
+
/** type of the edge */
- def typ = (data / "type").stringValue
+ def typ: String = (data / "type").stringValue
- def typeInfo = theory.edgeTypes(typ)
- def label = data.getOrElse("label","").stringValue
- def value = data ? "value"
+ def label: String = data.getOrElse("label", "").stringValue
+
+ def value: Json = data ? "value"
/** Create a copy of the current edge data, but with the new value */
def withValue(v: String): EData
def toDirEdge = DirEdge(data, annotation, theory)
+
def toUndirEdge = UndirEdge(data, annotation, theory)
}
/**
- * A class which represents directed edge data.
- *
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- */
+ * A class which represents directed edge data.
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ */
case class DirEdge(
- data: JsonObject = Theory.DefaultTheory.defaultEdgeData,
- annotation: JsonObject = JsonObject(),
- theory: Theory = Theory.DefaultTheory
-) extends EData {
+ data: JsonObject = Theory.DefaultTheory.defaultEdgeData,
+ annotation: JsonObject = JsonObject(),
+ theory: Theory = Theory.DefaultTheory
+ ) extends EData {
def isDirected = true
+
def withValue(v: String): DirEdge = copy(data = data.setPath("$.value", v).setPath("$.label", v).asObject)
- override def toJson = DirEdge.toJson(this, theory)
+
+ override def toJson: JsonObject = DirEdge.toJson(this, theory)
}
/**
- * A class which represents undirected edge data.
- *
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- */
+ * A class which represents undirected edge data.
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ */
case class UndirEdge(
- data: JsonObject = Theory.DefaultTheory.defaultEdgeData,
- annotation: JsonObject = JsonObject(),
- theory: Theory = Theory.DefaultTheory
-) extends EData {
+ data: JsonObject = Theory.DefaultTheory.defaultEdgeData,
+ annotation: JsonObject = JsonObject(),
+ theory: Theory = Theory.DefaultTheory
+ ) extends EData {
def isDirected = false
+
def withValue(v: String): UndirEdge = copy(data = data.setPath("$.value", v).setPath("$.label", v).asObject)
- override def toJson = UndirEdge.toJson(this, theory)
+
+ override def toJson: JsonObject = UndirEdge.toJson(this, theory)
}
/**
- * Companion object for the DirEdge class. Contains methods to convert to/from
- * JSON
- *
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- */
+ * Companion object for the DirEdge class. Contains methods to convert to/from
+ * JSON
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ */
object DirEdge {
- def toJson(d: EData, theory: Theory) = JsonObject(
- "data" -> (if (d.data == theory.edgeTypes(d.typ).defaultData) JsonNull else d.data),
+ def toJson(d: EData, theory: Theory): JsonObject = JsonObject(
+ "data" -> (if (d.typ == theory.defaultEdgeType && d.data == theory.edgeTypes(d.typ).defaultData) {
+ JsonNull
+ } else {
+ d.data
+ }),
"annotation" -> d.annotation).noEmpty
- def fromJson(json: Json, theory: Theory) : DirEdge = {
+
+ def fromJson(json: Json, theory: Theory): DirEdge = {
val data = json.getOrElse("data", theory.defaultEdgeData).asObject
val annotation = (json ? "annotation").asObject
DirEdge(data, annotation, theory)
@@ -77,17 +89,22 @@ object DirEdge {
}
/**
- * Companion object for the UndirEdge class. Contains methods to convert
- * to/from JSON
- *
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- */
+ * Companion object for the UndirEdge class. Contains methods to convert
+ * to/from JSON
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ */
object UndirEdge {
- def toJson(d: EData, theory: Theory) = JsonObject(
- "data" -> (if (d.data == theory.edgeTypes(d.typ).defaultData) JsonNull else d.data),
+ def toJson(d: EData, theory: Theory): JsonObject = JsonObject(
+ "data" -> (if (d.typ == theory.defaultEdgeType && d.data == theory.edgeTypes(d.typ).defaultData) {
+ JsonNull
+ } else {
+ d.data
+ }),
"annotation" -> d.annotation).noEmpty
- def fromJson(json: Json, theory: Theory) : UndirEdge = {
+
+ def fromJson(json: Json, theory: Theory): UndirEdge = {
val data = json.getOrElse("data", theory.defaultEdgeData).asObject
val annotation = (json ? "annotation").asObject
UndirEdge(data, annotation, theory)
diff --git a/scala/src/main/scala/quanto/data/GData.scala b/scala/src/main/scala/quanto/data/GData.scala
index 3f4cb887..250d989d 100644
--- a/scala/src/main/scala/quanto/data/GData.scala
+++ b/scala/src/main/scala/quanto/data/GData.scala
@@ -3,12 +3,13 @@ package quanto.data
import quanto.util.json.JsonObject
/**
- * A class which represents graph data
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/GData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- */
+ * A class which represents graph data
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/GData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ */
case class GData(
- data: JsonObject = JsonObject(),
- annotation: JsonObject = JsonObject(),
- theory: Theory = Theory.DefaultTheory
-) extends GraphElementData
+ data: JsonObject = JsonObject(),
+ annotation: JsonObject = JsonObject(),
+ theory: Theory = Theory.DefaultTheory
+ ) extends GraphElementData
diff --git a/scala/src/main/scala/quanto/data/Graph.scala b/scala/src/main/scala/quanto/data/Graph.scala
index ec812d23..91c67098 100644
--- a/scala/src/main/scala/quanto/data/Graph.scala
+++ b/scala/src/main/scala/quanto/data/Graph.scala
@@ -1,81 +1,113 @@
package quanto.data
-import Names._
+
import quanto.cosy.AdjMat
+import quanto.data.Names._
+import quanto.data.Theory.ValueType
+import quanto.util.json.JsonValues._
import quanto.util.json._
-import math.sqrt
-import JsonValues._
-import collection.mutable.ArrayBuffer
-import quanto.util._
-import java.awt.datatransfer.{DataFlavor, Transferable}
-import scala.annotation.tailrec
+import scala.collection.mutable.ArrayBuffer
+import scala.math.sqrt
+import scala.util.Random
class GraphException(msg: String, cause: Throwable = null) extends Exception(msg, cause)
+
class PluggingException(msg: String) extends GraphException(msg)
case class GraphSearchContext(exploredV: Set[VName], exploredE: Set[EName])
class GraphLoadException(message: String, cause: Throwable = null)
-extends GraphException(message, cause)
+ extends GraphException(message, cause)
sealed abstract class BBOp {
def bb: BBName
+
def shortName: String
+
//override def toString: String = shortName
}
-case class BBExpand(bb: BBName, mp: GraphMap) extends BBOp { def shortName = "E(" + bb + ")" }
-case class BBCopy(bb: BBName, mp: GraphMap) extends BBOp { def shortName = "C(" + bb + ")" }
-case class BBDrop(bb: BBName) extends BBOp { def shortName = "D(" + bb + ")" }
-case class BBKill(bb: BBName) extends BBOp { def shortName = "K(" + bb + ")" }
+case class BBExpand(bb: BBName, mp: GraphMap) extends BBOp {
+ def shortName: String = "E(" + bb + ")"
+}
+
+case class BBCopy(bb: BBName, mp: GraphMap) extends BBOp {
+ def shortName: String = "C(" + bb + ")"
+}
+
+case class BBDrop(bb: BBName) extends BBOp {
+ def shortName: String = "D(" + bb + ")"
+}
+
+case class BBKill(bb: BBName) extends BBOp {
+ def shortName: String = "K(" + bb + ")"
+}
case class Graph(
- data: GData = GData(),
- vdata: Map[VName,VData] = Map[VName,VData](),
- edata: Map[EName,EData] = Map[EName,EData](),
- source: PFun[EName,VName] = PFun[EName,VName](),
- target: PFun[EName,VName] = PFun[EName,VName](),
- bbdata: Map[BBName,BBData] = Map[BBName,BBData](),
- inBBox: BinRel[VName,BBName] = BinRel[VName,BBName](),
- bboxParent: PFun[BBName,BBName] = PFun[BBName,BBName]()) extends Ordered[Graph]
-{
- def isInput (v: VName): Boolean = vdata(v).isWireVertex && inEdges(v).isEmpty && outEdges(v).size == 1
- def isOutput(v: VName): Boolean = vdata(v).isWireVertex && outEdges(v).isEmpty && inEdges(v).size == 1
- def isInternal(v: VName): Boolean = vdata(v).isWireVertex && outEdges(v).size == 1 && inEdges(v).size == 1
- def isBoundary(vn: VName): Boolean =
- vdata(vn).isWireVertex && (inEdges(vn).size + outEdges(vn).size) <= 1
+ data: GData = GData(),
+ vdata: Map[VName, VData] = Map[VName, VData](),
+ edata: Map[EName, EData] = Map[EName, EData](),
+ source: PFun[EName, VName] = PFun[EName, VName](),
+ target: PFun[EName, VName] = PFun[EName, VName](),
+ bbdata: Map[BBName, BBData] = Map[BBName, BBData](),
+ inBBox: BinRel[VName, BBName] = BinRel[VName, BBName](),
+ bboxParent: PFun[BBName, BBName] = PFun[BBName, BBName]()) extends Ordered[Graph] {
+ lazy val nodesThatAreNotWires: Set[VName] = verts.filterNot(vdata(_).isWireVertex)
+ protected val factory = new Graph(_, _, _, _, _, _, _, _)
+
+ def isInternalWire(v: VName): Boolean = vdata(v).isWireVertex && outEdges(v).size == 1 && inEdges(v).size == 1
+
def isCircle(vn: VName): Boolean =
vdata(vn).isWireVertex && inEdges(vn).size == 1 && inEdges(vn) == outEdges(vn)
- def typeOf(v: VName): String = vdata(v).typ
+ // convenience methods
+ def inEdges(vn: VName): Set[EName] = target.codf(vn)
+
def isAdjacentToBoundary(v: VName): Boolean = adjacentVerts(v).exists(isBoundary)
+
+ def isBoundary(vn: VName): Boolean = isTerminalWire(vn)
+
def isAdjacentToType(v: VName, t: String): Boolean = adjacentVerts(v).exists(typeOf(_) == t)
- def isWireVertex(v: VName) = vdata(v).isWireVertex
- def representsWire(vn: VName) = vdata(vn).isWireVertex &&
+ def typeOf(v: VName): String = vdata(v).typ
+
+ /** Returns a set of vertex names adjacent to vn */
+ def adjacentVerts(vn: VName): Set[VName] = predVerts(vn) union succVerts(vn)
+
+ def predVerts(vn: VName): Set[VName] = inEdges(vn).map(source(_))
+
+ def succVerts(vn: VName): Set[VName] = outEdges(vn).map(target(_))
+
+ def isWireVertex(v: VName): Boolean = vdata(v).isWireVertex
+
+ def numAdjacentBoundaries(vn: VName): Int = adjacentVerts(vn).count(isTerminalWire)
+
+ def representsWire(vn: VName): Boolean = vdata(vn).isWireVertex &&
(predVerts(vn).headOption match {
case None => true
case Some(vn1) => vn == vn1 || !vdata(vn1).isWireVertex
})
- def representsBareWire(vn: VName) =
- isInput(vn) &&
+ def representsBareWire(vn: VName): Boolean =
+ isInputWire(vn) &&
(succVerts(vn).headOption match {
case None => false
- case Some(vn1) => isOutput(vn1)
+ case Some(vn1) => isOutputWire(vn1)
})
- def arity(v: VName) = adjacentEdges(v).size
+ def isInputWire(v: VName): Boolean = vdata(v).isWireVertex && inEdges(v).isEmpty && outEdges(v).size == 1
- def verts: Set[VName] = vdata.keySet
- def edges: Set[EName] = edata.keySet
- def bboxes: Set[BBName] = bbdata.keySet
+ def isOutputWire(v: VName): Boolean = vdata(v).isWireVertex && outEdges(v).isEmpty && inEdges(v).size == 1
+
+ def arity(v: VName): Int = adjacentEdges(v).size
+
+ def inputs: Set[VName] = verts.filter(isInputWire)
- def inputs: Set[VName] = verts.filter(isInput)
- def outputs: Set[VName] = verts.filter(isOutput)
- def boundary: Set[VName] = verts.filter(isBoundary)
+ def outputs: Set[VName] = verts.filter(isOutputWire)
+
+ def boundary: Set[VName] = verts.filter(isTerminalWire)
override def hashCode: Int = {
var h = data.hashCode
@@ -89,8 +121,6 @@ case class Graph(
h
}
- def canEqual(other: Any): Boolean = other.isInstanceOf[Graph]
-
override def equals(other: Any): Boolean = other match {
case that: Graph => (that canEqual this) &&
vdata == that.vdata &&
@@ -103,61 +133,72 @@ case class Graph(
case _ => false
}
+ def canEqual(other: Any): Boolean = other.isInstanceOf[Graph]
- protected val factory = new Graph(_,_,_,_,_,_,_,_)
+ def vars: Set[String] = vdata.values.foldLeft(Set.empty[String]) {
+ case (vs, d: NodeV) =>
+ vs ++ d.phaseData.vars
+ case (vs, _) => vs
+ }
- def copy(data: GData = this.data,
- vdata: Map[VName,VData] = this.vdata,
- edata: Map[EName,EData] = this.edata,
- source: PFun[EName,VName] = this.source,
- target: PFun[EName,VName] = this.target,
- bbdata: Map[BBName,BBData] = this.bbdata,
- inBBox: BinRel[VName,BBName] = this.inBBox,
- bboxParent: PFun[BBName,BBName] = this.bboxParent): Graph =
- factory(data,vdata,edata,source,target,bbdata,inBBox,bboxParent)
+ // The set of nodes and wire-boundaires at ends of wires attached to vn
+ def adjacentNodesAndBoundaries(vn: VName): Set[VName] = adjacentEdges(vn).flatMap(e => edgeEndPoints(e)._1) - vn
- // convenience methods
- def inEdges(vn: VName): Set[EName] = target.codf(vn)
- def outEdges(vn: VName): Set[EName] = source.codf(vn)
- def predVerts(vn: VName): Set[VName] = inEdges(vn).map(source(_))
- def succVerts(vn: VName): Set[VName] = outEdges(vn).map(target(_))
- def contents(bbn: BBName): Set[VName] = inBBox.codf(bbn)
- def bboxesContaining(vn: VName): Set[BBName] = inBBox.domf(vn)
+ /**
+ * Traverse the wire nodes to find the non-wire endpoints of an edge
+ *
+ * @param edge Edge to focus on
+ * @return endPoints, edgesBetweenThem, wireNodesBetweenThem
+ */
+ def edgeEndPoints(edge: EName): (Set[VName], Set[EName], Set[VName]) = {
+ // Given an edge find its two end points,
+ // where wire nodes of arity 2 do not count as endpoints
+ // If given a loop then returns an empty set
+ var edgesVisited: Set[EName] = Set()
+ var nodesVisited: Set[VName] = Set()
- def vars: Set[String] = vdata.values.foldLeft(Set.empty[String]) {
- case (vs, d: NodeV) =>
- vs ++ d.angle.vars
- case (vs,_) => vs
- }
+ var endPointsFound: Set[VName] = Set()
- /** Returns a set of vertex names adjacent to vn */
- def adjacentVerts(vn: VName): Set[VName] = predVerts(vn) union succVerts(vn)
+ // slightly different to arity, since arity counts self-loops as one edge not two
+ def inPlusOut(vertex: VName): Int = {
+ source.codf(vertex).size + target.codf(vertex).size
+ }
+
+ def considerNode(v2: VName) {
+ if (!nodesVisited.contains(v2)) {
+ if (!vdata(v2).isWireVertex || inPlusOut(v2) != 2) {
+ endPointsFound += v2
+ } else {
+ nodesVisited += v2
+ adjacentEdges(v2).foreach(considerEdge)
+ }
+ }
+ }
+
+ def considerEdge(e2: EName) {
+ if (!edgesVisited.contains(e2)) {
+ edgesVisited += e2
+ considerNode(source(e2))
+ considerNode(target(e2))
+ }
+ }
+
+ considerEdge(edge)
+ (endPointsFound, edgesVisited, nodesVisited)
+ }
/** Returns a set of vertex names adjacent to, and including, vset */
def extendToAdjacentVerts(vset: Set[VName]): Set[VName] =
- vset.foldRight(Set[VName]()) { (v,vs) => (vs union adjacentVerts(v)) + v }
+ vset.foldRight(Set[VName]()) { (v, vs) => (vs union adjacentVerts(v)) + v }
/** Returns a set of vertex names adjacent to, but not including, vset */
def adjacentVerts(vset: Set[VName]): Set[VName] =
- vset.foldRight(Set[VName]()) { (v,vs) => vs union adjacentVerts(v) } -- vset
-
- /** Returns a set of edge names adjacent to vn */
- def adjacentEdges(vn: VName): Set[EName] = source.codf(vn) union target.codf(vn)
+ vset.foldRight(Set[VName]()) { (v, vs) => vs union adjacentVerts(v) } -- vset
/** Returns a set of edge names adjacent to any vertex in vset */
def adjacentEdges(vset: Set[VName]): Set[EName] =
- vset.foldRight(Set[EName]()) { (v,es) => es union adjacentEdges(v) }
-
- /** Returns a set of edge names which connect v1 to v2 or vice versa */
- def edgesBetween(v1: VName, v2: VName): Set[EName] = {
- if (v1 == v2) {
- source.codf(v1) intersect target.codf(v1)
- }
- else {
- adjacentEdges(v1) intersect adjacentEdges(v2)
- }
- }
+ vset.foldRight(Set[EName]()) { (v, es) => es union adjacentEdges(v) }
/**
* If "e" is not a self-loop, get the vertex connected
@@ -170,6 +211,7 @@ case class Graph(
/**
* Get the other edge connected to this wire vertex, if there is one
+ *
* @param w a wire vertex
* @param e an edge
* @return an edge, optionally
@@ -180,43 +222,49 @@ case class Graph(
else throw new GraphException("Wire: " + w + " is not connected to edge: " + e)
}
+ /** Returns a set of edge names adjacent to vn */
+ def adjacentEdges(vn: VName): Set[EName] = source.codf(vn) union target.codf(vn)
+
/**
- * Partition of all edges into sets, s.t. they connect the same two vertices
- * regardless of edge direction
- */
- def edgePartition() : List[Set[EName]] = {
- var res : List[Set[EName]] = List()
- for ((v1,_) <- vdata; (v2,_) <- vdata if v1 <= v2) {
- val edgeSet = edgesBetween(v1,v2)
+ * Partition of all edges into sets, s.t. they connect the same two vertices
+ * regardless of edge direction
+ */
+ def edgePartition: List[Set[EName]] = {
+ var res: List[Set[EName]] = List()
+ for ((v1, _) <- vdata; (v2, _) <- vdata if v1 <= v2) {
+ val edgeSet = edgesBetween(v1, v2)
if (edgeSet.nonEmpty) res = edgeSet :: res
}
res
}
+ /** Returns a set of edge names which connect v1 to v2 or vice versa */
+ def edgesBetween(v1: VName, v2: VName): Set[EName] = {
+ if (v1 == v2) {
+ source.codf(v1) intersect target.codf(v1)
+ }
+ else {
+ adjacentEdges(v1) intersect adjacentEdges(v2)
+ }
+ }
+
def isBBoxed(v: VName): Boolean = inBBox.domf(v).nonEmpty
/// by song
// to compute whether two vertices are in the same bbox.
- def isInSameBBox(v1:VName, v2:VName): Boolean = (inBBox.domf(v1) & inBBox.domf(v2)).nonEmpty
-
- def addVertex(vn: VName, data: VData): Graph = {
- if (vdata contains vn)
- throw new DuplicateVertexNameException(vn)
-
- copy(vdata = vdata + (vn -> data))
- }
+ def isInSameBBox(v1: VName, v2: VName): Boolean = (inBBox.domf(v1) & inBBox.domf(v2)).nonEmpty
/**
- * @return A new graph where all vertices have coordinates which align to a
- * grid
- */
+ * @return A new graph where all vertices have coordinates which align to a
+ * grid
+ */
def snapToGrid(): Graph = {
- def roundCoord(d : Double) = {
+ def roundCoord(d: Double) = {
math.rint(d * 4.0) / 4.0 // rounds to .25
}
- val snapped_vdata = vdata.mapValues {vd =>
+ val snapped_vdata = vdata.mapValues { vd =>
val coord = vd.coord
vd.withCoord(roundCoord(coord._1), roundCoord(coord._2))
}
@@ -228,19 +276,11 @@ case class Graph(
(addVertex(vn, data), vn)
}
- def addEdge(en: EName, data: EData, vns: (VName, VName)): Graph = {
- if (edata contains en)
- throw new DuplicateEdgeNameException(en)
- if (!vdata.contains(vns._1))
- throw new GraphException("Edge: " + en + " has no endpoint: " + vns._1 + " in graph")
- if (!vdata.contains(vns._1))
- throw new GraphException("Edge: " + en + " has no endpoint: " + vns._2 + " in graph")
+ def addVertex(vn: VName, data: VData): Graph = {
+ if (vdata contains vn)
+ throw new DuplicateVertexNameException(vn)
- copy(
- edata = edata + (en -> data),
- source = source + (en -> vns._1),
- target = target + (en -> vns._2)
- )
+ copy(vdata = vdata + (vn -> data))
}
def newEdge(data: EData, vns: (VName, VName)): (Graph, EName) = {
@@ -254,7 +294,7 @@ case class Graph(
val g1 = copy(
bbdata = bbdata + (bbn -> data),
- inBBox = contents.foldLeft(inBBox){ (x,v) => x + (v -> bbn) }
+ inBBox = contents.foldLeft(inBBox) { (x, v) => x + (v -> bbn) }
)
parent match {
@@ -265,10 +305,11 @@ case class Graph(
/**
* A list of bbox parents, with the closest ancestor first
+ *
* @param bb a bbox
* @return
*/
- def bboxParentList(bb : BBName): List[BBName] =
+ def bboxParentList(bb: BBName): List[BBName] =
bboxParent.get(bb) match {
case Some(bb1) => bb1 :: bboxParentList(bb1)
case None => List()
@@ -276,7 +317,8 @@ case class Graph(
/**
* The set of all parents for a given bbox
- * @param bb
+ *
+ * @param bb Bounding Box
* @return
*/
def bboxParents(bb: BBName): Set[BBName] =
@@ -285,16 +327,15 @@ case class Graph(
case None => Set()
}
- def bboxChildren(bb: BBName) : Set[BBName] =
+ def bboxChildren(bb: BBName): Set[BBName] =
bboxParent.codf(bb)
-
def addToBBox(v: VName, bb: BBName): Graph = {
- copy(inBBox = inBBox + (v,bb))
+ copy(inBBox = inBBox + (v, bb))
}
def addToBBoxes(v: VName, bboxes: Set[BBName]): Graph = {
- copy(inBBox = bboxes.foldRight(inBBox) { (bb,mp) => mp + (v,bb) })
+ copy(inBBox = bboxes.foldRight(inBBox) { (bb, mp) => mp + (v, bb) })
}
/** Replace the contents of a bang box with new ones
@@ -306,11 +347,11 @@ case class Graph(
var inBB = inBBox
for (bb1 <- updateBB) {
- oldContents.foreach {v => inBB -= (v -> bb1) }
- newContents.foreach {v => inBB += (v -> bb1) }
+ oldContents.foreach { v => inBB -= (v -> bb1) }
+ newContents.foreach { v => inBB += (v -> bb1) }
}
- copy( inBBox = inBB )
+ copy(inBBox = inBB)
}
/** Change bbox parent. All contents will be removed from old parents and added to
@@ -334,17 +375,16 @@ case class Graph(
}
-
for (bbp <- oldParents) {
- cont.foreach {v => inBB -= (v -> bbp) }
+ cont.foreach { v => inBB -= (v -> bbp) }
}
for (bbp <- newParents) {
- cont.foreach {v => inBB += (v -> bbp) }
+ cont.foreach { v => inBB += (v -> bbp) }
}
- copy( inBBox = inBB , bboxParent = bbP )
+ copy(inBBox = inBB, bboxParent = bbP)
}
def newBBox(data: BBData, contents: Set[VName] = Set[VName](), parent: Option[BBName] = None): (Graph, BBName) = {
@@ -360,15 +400,7 @@ case class Graph(
)
}
- def deleteEdge(en: EName): Graph = {
- copy(
- edata = edata - en,
- source = source.unmapDom(en),
- target = target.unmapDom(en)
- )
- }
-
- def deleteEdges(es: Set[EName]): Graph = es.foldRight(this) { (e,g) => g.deleteEdge(e) }
+ def deleteEdges(es: Set[EName]): Graph = es.foldRight(this) { (e, g) => g.deleteEdge(e) }
def safeDeleteVertex(vn: VName): Graph = {
if (source.codf(vn).nonEmpty || target.codf(vn).nonEmpty)
@@ -378,54 +410,146 @@ case class Graph(
copy(vdata = vdata - vn, inBBox = inBBox.unmapDom(vn))
}
- def deleteVertex(vn: VName): Graph = {
+ def deleteVertices(vs: Set[VName]): Graph = vs.foldRight(this) { (v, g) => g.deleteVertex(v) }
+
+ def cutVertex(vertexName: VName): (Graph, Set[VName], Set[VName]) = cutVertex(vertexName, verts.filter(isBoundary))
+ /**
+ * Delete a vertex, but leave edges dangling if they were attached to another non-boundary node,
+ * removes boundaries adjacent to the cut vertex,
+ * dangling edges have a newly created boundary at one end.
+ *
+ * If trying to remove a boundary, consider just using deleteVertex instead.
+ *
+ * @param vertexName Vertex Name
+ * @param removingBoundaries If any of these are neighbours then remove whole-cloth
+ * @return (Cut graph, new boundaries, removed boundaries)
+ */
+ def cutVertex(vertexName: VName, removingBoundaries: Set[VName]): (Graph, Set[VName], Set[VName]) = {
var g = this
- for (e <- source.codf(vn)) g = g.deleteEdge(e)
- for (e <- target.codf(vn)) g = g.deleteEdge(e)
- g.copy(vdata = vdata - vn, inBBox = inBBox.unmapDom(vn))
- }
+ var newBoundaries: Set[VName] = Set()
+ var oldBoundaries: Set[VName] = Set()
- def deleteVertices(vs: Set[VName]): Graph = vs.foldRight(this) { (v, g) => g.deleteVertex(v) }
+ def midPoint(v1: VData, v2: VData): (Double, Double) = {
+ val c1 = v1.coord
+ val c2 = v2.coord
+ val x = (c1._1 + c2._1) / 2.0
+ val y = (c1._2 + c2._2) / 2.0
+ (x, y)
+ }
+
+ def breakEdge(g: Graph, e: EName): Graph = {
+ var g2 = g
+ val (ends, edges, wireNodes) = edgeEndPoints(e)
+ if (ends.size == 2) {
+ val joinNode = (ends - vertexName).head
+ if (g2.isTerminalWire(joinNode) && removingBoundaries.contains(joinNode)) {
+ oldBoundaries += joinNode
+ // Delete boundaries that would otherwise float after vertex removal
+ g2 = g2.deleteVertex(joinNode)
+ } else {
+ // Create a boundary where we cut the wire
+ val bName = g.verts.freshWithSuggestion(VName("c-" + vertexName.s + "-b"))
+ newBoundaries += bName
+ val oldSource : VName = g.source(e)
+ val direction = if (oldSource == joinNode) joinNode -> bName else bName -> joinNode
+ g2 = g2.addVertex(bName, WireV()).
+ addEdge(g.edges.freshWithSuggestion(e), g.edata(e), direction)
+ // Add a coordinate to our new boundary
+ val newCoordinate = midPoint(g.vdata(vertexName), g.vdata(joinNode))
+ g2 = g2.updateVData(bName) { vd => vd.withCoord(newCoordinate) }
+ }
+ g2 = g2.deleteEdges(edges)
+ g2 = g2.deleteVertices(wireNodes)
+ } else {
+ // the edge given was part of a loop
+ // should never end up here.
+ }
+ g2
+ }
+
+ g = source.codf(vertexName).foldLeft(g) { (g, e) => breakEdge(g, e) }
+ g = target.codf(vertexName).foldLeft(g) { (g, e) => breakEdge(g, e) }
+ g = g.deleteVertex(vertexName)
+ (g, newBoundaries, oldBoundaries)
+ }
// data updaters
def updateData(f: GData => GData): Graph = copy(data = f(data))
- def updateVData(vn: VName)(f: VData => VData): Graph = copy(vdata = vdata + (vn -> f(vdata(vn))))
+
def updateEData(en: EName)(f: EData => EData): Graph = copy(edata = edata + (en -> f(edata(en))))
+
def updateBBData(bbn: BBName)(f: BBData => BBData): Graph = copy(bbdata = bbdata + (bbn -> f(bbdata(bbn))))
- def rename(vrn: Map[VName,VName], ern: Map[EName, EName], brn: Map[BBName,BBName]): Graph = {
+ def rename(vrn: Map[VName, VName] = Map(), ern: Map[EName, EName] = Map(), brn: Map[BBName, BBName] = Map()): Graph = {
// compute inverses
-// val vrni = vrn.foldLeft(Map[VName,VName]()) { case (mp, (k,v)) => mp + (v -> k) }
-// val erni = ern.foldLeft(Map[EName,EName]()) { case (mp, (k,v)) => mp + (v -> k) }
-// val brni = brn.foldLeft(Map[BBName,BBName]()) { case (mp, (k,v)) => mp + (v -> k) }
-
- val vdata1 = vdata.foldLeft(Map[VName,VData]()) { case (mp, (k,v)) => mp + (vrn(k) -> v)}
- val edata1 = edata.foldLeft(Map[EName,EData]()) { case (mp, (k,v)) => mp + (ern(k) -> v)}
- val bbdata1 = bbdata.foldLeft(Map[BBName,BBData]()) { case (mp, (k,v)) => mp + (brn(k) -> v)}
- val source1 = source.foldLeft(PFun[EName,VName]()) { case (mp, (k,v)) => mp + (ern(k) -> vrn(v))}
- val target1 = target.foldLeft(PFun[EName,VName]()) { case (mp, (k,v)) => mp + (ern(k) -> vrn(v))}
- val inBBox1 = inBBox.foldLeft(BinRel[VName,BBName]()) { case (mp, (k,v)) => mp + (vrn(k) -> brn(v))}
- val bboxParent1 = bboxParent.foldLeft(PFun[BBName,BBName]()) { case (mp, (k,v)) => mp + (brn(k) -> brn(v))}
-
- copy(vdata=vdata1,edata=edata1,source=source1,target=target1,
- bbdata=bbdata1,inBBox=inBBox1,bboxParent=bboxParent1)
+ // val vrni = vrn.foldLeft(Map[VName,VName]()) { case (mp, (k,v)) => mp + (v -> k) }
+ // val erni = ern.foldLeft(Map[EName,EName]()) { case (mp, (k,v)) => mp + (v -> k) }
+ // val brni = brn.foldLeft(Map[BBName,BBName]()) { case (mp, (k,v)) => mp + (v -> k) }
+
+ val vdata1 = vdata.foldLeft(Map[VName, VData]()) { case (mp, (k, v)) => mp + (vrn.getOrElse(k, k) -> v) }
+ val edata1 = edata.foldLeft(Map[EName, EData]()) { case (mp, (k, v)) => mp + (ern.getOrElse(k, k) -> v) }
+ val bbdata1 = bbdata.foldLeft(Map[BBName, BBData]()) { case (mp, (k, v)) => mp + (brn.getOrElse(k, k) -> v) }
+ val source1 = source.foldLeft(PFun[EName, VName]()) { case (mp, (k, v)) => mp + (ern.getOrElse(k, k) -> vrn.getOrElse(v, v)) }
+ val target1 = target.foldLeft(PFun[EName, VName]()) { case (mp, (k, v)) => mp + (ern.getOrElse(k, k) -> vrn.getOrElse(v, v)) }
+ val inBBox1 = inBBox.foldLeft(BinRel[VName, BBName]()) { case (mp, (k, v)) => mp + (vrn.getOrElse(k, k) -> brn.getOrElse(v, v)) }
+ val bboxParent1 = bboxParent.foldLeft(PFun[BBName, BBName]()) { case (mp, (k, v)) => mp + (brn.getOrElse(k, k) -> brn.getOrElse(v, v)) }
+
+ copy(vdata = vdata1, edata = edata1, source = source1, target = target1,
+ bbdata = bbdata1, inBBox = inBBox1, bboxParent = bboxParent1)
}
// get a subgraph consisting of the given vertices and bboxes, with any edges/nesting between them
def fullSubgraph(vs: Set[VName], bbs: Set[BBName]): Graph = {
val es = edges.filter { e => vs.contains(source(e)) && vs.contains(target(e)) }
- val vdata1 = vdata.filter { case (v,_) => vs.contains(v) }
- val edata1 = edata.filter { case (e,_) => es.contains(e) }
- val source1 = source.filter { case (e,_) => es.contains(e) }
- val target1 = target.filter { case (e,_) => es.contains(e) }
- val inBBox1 = inBBox.filter{ case (v,b) => vs.contains(v) && bbs.contains(b) }
- val bbdata1 = bbdata.filter { case (b,_) => bbs.contains(b) }
- val bboxParent1 = bboxParent.filter { case (b1,b2) => bbs.contains(b1) && bbs.contains(b2) }
+ val vdata1 = vdata.filter { case (v, _) => vs.contains(v) }
+ val edata1 = edata.filter { case (e, _) => es.contains(e) }
+ val source1 = source.filter { case (e, _) => es.contains(e) }
+ val target1 = target.filter { case (e, _) => es.contains(e) }
+ val inBBox1 = inBBox.filter { case (v, b) => vs.contains(v) && bbs.contains(b) }
+ val bbdata1 = bbdata.filter { case (b, _) => bbs.contains(b) }
+ val bboxParent1 = bboxParent.filter { case (b1, b2) => bbs.contains(b1) && bbs.contains(b2) }
- copy(data=GData(),vdata=vdata1,edata=edata1,source=source1,target=target1,
- bbdata=bbdata1,inBBox=inBBox1,bboxParent=bboxParent1)
+ copy(data = GData(), vdata = vdata1, edata = edata1, source = source1, target = target1,
+ bbdata = bbdata1, inBBox = inBBox1, bboxParent = bboxParent1)
+ }
+
+ def renameAvoiding(g: Graph): Graph = makeRenaming(g.verts, g.edges, g.bboxes).image(this)
+
+ // form a new graph by merging the given (non-empty) set of vertices into a new vertex with the given name
+ def mergeVertices(vs: Set[VName], newV: VName): Graph = {
+ val rep = vs.head
+ val bboxes = inBBox.directImage(vs)
+
+ val source1 = source.inverseImage(vs).foldRight(source) { (e, mp) => mp + (e -> newV) }
+ val target1 = target.inverseImage(vs).foldRight(target) { (e, mp) => mp + (e -> newV) }
+
+ val g1 = if (verts.contains(newV)) this else addVertex(newV, vdata(rep))
+
+ g1.copy(source = source1, target = target1)
+ .deleteVertices(vs - newV)
+ .addToBBoxes(newV, bboxes)
+ }
+
+ // add g to this graph, plugging 'b' in this graph into 'bg' in g.
+ def plugGraph(g: Graph, b: VName, bg: VName): Graph = {
+ // freshen target graph w.r.t. source
+ val mp = makeRenaming(verts, edges, bboxes)
+ val g1 = mp.image(g)
+ val bg1 = mp.v(bg)
+
+ // re-position g relative to this graph, using the boundaries as a guide
+ val bgcoord = g1.vdata(bg1).coord
+ val bcoord = vdata(b).coord
+ val dx = bcoord._1 - bgcoord._1
+ val dy = bcoord._2 - bgcoord._2
+
+ val g2 = g1.verts.foldLeft(g1) { (g1, v) =>
+ g1.updateVData(v) { d => d.withCoord(d.coord._1 + dx, d.coord._2 + dy) }
+ }
+
+ this.appendGraph(g2).plugBoundaries(b, bg1)
}
def makeRenaming(avoidVerts: Set[VName],
@@ -458,22 +582,26 @@ case class Graph(
mp
}
- def renameAvoiding(g: Graph): Graph = makeRenaming(g.verts, g.edges, g.bboxes).image(this)
+ def edges: Set[EName] = edata.keySet
+
+ def bboxes: Set[BBName] = bbdata.keySet
// append the given graph. note that its names should already be fresh
- def appendGraph(g: Graph): Graph = {
+ def appendGraph(g: Graph, noOverlap : Boolean = true): Graph = {
val coords = verts.map(vdata(_).coord)
// Pick any vertex in g and offset until that vertex is not sitting exactly
// on top of another.
var offset = 0.0
- g.verts.headOption.foreach { v1 =>
- val (x,y) = g.vdata(v1).coord
- while (coords.contains((x + offset, y))) offset += 1.0
+ if(noOverlap) {
+ g.verts.headOption.foreach { v1 =>
+ val (x, y) = g.vdata(v1).coord
+ while (coords.contains((x + offset, y))) offset += 1.0
+ }
}
- val g1 = g.verts.foldLeft(g) { (g1,v) =>
- g1.updateVData(v) { d => d.withCoord (d.coord._1 + offset, d.coord._2) }
+ val g1 = g.verts.foldLeft(g) { (g1, v) =>
+ g1.updateVData(v) { d => d.withCoord(d.coord._1 + offset, d.coord._2) }
}
copy(
@@ -487,65 +615,55 @@ case class Graph(
)
}
+ def updateVData(vn: VName)(f: VData => VData): Graph = copy(vdata = vdata + (vn -> f(vdata(vn))))
+
+ def updateAllVData(f: VData => VData): Graph = {
+ val graph = this
+ graph.verts.foldLeft(graph) { (g, v) =>
+ g.updateVData(v)(f)
+ }
+ }
+
def plugBoundaries(b1: VName, b2: VName): Graph = {
// pull a boundary edge, which we'll inherit the data from
val be = inEdges(b2).headOption.getOrElse(
- outEdges(b2).headOption.getOrElse (
+ outEdges(b2).headOption.getOrElse(
throw new PluggingException("Target boundary is an isolated point.")))
val beData = edata(be)
// figure out who should be the source and target
- val (s,t) = (
+ val (s, t) = (
predVerts(b1).headOption, succVerts(b1).headOption,
predVerts(b2).headOption, succVerts(b2).headOption) match {
- case (None, Some(t1), Some(s1), None) => (s1,t1)
- case (Some(s1), None, None, Some(t1)) => (s1,t1)
- case (Some(s1), None, Some(t1), None) if !beData.isDirected => (s1,t1)
- case (None, Some(s1), None, Some(t1)) if !beData.isDirected => (s1,t1)
+ case (None, Some(t1), Some(s1), None) => (s1, t1)
+ case (Some(s1), None, None, Some(t1)) => (s1, t1)
+ case (Some(s1), None, Some(t1), None) if !beData.isDirected => (s1, t1)
+ case (None, Some(s1), None, Some(t1)) if !beData.isDirected => (s1, t1)
case _ => throw new PluggingException("Bad boundary arity")
}
- this.deleteVertex(b1).deleteVertex(b2).addEdge(be, beData, (s,t))
+ this.deleteVertex(b1).deleteVertex(b2).addEdge(be, beData, (s, t))
}
- // form a new graph by merging the given (non-empty) set of vertices into a new vertex with the given name
- def mergeVertices(vs: Set[VName], newV: VName): Graph = {
- val rep = vs.head
- val bboxes = inBBox.directImage(vs)
-
- val source1 = source.inverseImage(vs).foldRight(source) { (e, mp) => mp + (e -> newV) }
- val target1 = target.inverseImage(vs).foldRight(target) { (e, mp) => mp + (e -> newV) }
-
- val g1 = if (verts.contains(newV)) this else addVertex(newV, vdata(rep))
+ def deleteVertex(vn: VName): Graph = {
+ var g = this
+ for (e <- source.codf(vn)) g = g.deleteEdge(e)
+ for (e <- target.codf(vn)) g = g.deleteEdge(e)
- g1.copy(source = source1, target = target1)
- .deleteVertices(vs - newV)
- .addToBBoxes(newV, bboxes)
+ g.copy(vdata = vdata - vn, inBBox = inBBox.unmapDom(vn))
}
- // add g to this graph, plugging 'b' in this graph into 'bg' in g.
- def plugGraph(g: Graph, b: VName, bg: VName): Graph = {
- // freshen target graph w.r.t. source
- val mp = makeRenaming(verts, edges, bboxes)
- val g1 = mp.image(g)
- val bg1 = mp.v(bg)
-
- // re-position g relative to this graph, using the boundaries as a guide
- val bgcoord = g1.vdata(bg1).coord
- val bcoord = vdata(b).coord
- val dx = bcoord._1 - bgcoord._1
- val dy = bcoord._2 - bgcoord._2
-
- val g2 = g1.verts.foldLeft(g1) { (g1,v) =>
- g1.updateVData(v) { d => d.withCoord (d.coord._1 + dx, d.coord._2 + dy) }
- }
-
- this.appendGraph(g2).plugBoundaries(b, bg1)
+ def deleteEdge(en: EName): Graph = {
+ copy(
+ edata = edata - en,
+ source = source.unmapDom(en),
+ target = target.unmapDom(en)
+ )
}
def compare(that: Graph): Int = {
- val x : Graph = this
- val y : Graph = that
+ val x: Graph = this
+ val y: Graph = that
if (x.verts.size > y.verts.size) {
1
} else {
@@ -572,78 +690,28 @@ case class Graph(
| bboxes: %s,
| nesting: %s
|}""".stripMargin.format(
- data, vdata,
- edata.map(kv => kv._1 -> "(%s => %s)::%s".format(source(kv._1), target(kv._1), kv._2)),
- bbdata.map(kv => kv._1 -> "%s::%s".format(inBBox.codf(kv._1), kv._2)),
- bboxParent.toString
- )
- }
-
- private def dftSuccessors[T](fromV: VName, exploredV: Set[VName], exploredE: Set[EName])(base: T)
- (f: (T, EName, GraphSearchContext) => T): (T, Set[VName], Set[EName]) =
- {
- val nextEs = outEdges(fromV).filter(!exploredE.contains(_))
-
- if (nextEs.nonEmpty) {
- val e = nextEs.min
- val nextV = target(e)
-
- val (base1, exploredV1, exploredE1) =
- dftSuccessors(nextV, exploredV + nextV, exploredE + e)(base)(f)
- val (base2, exploredV2, exploredE2) =
- dftSuccessors(fromV, exploredV1, exploredE1)(base1)(f)
-
- val context = GraphSearchContext(exploredV2, exploredE2)
- (f(base2, e, context), exploredV2, exploredE2)
- } else {
- (base, exploredV, exploredE)
- }
- }
-
- private def dftComponents[T](exploredV: Set[VName], exploredE: Set[EName])(base: T)
- (f: (T, EName, GraphSearchContext) => T) : T =
- {
- val nextVs = vdata.keySet.filter(!exploredV.contains(_))
- val initialVs = nextVs.filter(inEdges(_).isEmpty)
-
- // Try to start with the minimal unexplored vertex with no in-edges. Failing that, start with the
- // minimal unexplored vertex.
- val vOpt = if (initialVs.nonEmpty) Some(initialVs.min)
- else if (nextVs.nonEmpty) Some(nextVs.min)
- else None
-
- vOpt match {
- case Some(v) =>
- val (base1, exploredV1, exploredE1) = dftSuccessors(v, exploredV + v, exploredE)(base)(f)
- dftComponents[T](exploredV1, exploredE1)(base1)(f)
- case None => base
- }
+ data, vdata,
+ edata.map(kv => kv._1 -> "(%s => %s)::%s".format(source(kv._1), target(kv._1), kv._2)),
+ bbdata.map(kv => kv._1 -> "%s::%s".format(inBBox.codf(kv._1), kv._2)),
+ bboxParent.toString
+ )
}
def dft[T](base: T)(f: (T, EName, GraphSearchContext) => T): T =
dftComponents(Set[VName](), Set[EName]())(base)(f)
-
- private def bbDft(bb : BBName, bbSeq : collection.mutable.Buffer[BBName], bbs : collection.mutable.Set[BBName]) {
- for (ch <- bboxParent.codf(bb)) bbDft(ch, bbSeq, bbs)
- if (bbs.contains(bb)) {
- bbs.remove(bb)
- bbSeq += bb
- }
- }
-
def bboxesChildrenFirst: Seq[BBName] = {
val bbSeq = collection.mutable.Buffer[BBName]()
- val bbs = collection.mutable.Set[BBName](bboxes.toSeq : _*)
- while (bbs.nonEmpty) bbDft(bbs.iterator.next(),bbSeq,bbs)
+ val bbs = collection.mutable.Set[BBName](bboxes.toSeq: _*)
+ while (bbs.nonEmpty) bbDft(bbs.iterator.next(), bbSeq, bbs)
- bbSeq.toSeq
+ bbSeq
}
// returns a topo ordering. If graph is a dag, all edges will be consistent with this ordering
def topologicalOrdering: PartialOrdering[VName] = {
val visited = collection.mutable.Set[VName]()
- var ordMap = Map[VName,Int]()
+ var ordMap = Map[VName, Int]()
var max = 0
def visit(v: VName) {
@@ -663,9 +731,9 @@ case class Graph(
case _ => None
}
- def lteq(x: VName, y: VName): Boolean = tryCompare(x,y) match {
- case Some(c) => c != 1
- case None => false
+ def lteq(x: VName, y: VName): Boolean = tryCompare(x, y) match {
+ case Some(c) => c != 1
+ case None => false
}
}
}
@@ -673,14 +741,14 @@ case class Graph(
def dagCopy: Graph = {
// make a copy with no edges
val noEdges = copy(
- edata = Map[EName,EData](),
- source = PFun[EName,VName](),
- target = PFun[EName,VName]()
+ edata = Map[EName, EData](),
+ source = PFun[EName, VName](),
+ target = PFun[EName, VName]()
)
val ord = this.topologicalOrdering
- dft(noEdges) { (graph, e, context) =>
+ dft(noEdges) { (graph, e, _) =>
val s = source(e)
val t = target(e)
@@ -688,20 +756,20 @@ case class Graph(
else {
// reverse back-edges to break cycles
graph.addEdge(e, edata(e),
- if (ord.lteq(s,t)) (s,t) else (t,s))
+ if (ord.lteq(s, t)) (s, t) else (t, s))
}
}
}
def simpleCopy: Graph = {
var g = copy(
- edata = Map[EName,EData](),
- source = PFun[EName,VName](),
- target = PFun[EName,VName]()
+ edata = Map[EName, EData](),
+ source = PFun[EName, VName](),
+ target = PFun[EName, VName]()
)
for (v <- verts; w <- verts) {
outEdges(v).find(target(_) == w) match {
- case Some(e) => g = g.addEdge(e, edata(e), (v,w))
+ case Some(e) => g = g.addEdge(e, edata(e), (v, w))
case None => ()
}
}
@@ -709,6 +777,35 @@ case class Graph(
g
}
+ def verts: Set[VName] = vdata.keySet
+
+ def copy(data: GData = this.data,
+ vdata: Map[VName, VData] = this.vdata,
+ edata: Map[EName, EData] = this.edata,
+ source: PFun[EName, VName] = this.source,
+ target: PFun[EName, VName] = this.target,
+ bbdata: Map[BBName, BBData] = this.bbdata,
+ inBBox: BinRel[VName, BBName] = this.inBBox,
+ bboxParent: PFun[BBName, BBName] = this.bboxParent): Graph =
+ factory(data, vdata, edata, source, target, bbdata, inBBox, bboxParent)
+
+ def outEdges(vn: VName): Set[EName] = source.codf(vn)
+
+ def addEdge(en: EName, data: EData, vns: (VName, VName)): Graph = {
+ if (edata contains en)
+ throw new DuplicateEdgeNameException(en)
+ if (!vdata.contains(vns._1))
+ throw new GraphException("Edge: " + en + " has no endpoint: " + vns._1 + " in graph")
+ if (!vdata.contains(vns._1))
+ throw new GraphException("Edge: " + en + " has no endpoint: " + vns._2 + " in graph")
+
+ copy(
+ edata = edata + (en -> data),
+ source = source + (en -> vns._1),
+ target = target + (en -> vns._2)
+ )
+ }
+
def expandWire(w: VName): (Graph, (VName, VName, EName)) = {
val ed = adjacentEdges(w).headOption match {
case Some(e) => edata(e)
@@ -725,7 +822,7 @@ case class Graph(
outEdges(w).headOption match {
case None => // 'w' is an output, so it should stay an output
g = g.addEdge(newE, ed, newW -> w)
- inEdges(w).headOption.foreach{e =>
+ inEdges(w).headOption.foreach { e =>
g = g.deleteEdge(e).addEdge(e, ed, source(e) -> newW)
}
@@ -770,8 +867,8 @@ case class Graph(
// direct edge in the same direction as the minimum of the two edges connected to w
val endPoints =
- if (target(e1) == w) (source(e1), edgeGetOtherVertex(e2,w))
- else (edgeGetOtherVertex(e2,w), target(e1))
+ if (target(e1) == w) (source(e1), edgeGetOtherVertex(e2, w))
+ else (edgeGetOtherVertex(e2, w), target(e1))
this
.deleteVertex(w)
@@ -780,26 +877,54 @@ case class Graph(
}
/**
- * Put graph in normal form, where each (non-bare) wire has exactly 1 wire vertex
- * @return
+ * Change all wires to boundaries or boundaries to wires depending on how many edges coincide with it
+ *
+ * @return Graph
+ */
+ def coerceWiresAndBoundaries: Graph = {
+ val graph = this
+ graph.verts.foldLeft(graph) { (g, v) =>
+ g.updateVData(v) { d =>
+ if (d.isWireVertex) {
+ d.asInstanceOf[WireV].makeBoundary(
+ graph.isTerminalWire(v)
+ )
+ } else d
+ }
+ }
+ }
+
+ def isTerminalWire(vn: VName): Boolean =
+ vdata(vn).isWireVertex && (inEdges(vn).size + outEdges(vn).size) <= 1
+
+ /**
+ * Put graph in normal form, where each wire has exactly 1 wire vertex
+ *
+ * @return Graph
*/
def normalise: Graph = {
var ch = false
var g = this
for (e <- edges) {
- val s = source(e)
- val t = target(e)
- (vdata(s), vdata(t)) match {
- case (_: NodeV, _: NodeV) =>
- g = g.edgeToWire(e)
- ch = true
- case (_: WireV, _: WireV) if s != t =>
- if (!isBoundary(s) || !isBoundary(t)) {
- g = g.collapseWire(e)
+ if (!ch) {
+ val s = source(e)
+ val t = target(e)
+ (vdata(s), vdata(t)) match {
+ case (_: NodeV, _: NodeV) =>
+ g = g.edgeToWire(e)
ch = true
- }
- case _ => // do nothing
+ case (_: WireV, _: WireV) if s != t =>
+
+ /**
+ * Collapse if between two internal wires, unless going in or out of a bbox
+ */
+ if (!isTerminalWire(s) && !isTerminalWire(t)) {
+ g = g.collapseWire(e)
+ ch = true
+ }
+ case _ => // do nothing
+ }
}
}
@@ -812,10 +937,8 @@ case class Graph(
}
}
- g
-
-// if (ch) g.normalise
-// else g
+ if (ch) g.normalise
+ else g
}
def minimise: Graph = {
@@ -828,6 +951,7 @@ case class Graph(
/**
* make a copy of the given bbox's contents, without copying the bbox itself
+ *
* @param bb the bbox to be expanded
* @return the new graph and a record containing relevant data for replaying the expansion
*/
@@ -839,8 +963,8 @@ case class Graph(
var gfr = mp1.image(g)
// add each expanded vertex to the bboxes that it was already in
- g.verts.foreach{v =>
- ((bboxesContaining(v) -- bboxChildren(bb)) - bb).foreach{ bb1 => gfr = gfr.addToBBox(mp1.v(v), bb1) }
+ g.verts.foreach { v =>
+ ((bboxesContaining(v) -- bboxChildren(bb)) - bb).foreach { bb1 => gfr = gfr.addToBBox(mp1.v(v), bb1) }
}
var g1 = appendGraph(gfr)
@@ -853,21 +977,22 @@ case class Graph(
val e1 = freshE.freshWithSuggestion(e)
freshE = freshE + e1
mp1 = mp1.addEdge(e -> e1)
- g1 = g1.addEdge(e1, edata(e), mp1.v.getOrElse(s,s) -> mp1.v.getOrElse(t,t))
+ g1 = g1.addEdge(e1, edata(e), mp1.v.getOrElse(s, s) -> mp1.v.getOrElse(t, t))
}
- (g1, BBExpand(bb,mp1))
+ (g1, BBExpand(bb, mp1))
}
/**
* make a copy of the given bbox
+ *
* @param bb the bbox to be copied
* @return the new graph and a record containing relevant data for replaying the copy
*/
def copyBBox(bb: BBName, avoidV: Set[VName] = Set(), mp: GraphMap = GraphMap()): (Graph, BBCopy) = {
var (g1, bbe) = expandBBox(bb, avoidV, mp)
val mp1 = if (bbe.mp.bb.domSet contains bb) bbe.mp
- else bbe.mp.copy(bb = bbe.mp.bb + (bb -> g1.bboxes.freshWithSuggestion(bb)))
+ else bbe.mp.copy(bb = bbe.mp.bb + (bb -> g1.bboxes.freshWithSuggestion(bb)))
val bb1 = mp1.bb(bb)
g1 = g1.addBBox(bb1, bbdata(bb), mp1.v.codSet)
@@ -878,6 +1003,7 @@ case class Graph(
/**
* drop the given bbox, keeping the contents intact
+ *
* @param bb the bbox to be dropped
* @return the new graph and a record containing relevant data for replaying the drop
*/
@@ -887,6 +1013,7 @@ case class Graph(
/**
* kill the given bbox, also deleting child nodes and bboxes
+ *
* @param bb the bbox to be dropped
* @return the new graph and a record containing relevant data for replaying the drop
*/
@@ -898,29 +1025,26 @@ case class Graph(
(g1, BBKill(bb))
}
- private def freshMap(mp: PFun[VName, VName], avoid: Set[VName]) = {
- mp
- }
-
/**
* Apply the given bbox operation
- * @param bbop a !-box operation
+ *
+ * @param bbop a !-box operation
* @param avoidV an optional set of (extra) vertices to avoid when copying !-box contents
* @return
*/
def applyBBOp(bbop: BBOp, avoidV: Set[VName] = Set()): Graph = bbop match {
case BBExpand(bb, mp) =>
- val mp1 = GraphMap(v = mp.v.filterKeys(v => verts.contains(v) && isBoundary(v)), bb = mp.bb)
+ val mp1 = GraphMap(v = mp.v.filterKeys(v => verts.contains(v) && isTerminalWire(v)), bb = mp.bb)
expandBBox(bb, avoidV, mp1)._1
case BBCopy(bb, mp) =>
- val mp1 = GraphMap(v = mp.v.filterKeys(v => verts.contains(v) && isBoundary(v)), bb = mp.bb)
+ val mp1 = GraphMap(v = mp.v.filterKeys(v => verts.contains(v) && isTerminalWire(v)), bb = mp.bb)
copyBBox(bb, avoidV, mp1)._1
case BBDrop(bb) => dropBBox(bb)._1
case BBKill(bb) => killBBox(bb)._1
}
- def freeVars: Set[String] = vdata.foldRight(Set[String]()) {
- case ((_,d: NodeV), s) => s union d.angle.vars
+ def freeVars: Set[(ValueType, String)] = vdata.foldRight(Set[(ValueType, String)]()) {
+ case ((_, d: NodeV), s) => s union d.phaseData.varsWithType
case (_, s) => s
}
@@ -930,58 +1054,117 @@ case class Graph(
contents(bb).forall(v => bboxesContaining(v).size > 1)
}
- def toJson(theory: Theory) : Json = Graph.toJson(this, theory)
-}
+ def contents(bbn: BBName): Set[VName] = inBBox.codf(bbn)
-object Graph {
-// val Flavor = new DataFlavor(Graph.getClass, "X-quantoderive/qgraph; class=;")
-// class GraphPacket(graph: Graph, val theory: Theory) extends Transferable {
-// def getTransferData(f: DataFlavor) = this
-// def isDataFlavorSupported(f: DataFlavor) = { f == Graph.Flavor }
-// def getTransferDataFlavors = Array(Graph.Flavor)
-// }
+ def bboxesContaining(vn: VName): Set[BBName] = inBBox.domf(vn)
- implicit def qGraphAndNameToQGraph[N <: Name[N]](t: (Graph, Name[N])) : Graph = t._1
+ def toJson(theory: Theory = Theory.DefaultTheory): Json = Graph.toJson(this, theory)
- def apply(theory: Theory): Graph = Graph(data = GData(theory = theory))
+ private def dftSuccessors[T](fromV: VName, exploredV: Set[VName], exploredE: Set[EName])(base: T)
+ (f: (T, EName, GraphSearchContext) => T): (T, Set[VName], Set[EName]) = {
+ val nextEs = outEdges(fromV).filter(!exploredE.contains(_))
+ if (nextEs.nonEmpty) {
+ val e = nextEs.min
+ val nextV = target(e)
- def fromJson(s: String, thy: Theory): Graph =
- try { fromJson(Json.parse(s), thy) }
- catch { case e:JsonParseException => throw new GraphLoadException("Error parsing JSON", e) }
+ val (base1, exploredV1, exploredE1) =
+ dftSuccessors(nextV, exploredV + nextV, exploredE + e)(base)(f)
+ val (base2, exploredV2, exploredE2) =
+ dftSuccessors(fromV, exploredV1, exploredE1)(base1)(f)
+
+ val context = GraphSearchContext(exploredV2, exploredE2)
+ (f(base2, e, context), exploredV2, exploredE2)
+ } else {
+ (base, exploredV, exploredE)
+ }
+ }
+
+ private def dftComponents[T](exploredV: Set[VName], exploredE: Set[EName])(base: T)
+ (f: (T, EName, GraphSearchContext) => T): T = {
+ val nextVs = vdata.keySet.filter(!exploredV.contains(_))
+ val initialVs = nextVs.filter(inEdges(_).isEmpty)
+ // Try to start with the minimal unexplored vertex with no in-edges. Failing that, start with the
+ // minimal unexplored vertex.
+ val vOpt = if (initialVs.nonEmpty) Some(initialVs.min)
+ else if (nextVs.nonEmpty) Some(nextVs.min)
+ else None
+
+ vOpt match {
+ case Some(v) =>
+ val (base1, exploredV1, exploredE1) = dftSuccessors(v, exploredV + v, exploredE)(base)(f)
+ dftComponents[T](exploredV1, exploredE1)(base1)(f)
+ case None => base
+ }
+ }
+
+ private def bbDft(bb: BBName, bbSeq: collection.mutable.Buffer[BBName], bbs: collection.mutable.Set[BBName]) {
+ for (ch <- bboxParent.codf(bb)) bbDft(ch, bbSeq, bbs)
+ if (bbs.contains(bb)) {
+ bbs.remove(bb)
+ bbSeq += bb
+ }
+ }
+
+ private def freshMap(mp: PFun[VName, VName], avoid: Set[VName]) = {
+ mp
+ }
+}
+
+object Graph {
+ // val Flavor = new DataFlavor(Graph.getClass, "X-quantoderive/qgraph; class=;")
+ // class GraphPacket(graph: Graph, val theory: Theory) extends Transferable {
+ // def getTransferData(f: DataFlavor) = this
+ // def isDataFlavorSupported(f: DataFlavor) = { f == Graph.Flavor }
+ // def getTransferDataFlavors = Array(Graph.Flavor)
+ // }
+
+ implicit def qGraphAndNameToQGraph[N <: Name[N]](t: (Graph, Name[N])): Graph = t._1
def fromJson(s: String): Graph = fromJson(s, Theory.DefaultTheory)
+ def fromJson(s: String, thy: Theory): Graph =
+ try {
+ fromJson(Json.parse(s), thy)
+ }
+ catch {
+ case e: JsonParseException => throw new GraphLoadException("Error parsing JSON", e)
+ }
+
def fromJson(json: Json, thy: Theory = Theory.DefaultTheory): Graph = try {
Function.chain[Graph](Seq(
- (json ? "wire_vertices").asObject.foldLeft(_) { (g,v) =>
+ (json ? "wire_vertices").asObject.foldLeft(_) { (g, v) =>
g.addVertex(v._1, WireV.fromJson(v._2, thy))
},
- (json ? "node_vertices").asObject.foldLeft(_) { (g,v) =>
+ (json ? "node_vertices").asObject.foldLeft(_) { (g, v) =>
g.addVertex(v._1, NodeV.fromJson(v._2, thy))
},
- (json ? "dir_edges").asObject.foldLeft(_) { (g,e) =>
+ (json ? "dir_edges").asObject.foldLeft(_) { (g, e) =>
val data = e._2.getOrElse("data", thy.defaultEdgeData).asObject
val annotation = (e._2 ? "annotation").asObject
g.addEdge(e._1, DirEdge(data, annotation, thy),
((e._2 / "src").stringValue, (e._2 / "tgt").stringValue))
},
- (json ? "undir_edges").asObject.foldLeft(_) { (g,e) =>
+ (json ? "undir_edges").asObject.foldLeft(_) { (g, e) =>
val data = e._2.getOrElse("data", thy.defaultEdgeData).asObject
val annotation = (e._2 ? "annotation").asObject
g.addEdge(e._1, UndirEdge(data, annotation, thy), ((e._2 / "src").stringValue, (e._2 / "tgt").stringValue))
},
- (json ? "bang_boxes").asObject.foldLeft(_) { (g,bb) =>
+ (json ? "bang_boxes").asObject.foldLeft(_) { (g, bb) =>
val data = (bb._2 ? "data").asObject
val annotation = (bb._2 ? "annotation").asObject
- val contains = (bb._2 ? "contents").vectorValue map { VName(_) }
- val parent = bb._2.get("parent") map { BBName(_) }
+ val contains = (bb._2 ? "contents").vectorValue map {
+ VName(_)
+ }
+ val parent = bb._2.get("parent") map {
+ BBName(_)
+ }
g.addBBox(bb._1, BBData(data, annotation), contains.toSet, parent)
}
@@ -999,14 +1182,12 @@ object Graph {
}
def toJson(graph: Graph, thy: Theory = Theory.DefaultTheory): Json = {
- val (wireVertices, nodeVertices) = graph.vdata.foldLeft((JsonObject(), JsonObject()))
- {
- case ((objW,objN), (v,w: WireV)) => (objW + (v.toString -> w.toJson), objN)
- case ((objW,objN), (v,n: NodeV)) => (objW, objN + (v.toString -> n.toJson))
+ val (wireVertices, nodeVertices) = graph.vdata.foldLeft((JsonObject(), JsonObject())) {
+ case ((objW, objN), (v, w: WireV)) => (objW + (v.toString -> w.toJson), objN)
+ case ((objW, objN), (v, n: NodeV)) => (objW, objN + (v.toString -> n.toJson))
}
- val (dirEdges, undirEdges) = graph.edata.foldLeft((JsonObject(), JsonObject()))
- { case ((objD,objU), (e,d)) =>
+ val (dirEdges, undirEdges) = graph.edata.foldLeft((JsonObject(), JsonObject())) { case ((objD, objU), (e, d)) =>
val entry = e.toString -> (d.toJson + ("src" -> graph.source(e).toString, "tgt" -> graph.target(e).toString))
if (d.isDirected) (objD + entry, objU) else (objD, objU + entry)
}
@@ -1014,11 +1195,12 @@ object Graph {
val bangBoxes = graph.bbdata.foldLeft(JsonObject()) { case (obj, (bb, d)) =>
obj + (bb.toString ->
JsonObject(
- "contents" -> JsonArray(graph.contents(bb)),
- "parent" -> (graph.bboxParent.get(bb) match {
+ "contents" -> JsonArray(graph.contents(bb)),
+ "parent" -> (graph.bboxParent.get(bb) match {
case Some(p) => JsonString(p.toString)
- case None => JsonNull }),
- "data" -> d.data,
+ case None => JsonNull
+ }),
+ "data" -> d.data,
"annotation" -> d.annotation
).noEmpty)
}
@@ -1026,18 +1208,18 @@ object Graph {
JsonObject(
"wire_vertices" -> wireVertices.asObjectOrKeyArray,
"node_vertices" -> nodeVertices.asObjectOrKeyArray,
- "dir_edges" -> dirEdges,
- "undir_edges" -> undirEdges,
- "bang_boxes" -> bangBoxes,
- "data" -> graph.data.data,
- "annotation" -> graph.data.annotation
+ "dir_edges" -> dirEdges,
+ "undir_edges" -> undirEdges,
+ "bang_boxes" -> bangBoxes,
+ "data" -> graph.data.data,
+ "annotation" -> graph.data.annotation
).noEmpty
}
def random(nverts: Int, nedges: Int, nbboxes: Int = 0): Graph = {
val rand = new util.Random
var randomGraph = Graph()
- for (i <- 1 to nverts) {
+ for (_ <- 1 to nverts) {
val p = (rand.nextDouble * 6.0 - 3.0, rand.nextDouble * 6.0 - 3.0)
if (rand.nextBoolean()) randomGraph = randomGraph.newVertex(NodeV(p))
else randomGraph = randomGraph.newVertex(WireV(p))
@@ -1046,21 +1228,21 @@ object Graph {
if (nverts != 0) {
val sources = new ArrayBuffer[VName](randomGraph.vdata.keys.size)
val targets = new ArrayBuffer[VName](randomGraph.vdata.keys.size)
- randomGraph.vdata.keys.foreach{k => sources += k; targets += k}
- for(j <- 1 to nedges if sources.nonEmpty && targets.nonEmpty) {
- val (si,ti) = (rand.nextInt(sources.size), rand.nextInt(targets.size))
+ randomGraph.vdata.keys.foreach { k => sources += k; targets += k }
+ for (_ <- 1 to nedges if sources.nonEmpty && targets.nonEmpty) {
+ val (si, ti) = (rand.nextInt(sources.size), rand.nextInt(targets.size))
val s = sources(si)
val t = targets(ti)
if (randomGraph.vdata(s).isWireVertex) sources -= s
if (randomGraph.vdata(t).isWireVertex) targets -= t
- randomGraph = randomGraph.newEdge(DirEdge(), (s,t))
+ randomGraph = randomGraph.newEdge(DirEdge(), (s, t))
}
val varray = randomGraph.vdata.keys.toArray
- for (i <- 1 to nbboxes) {
- val randomVSet = (1 to sqrt(nverts).toInt).foldLeft(Set[VName]()) { (s,_) =>
+ for (_ <- 1 to nbboxes) {
+ val randomVSet = (1 to sqrt(nverts).toInt).foldLeft(Set[VName]()) { (s, _) =>
s + varray(rand.nextInt(varray.length))
}
@@ -1074,7 +1256,7 @@ object Graph {
def randomDag(nverts: Int, nedges: Int): Graph = {
val rand = new util.Random
var randomGraph = Graph()
- for (i <- 1 to nverts) {
+ for (_ <- 1 to nverts) {
val p = (rand.nextDouble * 6.0 - 3.0, rand.nextDouble * 6.0 - 3.0)
randomGraph = randomGraph.newVertex(NodeV(p))
}
@@ -1082,17 +1264,39 @@ object Graph {
// must have at least two verts to add edges since no self-loops allowed
if (nverts > 1)
- for(j <- 1 to nedges) {
+ for (_ <- 1 to nedges) {
val x = rand.nextInt(varray.length)
val y = rand.nextInt(varray.length - 1)
val s = varray(x)
- val t = varray(if (y >= x) y+1 else y)
- randomGraph = randomGraph.newEdge(DirEdge(), if (s <= t) (s,t) else (t,s))
+ val t = varray(if (y >= x) y + 1 else y)
+ randomGraph = randomGraph.newEdge(DirEdge(), if (s <= t) (s, t) else (t, s))
}
randomGraph
}
+ def variablesUsed(theory: Theory, graph: Graph): Set[String] = {
+ graph.vdata.foldLeft(Set[String]()) { (names, vnd) => {
+ val nodeDataType = theory.vertexTypes(vnd._2.typ).value.typ
+ val phases = CompositeExpression.parseKnowingTypes(vnd._2.data.toString(), nodeDataType)
+ val compositeExpression = CompositeExpression(nodeDataType, phases)
+ names union compositeExpression.vars
+ }
+ }
+ }
+
+ def variablesUsedWithType(theory: Theory, graph: Graph): Set[(ValueType, String)] = {
+ graph.vdata.foldLeft(Set[(ValueType, String)]()) { (names, vnd) =>
+ vnd._2 match {
+ case node: NodeV =>
+ val nodeDataType = theory.vertexTypes(node.typ).value.typ
+ val phases = CompositeExpression.parseKnowingTypes(node.data ? "value", nodeDataType)
+ val compositeExpression = CompositeExpression(nodeDataType, phases)
+ names union compositeExpression.varsWithType
+ case _ => names
+ }
+ }
+ }
def fromAdjMat(amat: AdjMat, rdata: Vector[NodeV], gdata: Vector[NodeV]): Graph = {
val thy =
@@ -1100,18 +1304,31 @@ object Graph {
else if (rdata.nonEmpty) rdata(0).theory
else throw new GraphException("Must give at least one piece of node data")
+
var g = Graph(thy)
- for (i <- 0 until amat.numBoundaries) g = g.addVertex(VName("v"+i), WireV(theory = thy))
+ for (i <- 0 until amat.numBoundaries) g = g.addVertex(VName("v" + i), WireV(theory = thy).withCoord(i,0))
var i = amat.numBoundaries
- for (t <- 0 until amat.numRedTypes; v <- 0 until amat.red(t)) {
- g = g.addVertex(VName("v" + i), rdata(t))
+ def red(i: Int): Int = if (amat.red.size > i) {
+ amat.red(i)
+ } else {
+ 0
+ }
+
+ def green(i: Int): Int = if (amat.green.size > i) {
+ amat.green(i)
+ } else {
+ 0
+ }
+
+ for (t <- 0 until amat.numRedTypes; _ <- 0 until red(t)) {
+ g = g.addVertex(VName("v" + i), rdata(t).withCoord(i, math.sin(i + t)))
i += 1
}
- for (t <- 0 until amat.numGreenTypes; v <- 0 until amat.green(t)) {
- g = g.addVertex(VName("v" + i), gdata(t))
+ for (t <- 0 until amat.numGreenTypes; _ <- 0 until green(t)) {
+ g = g.addVertex(VName("v" + i), gdata(t).withCoord(i, math.sin(i + t)))
i += 1
}
@@ -1125,4 +1342,6 @@ object Graph {
g
}
+
+ def apply(theory: Theory): Graph = Graph(data = GData(theory = theory))
}
diff --git a/scala/src/main/scala/quanto/data/GraphElementData.scala b/scala/src/main/scala/quanto/data/GraphElementData.scala
index 36f36014..e8212ef4 100644
--- a/scala/src/main/scala/quanto/data/GraphElementData.scala
+++ b/scala/src/main/scala/quanto/data/GraphElementData.scala
@@ -3,15 +3,19 @@ package quanto.data
import quanto.util.json.JsonObject
/**
- * An abstract class providing a general interface for accessing
- * information contained in its different components
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/GraphElementData.scala Source code]]
- * @see Known Subclasses below
- * @author Aleks Kissinger
- */
+ * An abstract class providing a general interface for accessing
+ * information contained in its different components
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/GraphElementData.scala Source code]]
+ * @see Known Subclasses below
+ * @author Aleks Kissinger
+ */
abstract class GraphElementData {
def theory: Theory
+
def data: JsonObject
+
def annotation: JsonObject
+
def toJson: JsonObject = JsonObject("data" -> data, "annotation" -> annotation).noEmpty
}
diff --git a/scala/src/main/scala/quanto/data/GraphMap.scala b/scala/src/main/scala/quanto/data/GraphMap.scala
index 2eb5b148..3d3784d6 100644
--- a/scala/src/main/scala/quanto/data/GraphMap.scala
+++ b/scala/src/main/scala/quanto/data/GraphMap.scala
@@ -32,14 +32,16 @@ case class GraphMap(
bb.codf.forall(_._2.size == 1)
}
- def addVertex(p: (VName,VName)): GraphMap = copy(v = v + p)
- def addEdge(p: (EName,EName)): GraphMap = copy(e = e + p)
- def addBBox(p: (BBName,BBName)): GraphMap = copy(bb = bb + p)
+ def addVertex(p: (VName, VName)): GraphMap = copy(v = v + p)
+
+ def addEdge(p: (EName, EName)): GraphMap = copy(e = e + p)
+
+ def addBBox(p: (BBName, BBName)): GraphMap = copy(bb = bb + p)
def isTotal(g: Graph): Boolean =
- v.domSet == g.verts &&
- e.domSet == g.edges &&
- bb.domSet == g.bboxes
+ v.domSet == g.verts &&
+ e.domSet == g.edges &&
+ bb.domSet == g.bboxes
def image(g: Graph): Graph = g.rename(v.toMap, e.toMap, bb.toMap)
}
diff --git a/scala/src/main/scala/quanto/data/HasGraph.scala b/scala/src/main/scala/quanto/data/HasGraph.scala
index ebefa047..02ce4479 100644
--- a/scala/src/main/scala/quanto/data/HasGraph.scala
+++ b/scala/src/main/scala/quanto/data/HasGraph.scala
@@ -4,18 +4,21 @@ import scala.swing.Publisher
import scala.swing.event.Event
abstract class GraphEvent extends Event
+
case class GraphChanged(sender: HasGraph) extends GraphEvent
// will cause any graph views to invalidate and repaint the graph
case class GraphReplaced(sender: HasGraph, clearSelection: Boolean) extends GraphEvent
trait HasGraph extends Publisher {
- protected def gr: Graph
- protected def gr_=(g: Graph)
+ def graph: Graph = gr
- def graph = gr
def graph_=(g: Graph) {
gr = g
publish(GraphChanged(this))
}
+
+ protected def gr: Graph
+
+ protected def gr_=(g: Graph)
}
diff --git a/scala/src/main/scala/quanto/data/Names.scala b/scala/src/main/scala/quanto/data/Names.scala
index f1c4ac6f..8d697563 100644
--- a/scala/src/main/scala/quanto/data/Names.scala
+++ b/scala/src/main/scala/quanto/data/Names.scala
@@ -3,15 +3,10 @@ package quanto.data
import quanto.util.json.JsonString
import scala.collection._
-import quanto.util.StringNamer
-
-
trait Name[This <: Name[This]] extends Ordered[This] {
val s: String
- protected val mk: String => This
-
val (prefix, suffix) = {
var intIndex = s.length
while (intIndex > 0 && s.charAt(intIndex - 1) >= '0' && s.charAt(intIndex - 1) <= '9') intIndex -= 1
@@ -24,41 +19,64 @@ trait Name[This <: Name[This]] extends Ordered[This] {
if (intIndex == s.length) -1 else s.substring(intIndex, s.length).toInt
)
}
+ protected val mk: String => This
- def compare(that: This) = if (prefix < that.prefix) -1
- else if (prefix > that.prefix) 1
- else suffix compare that.suffix
+ def compare(that: This): Int = if (prefix < that.prefix) -1
+ else if (prefix > that.prefix) 1
+ else suffix compare that.suffix
def succ: This = mk(prefix + (suffix + 1))
- override def toString = s
+ override def toString: String = s
}
-case class GName(s: String) extends Name[GName] { protected val mk = GName(_) }
-case class VName(s: String) extends Name[VName] { protected val mk = VName(_) }
-case class EName(s: String) extends Name[EName] { protected val mk = EName(_) }
-case class BBName(s: String) extends Name[BBName] { protected val mk = BBName(_) }
-case class DSName(s: String) extends Name[DSName] { protected val mk = DSName(_) }
+case class GName(s: String) extends Name[GName] {
+ protected val mk = GName(_)
+}
+
+case class VName(s: String) extends Name[VName] {
+ protected val mk = VName(_)
+}
+
+case class EName(s: String) extends Name[EName] {
+ protected val mk = EName(_)
+}
+
+case class BBName(s: String) extends Name[BBName] {
+ protected val mk = BBName(_)
+}
+
+case class DSName(s: String) extends Name[DSName] {
+ protected val mk = DSName(_)
+}
class DuplicateNameException[N <: Name[N]](ty: String, val name: N)
extends Exception("Duplicate " + ty + " name: '" + name + "'")
+
class DuplicateVertexNameException(override val name: VName)
extends DuplicateNameException("vertex", name)
+
class DuplicateEdgeNameException(override val name: EName)
extends DuplicateNameException("edge", name)
+
class DuplicateBBoxNameException(override val name: BBName)
extends DuplicateNameException("bang box", name)
object Names {
+
class NameSet[N <: Name[N]](val set: Set[N]) {
- def fresh(implicit default: N) : N = if (set.isEmpty) default else set.max.succ
- def freshWithSuggestion(s : N) : N = { var t = s; while (set.contains(t)) t = t.succ; t }
+ def fresh(implicit default: N): N = if (set.isEmpty) default else set.max.succ
+
+ def freshWithSuggestion(s: N): N = {
+ var t = s; while (set.contains(t)) t = t.succ; t
+ }
}
- class NameMap[N <: Name[N], T](val map: Map[N,T]) {
- def fresh(implicit default: N) : N = if (map.isEmpty) default else map.keys.max.succ
- def freshWithSuggestion(s : N) : N = {
+ class NameMap[N <: Name[N], T](val map: Map[N, T]) {
+ def fresh(implicit default: N): N = if (map.isEmpty) default else map.keys.max.succ
+
+ def freshWithSuggestion(s: N): N = {
val set = map.keySet
var t = s
while (set.contains(t)) t = t.succ
@@ -66,48 +84,60 @@ object Names {
}
}
-// class NamePFun[N <: Name[N], T](val pf: PFun[N,T]) {
-// def fresh(implicit default: N) : N = if (pf.isEmpty) default else pf.dom.max.succ
-// def freshWithSuggestion(s : N) : N = {
-// val set = pf.domSet
-// var t = s
-// while (set.contains(t)) t = t.succ
-// t
-// }
-// }
+ // class NamePFun[N <: Name[N], T](val pf: PFun[N,T]) {
+ // def fresh(implicit default: N) : N = if (pf.isEmpty) default else pf.dom.max.succ
+ // def freshWithSuggestion(s : N) : N = {
+ // val set = pf.domSet
+ // var t = s
+ // while (set.contains(t)) t = t.succ
+ // t
+ // }
+ // }
// TODO: overkill with implicits?
- implicit def setToNameSet[N <: Name[N]](set : Set[N]):NameSet[N] =
+ implicit def setToNameSet[N <: Name[N]](set: Set[N]): NameSet[N] =
new NameSet(set)
- implicit def mapToNameMap[N <: Name[N], T](map : Map[N,T]):NameMap[N,T] =
+
+ implicit def mapToNameMap[N <: Name[N], T](map: Map[N, T]): NameMap[N, T] =
new NameMap(map)
// these support general-purpose string-for-name substitution
- implicit def stringToGName(s: String): GName = GName(s)
- implicit def stringToVName(s: String): VName = VName(s)
- implicit def stringToEName(s: String): EName = EName(s)
+ implicit def stringToGName(s: String): GName = GName(s)
+
+ implicit def stringToVName(s: String): VName = VName(s)
+
+ implicit def stringToEName(s: String): EName = EName(s)
+
implicit def stringToBBName(s: String): BBName = BBName(s)
+
implicit def stringToDSName(s: String): DSName = DSName(s)
implicit def stringSetToGNameSet(set: Set[String]): Set[GName] = set map GName.apply
+
implicit def stringSetToVNameSet(set: Set[String]): Set[VName] = set map VName.apply
+
implicit def stringSetToENameSet(set: Set[String]): Set[EName] = set map EName.apply
+
implicit def stringSetToBBNameSet(set: Set[String]): Set[BBName] = set map BBName.apply
+
implicit def stringSetToDSNameSet(set: Set[String]): Set[DSName] = set map DSName.apply
// edge creation methods take a pair of vertices
- implicit def stringPairToVNamePair(t: (String,String)): (VName, VName) = (VName(t._1), VName(t._2))
+ implicit def stringPairToVNamePair(t: (String, String)): (VName, VName) = (VName(t._1), VName(t._2))
// these can be used to save names into JSON without conversion
implicit def gNameToJsonString(n: GName): JsonString = quanto.util.json.JsonString(n.toString)
+
implicit def vNameToJsonString(n: VName): JsonString = quanto.util.json.JsonString(n.toString)
+
implicit def eNameToJsonString(n: EName): JsonString = quanto.util.json.JsonString(n.toString)
+
implicit def bbNameToJsonString(n: BBName): JsonString = quanto.util.json.JsonString(n.toString)
- implicit val defaultVName = VName("v0")
- implicit val defaultEName = EName("e0")
- implicit val defaultGName = GName("g0")
- implicit val defaultBBName = BBName("bx0")
- implicit val defaultDSName = DSName("0")
+ implicit val defaultVName: VName = VName("v0")
+ implicit val defaultEName: EName = EName("e0")
+ implicit val defaultGName: GName = GName("g0")
+ implicit val defaultBBName: BBName = BBName("bx0")
+ implicit val defaultDSName: DSName = DSName("0")
}
diff --git a/scala/src/main/scala/quanto/data/PFun.scala b/scala/src/main/scala/quanto/data/PFun.scala
index 975660ee..4cc2bfb2 100644
--- a/scala/src/main/scala/quanto/data/PFun.scala
+++ b/scala/src/main/scala/quanto/data/PFun.scala
@@ -1,51 +1,37 @@
package quanto.data
-import collection.immutable.TreeSet
-import scala.collection.{TraversableLike, GenTraversableOnce, mutable, IterableLike}
-import scala.collection.generic.CanBuildFrom
-import Names._
+import scala.collection.immutable.TreeSet
+import scala.collection.{GenTraversableOnce, IterableLike, mutable}
/**
- * Basically a map, but with cached inverse images
- *
- * @tparam A type of domain elements
- * @tparam B type of codomain elements
- *
- * @constructor Create a new instance by specifying the partial function and
- * the inverse image function
- * @param f The partial function
- * @param finv The inverse image function '''f ^-1^ '''
- * @param keyOrd Order on the domain elements (keys)
- *
- * @author Aleks Kissinger
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/PFun.scala Source code]]
- */
+ * Basically a map, but with cached inverse images
+ *
+ * @tparam A type of domain elements
+ * @tparam B type of codomain elements
+ * @constructor Create a new instance by specifying the partial function and
+ * the inverse image function
+ * @param f The partial function
+ * @param finv The inverse image function '''f ^-1^ '''
+ * @param keyOrd Order on the domain elements (keys)
+ * @author Aleks Kissinger
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/PFun.scala Source code]]
+ */
class PFun[A, B]
-(f : Map[A,B], finv: Map[B,TreeSet[A]])
+(f: Map[A, B], finv: Map[B, TreeSet[A]])
(implicit keyOrd: Ordering[A])
- extends BinRel[A,B] with IterableLike[(A,B), PFun[A,B]]
- with GenTraversableOnce[(A,B)] {
+ extends BinRel[A, B] with IterableLike[(A, B), PFun[A, B]]
+ with GenTraversableOnce[(A, B)] {
- def domf = f.mapValues(Set(_))
+ def domf: Map[A, Set[B]] = f.mapValues(Set(_))
- def codf = finv.withDefaultValue(Set[A]())
+ def codf: Map[B, Set[A]] = finv.withDefaultValue(Set[A]())
- def +(kv: (A,B)) : PFun[A,B] = {
- val finv1 =
- (f.get(kv._1) match {
- case Some(oldV) => if (finv(oldV).size == 1) finv - oldV
- else finv + (oldV -> (finv(oldV) - kv._1))
- case None => finv
- }) + (kv._2 -> (finv.getOrElse(kv._2, TreeSet[A]()) + kv._1))
- new PFun(f + kv,finv1)
- }
+ def contains(kv: (A, B)): Boolean = f.get(kv._1).contains(kv._2)
- def contains(kv: (A,B)) = f.get(kv._1).contains(kv._2)
-
- def unmap(kv: (A, B)) = f.get(kv._1) match {
+ def unmap(kv: (A, B)): PFun[A, B] = f.get(kv._1) match {
case Some(v) if v == kv._2 =>
val domSet = finv(v)
- new PFun[A,B](
+ new PFun[A, B](
f - kv._1,
if (domSet.size == 1) finv - kv._2 else finv + (kv._2 -> (domSet - kv._1))
)
@@ -53,85 +39,106 @@ class PFun[A, B]
case None => this
}
- def unmapDom(k: A) :PFun[A,B] = {
- f.get(k) match {
- case None => this // do nothing
- case Some(v) =>
- val domSet = finv(v)
- new PFun[A,B](
- f - k,
- if (domSet.size == 1) finv - v else finv + (v -> (domSet - k))
- )
- }
- }
-
- def unmapCod(v: B) :PFun[A,B] = {
+ def unmapCod(v: B): PFun[A, B] = {
finv.get(v) match {
case None => this // do nothing
case Some(domSet) =>
- new PFun[A,B](
- domSet.foldLeft(f) { _ - _ },
+ new PFun[A, B](
+ domSet.foldLeft(f) {
+ _ - _
+ },
finv - v
)
}
}
- def restrictDom(s: Set[A]): PFun[A,B] = {
- s.foldRight(PFun[A,B]()) { (k,mp) => get(k) match { case Some(v) => mp + (k -> v); case None => mp } }
+ def restrictDom(s: Set[A]): PFun[A, B] = {
+ s.foldRight(PFun[A, B]()) { (k, mp) => get(k) match {
+ case Some(v) => mp + (k -> v);
+ case None => mp
+ }
+ }
}
+ /**
+ * Similar to '''apply''', but returns an option instead
+ *
+ * @param k Domain element where function should be evaluated
+ * @return Optionally returns the function value at '''k'''
+ */
+ def get(k: A): Option[B] = f.get(k)
+
/** Same as '''unmapDom''' */
- def -(k: A) = unmapDom(k)
+ def -(k: A): PFun[A, B] = unmapDom(k)
- /** Creates an iterator (same as f.iterator) */
- def iterator = f.iterator
+ def unmapDom(k: A): PFun[A, B] = {
+ f.get(k) match {
+ case None => this // do nothing
+ case Some(v) =>
+ val domSet = finv(v)
+ new PFun[A, B](
+ f - k,
+ if (domSet.size == 1) finv - v else finv + (v -> (domSet - k))
+ )
+ }
+ }
/** Returns '''f''' */
- def toMap = f
-
- /**
- * Specifies the behaviour for elements of the domain where the
- * function is not defined. Children have the option to override
- * this to give a default value.
- *
- * @param key Domain element
- * @return Alway throws an exception
- * @throws NoSuchElementException Exception indicates which key is not found
- */
- def default(key: A): B =
- throw new NoSuchElementException("key not found: " + key)
-
- /**
- * Similar to '''apply''', but returns an option instead
- *
- * @param k Domain element where function should be evaluated
- * @return Optionally returns the function value at '''k'''
- */
- def get(k: A) = f.get(k)
+ def toMap: Map[A, B] = f
/**
- * Get the value of the function at the specified domain element
- * '''( F(k) )'''
- *
- * @param k Domain element where function should be evaluated
- * @return Function value at '''k''' if it is defined there, otherwise
- * '''default(k)'''
- */
- def apply(k: A) = f.get(k) match {
+ * Get the value of the function at the specified domain element
+ * '''( F(k) )'''
+ *
+ * @param k Domain element where function should be evaluated
+ * @return Function value at '''k''' if it is defined there, otherwise
+ * '''default(k)'''
+ */
+ def apply(k: A): B = f.get(k) match {
case Some(v) => v
case None => default(k)
}
- protected[this] def newBuilder = new mutable.Builder[(A,B),PFun[A,B]] {
- val s = collection.mutable.Buffer[(A,B)]()
- def result() = PFun(s: _*)
- def clear() = s.clear()
- def +=(elem: (A,B)) = { s += elem; this }
+ /**
+ * Specifies the behaviour for elements of the domain where the
+ * function is not defined. Children have the option to override
+ * this to give a default value.
+ *
+ * @param key Domain element
+ * @return Alway throws an exception
+ * @throws NoSuchElementException Exception indicates which key is not found
+ */
+ def default(key: A): B =
+ throw new NoSuchElementException("key not found: " + key)
+
+ def seq: Seq[(A, B)] = iterator.toSeq
+
+ /** Creates an iterator (same as f.iterator) */
+ def iterator: Iterator[(A, B)] = f.iterator
+
+ def ++(pf: PFun[A, B]): PFun[A, B] = pf.foldLeft(this) { case (mp, kv) => mp + kv }
+
+ def +(kv: (A, B)): PFun[A, B] = {
+ val finv1 =
+ (f.get(kv._1) match {
+ case Some(oldV) => if (finv(oldV).size == 1) finv - oldV
+ else finv + (oldV -> (finv(oldV) - kv._1))
+ case None => finv
+ }) + (kv._2 -> (finv.getOrElse(kv._2, TreeSet[A]()) + kv._1))
+ new PFun(f + kv, finv1)
}
- def seq = iterator.toSeq
+ protected[this] def newBuilder : mutable.Builder[(A, B), PFun[A, B]] = new mutable.Builder[(A, B), PFun[A, B]] {
+ val s: mutable.Buffer[(A, B)] = collection.mutable.Buffer[(A, B)]()
+
+ def result() = PFun(s: _*)
- def ++(pf:PFun[A,B]) = pf.foldLeft(this) { case (mp, kv) => mp + kv }
+ def clear(): Unit = s.clear()
+
+ def +=(elem: (A, B)): this.type = {
+ s += elem; this
+ }
+ }
// PFun inherits equality from its member "f"
// override def hashCode = f.hashCode()
@@ -148,17 +155,18 @@ class PFun[A, B]
}
/**
- * Companion object for the PFun class
- *
- * @author Aleks Kissinger
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/PFun.scala Source code]]
- */
+ * Companion object for the PFun class
+ *
+ * @author Aleks Kissinger
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/PFun.scala Source code]]
+ */
object PFun {
/** Create an instance of PFun from a sequence of pairs */
- def apply[A, B](kvs: (A,B)*)(implicit keyOrd: Ordering[A]) : PFun[A,B] = {
- kvs.foldLeft(new PFun[A,B](Map(),Map())){ (pf: PFun[A,B], kv: (A,B)) => pf + kv }
+ def apply[A, B](kvs: (A, B)*)(implicit keyOrd: Ordering[A]): PFun[A, B] = {
+ kvs.foldLeft(new PFun[A, B](Map(), Map())) { (pf: PFun[A, B], kv: (A, B)) => pf + kv }
}
- implicit def mapToPFun[A, B](mp: Map[A,B])(implicit keyOrd: Ordering[A]): PFun[A,B] = PFun(mp.toSeq:_*)
- implicit def pFunToMap[A, B](f: PFun[A,B]): Map[A,B] = f.toMap
+ implicit def mapToPFun[A, B](mp: Map[A, B])(implicit keyOrd: Ordering[A]): PFun[A, B] = PFun(mp.toSeq: _*)
+
+ implicit def pFunToMap[A, B](f: PFun[A, B]): Map[A, B] = f.toMap
}
diff --git a/scala/src/main/scala/quanto/data/PhaseExpression.scala b/scala/src/main/scala/quanto/data/PhaseExpression.scala
new file mode 100644
index 00000000..c8f052cc
--- /dev/null
+++ b/scala/src/main/scala/quanto/data/PhaseExpression.scala
@@ -0,0 +1,473 @@
+package quanto.data
+
+import quanto.data.Theory.ValueType
+import quanto.util.{Rational, UserAlerts}
+
+import scala.util.matching.Regex
+import scala.util.parsing.combinator.RegexParsers
+
+// Phases are just groups with added symbolic variables,
+// Except we also need to be able to do Gaussian elimination on them
+// So need to be able to interpret integers inside the phase group
+// (i.e. needs to know what '1' is)
+// We use the language of Abelian groups
+// For things like Strings you just need to fudge it to allow people to put notes into nodes
+
+
+case class PhaseParseException(msg: String, valueType: ValueType)
+ extends Exception(s"Attempting parse $ValueType threw message $msg")
+
+case class PhaseEvaluationException(msg: String) extends Exception(msg)
+
+
+abstract class RationalWithSymbols[Group <: RationalWithSymbols[Group]](val coefficients: Map[String, Rational], val constant: Rational) {
+
+ val description: ValueType
+ // Any symbolic variables should be in a set,
+ // However since we need these things to give consistent string outputs
+ // we sort the list here, and convert back into a set if needed
+ // to avoid "x + y" != "y + x"
+ val vars: List[String] = coefficients.keySet.toList.sorted
+
+ // The companion object should hold the zero element, but link to it from the class
+ def zero: Group
+
+ // Needs a 1, also held on the companion object
+ def one: Group
+
+ // Addition
+ def +(g: Group): Group
+
+ // Subtraction
+ def -(g: Group): Group
+
+ // Multiplication - only needs to be able to do multiplication by possible variable coefficients
+ def *(r: Rational): Group
+
+ // Substitution of single variable
+ def subst(s: String, e: Group): Group
+
+ // Substitution of map of variables
+ def subst(mp: Map[String, Group]): Group
+
+ def evaluate(mp: Map[String, Double]): Double = {
+ try {
+ constant + coefficients.foldLeft(0.0) { (a, b) => a + (mp(b._1) * Rational.rationalToDouble(b._2)) }
+ } catch {
+ case _: Exception => PhaseEvaluationException("Given arguments do not match those in the coefficient list")
+ 0
+ }
+ }
+
+}
+
+class PhaseExpression(const: Rational,
+ coeff: Map[String, Rational],
+ val modulus: Option[Int],
+ val finiteField: Boolean,
+ override val description: ValueType)
+ extends RationalWithSymbols[PhaseExpression](coeff, const) {
+
+ // Apply the modulus at creation
+ override val constant: Rational = mod(const)
+ // filter out any variables with zero as their coefficient
+ // If restricting to integers then we also apply the modulo to the coefficients
+ // e.g. 2*alpha = false in Bool (yes, that's a horrible mix, but it's the sort of thing we expect to see parsed)
+ override val coefficients: Map[String, Rational] = coeff.map(
+ sr => sr._1 -> (if (finiteField) mod(sr._2) else sr._2)
+ ).filterNot(sr => sr._2.isZero)
+
+ def mod(r: Rational): Rational = if (modulus.nonEmpty) {
+ val reduced = Rational(r.n % (modulus.get * r.d), r.d)
+ // Recall that -1 % 2 is still -1, for reasons.
+ if (reduced < 0) {
+ reduced + modulus.get
+ } else {
+ reduced
+ }
+ } else {
+ r
+ }
+
+ override def equals(that: Any): Boolean = that match {
+ case e: PhaseExpression =>
+ description == e.description && constant == e.constant && coefficients == e.coefficients
+ case _ => false
+ }
+
+ def zero: PhaseExpression = PhaseExpression.zero(description)
+
+ def one: PhaseExpression = PhaseExpression.one(description)
+
+ def subst(mp: Map[String, PhaseExpression]): PhaseExpression =
+ mp.foldLeft(this) { case (e, (v, e1)) => e.subst(v, e1) }
+
+ def subst(v: String, e: PhaseExpression): PhaseExpression = {
+ val c = coefficients.getOrElse(v, Rational(0))
+ this - PhaseExpression(Rational(0), Map(v -> c), description) + (e * c)
+ }
+
+ def -(e: PhaseExpression): PhaseExpression = this + (e * -1)
+
+ def *(i: Int): PhaseExpression = this * Rational(i)
+
+ def +(e: PhaseExpression) = PhaseExpression(mod(constant + e.constant),
+ e.coefficients.foldLeft(coefficients) {
+ case (m, (k, v)) => m + (k -> (v + m.getOrElse(k, Rational(0))))
+ }, description)
+
+ def *(r: Rational): PhaseExpression =
+ PhaseExpression(mod(constant * r), coefficients.mapValues(x => x * r), description)
+
+ override def toString: String = PhaseExpression.toString(description, this)
+
+ override def evaluate(mp: Map[String, Double]): Double = mod(super.evaluate(mp))
+
+ def mod(d: Double): Double = if (modulus.nonEmpty) {
+ d % modulus.get
+ } else {
+ d
+ }
+
+ def as(valueType: ValueType): PhaseExpression = PhaseExpression(constant, coefficients, valueType)
+
+ def convertTo(valueType: ValueType): PhaseExpression = PhaseExpression(this.constant, this.coefficients, valueType)
+}
+
+case class FieldData(modulus: Option[Int], finiteField: Boolean)
+
+
+object PhaseExpression {
+
+ def apply(r: Rational, valueType: ValueType): PhaseExpression = {
+ PhaseExpression(r, Map(), valueType)
+ }
+
+ def parse(s: String, valueType: ValueType): PhaseExpression = {
+ try {
+ valueType match {
+ case ValueType.AngleExpr => AngleExpressionParser.p(s)
+ case ValueType.Boolean => BooleanExpressionParser.p(s)
+ case ValueType.Rational => RationalExpressionParser.p(s)
+ case ValueType.Empty => PhaseExpression(0, Map(), ValueType.Empty)
+ case ValueType.String => StringExpressionParser.p(s)
+ case v => throw PhaseParseException(s"Asked to parse unexpected '$s' for type $v", v)
+ }
+ }catch{
+ case PhaseParseException(m, v) =>
+ UserAlerts.alert(s"Could not parse $s for type $v, regex error: $m", UserAlerts.Elevation.ERROR)
+ PhaseExpression.zero(v)
+ }
+ }
+
+ def one(valueType: ValueType): PhaseExpression = PhaseExpression(1, Map(), valueType)
+
+ def apply(r: Rational, m: Map[String, Rational], valueType: ValueType): PhaseExpression = {
+ val fData = fieldData(valueType)
+ new PhaseExpression(r, m, fData.modulus, fData.finiteField, valueType)
+ }
+
+ def fieldData(valueType: ValueType): FieldData = valueType match {
+ case ValueType.AngleExpr => FieldData(Some(2), finiteField = false)
+ case ValueType.Boolean => FieldData(Some(2), finiteField = true)
+ case _ => FieldData(None, finiteField = false)
+ }
+
+ def toString(valueType: ValueType, phaseExpression: PhaseExpression): String = {
+ valueType match {
+ case ValueType.AngleExpr =>
+ writeAsAngle(phaseExpression)
+ case ValueType.Boolean =>
+ writeAsBoolean(phaseExpression)
+ case ValueType.String =>
+ writeAsString(phaseExpression)
+ case ValueType.Empty =>
+ ""
+ case ValueType.Rational =>
+ writeAsRational(phaseExpression)
+ }
+ }
+
+ private def writeAsRational(phaseExpression: PhaseExpression): String = {
+ val constant = phaseExpression.constant
+ val coefficients = phaseExpression.coefficients
+ var fst = true
+ var s = ""
+ if (!constant.isZero) {
+
+ fst = false
+ val (n, sgn) =
+ if (constant.n > constant.d && constant.n < 2 * constant.d) (2 * constant.d - constant.n, "-")
+ else (constant.n, "")
+ if (n == 1) s += sgn + "1/" + constant.d
+ else s += sgn + n.toString + "/" + constant.d
+
+ }
+
+ def rStr(c: Rational) : String = writeAsRational(PhaseExpression(c, Map(), ValueType.Rational))
+
+ coefficients.keys.toList.sorted.foreach { variableName =>
+ val c = coefficients(variableName)
+ if (!c.isZero) {
+ if (fst) {
+ fst = false
+ s = s + (if (c == Rational(1)) "" else rStr(c) + " ") + variableName
+ } else {
+ if (c < Rational(0)) {
+ s = s + " - " + (if (c == Rational(-1)) "" else rStr(c * -1) + " ") + variableName
+ } else {
+ s = s + " + " + (if (c == Rational(1)) "" else rStr(c) + " ") + variableName
+ }
+ }
+ }
+ }
+ if (phaseExpression == PhaseExpression.zero(ValueType.Rational)) s = "0"
+
+ s
+ }
+
+ private def writeAsString(phaseExpression: PhaseExpression): String = {
+ val constant = phaseExpression.constant
+ val vars = phaseExpression.vars
+ if (vars.nonEmpty) {
+ vars.mkString("", ", ", "")
+ } else {
+ ""
+ }
+ }
+
+ private def writeAsBoolean(phaseExpression: PhaseExpression): String = {
+ val constant = phaseExpression.constant
+ val vars = phaseExpression.vars
+ val t = "\\True"
+ val f = "\\False"
+ if (vars.nonEmpty) {
+ vars.mkString(if (constant == Rational(1)) {
+ s"$t + "
+ } else {
+ ""
+ }, " + ", "")
+ } else {
+ if (constant == Rational(1)) t else f
+ }
+ }
+
+ private def writeAsAngle(phaseExpression: PhaseExpression): String = {
+ val constant = phaseExpression.constant
+ val coefficients = phaseExpression.coefficients
+ var fst = true
+ var s = ""
+ if (!constant.isZero) {
+
+ fst = false
+ if (constant.isOne) s += "\\pi"
+ else {
+ val (n, sgn) =
+ if (constant.n > constant.d && constant.n < 2 * constant.d) (2 * constant.d - constant.n, "-")
+ else (constant.n, "")
+ if (n == 1) s += sgn + "\\pi/" + constant.d
+ else s += sgn + n.toString + "\\pi/" + constant.d
+ }
+ }
+
+ coefficients.keys.toList.sorted.foreach { variableName =>
+ val c = coefficients(variableName)
+ if (!c.isZero) {
+ if (fst) {
+ fst = false
+ s = s + (if (c == Rational(1)) "" else c.toString + " ") + variableName
+ } else {
+ if (c < Rational(0)) {
+ s = s + " - " + (if (c == Rational(-1)) "" else (c * -1).toString + " ") + variableName
+ } else {
+ s = s + " + " + (if (c == Rational(1)) "" else c.toString + " ") + variableName
+ }
+ }
+ }
+ }
+ if (phaseExpression == PhaseExpression.zero(ValueType.AngleExpr)) s = "0"
+
+ s
+ }
+
+ def zero(valueType: ValueType): PhaseExpression = PhaseExpression(0, Map(), valueType)
+
+ private object AngleExpressionParser extends CommonParser(ValueType.AngleExpr) {
+
+ def angleExpression(r: Rational): PhaseExpression = PhaseExpression(r, ValueType.AngleExpr)
+
+ def angleExpression(r: Rational, m: Map[String, Rational]): PhaseExpression = PhaseExpression(r, m, ValueType.AngleExpr)
+
+ def INT_OPT: Parser[Int] = INT.? ^^ {
+ _.getOrElse(1)
+ }
+
+ def PI: Parser[Unit] = """\\?[pP][iI]""".r ^^ { _ => Unit }
+
+
+ def coeff: Parser[Rational] =
+ INT ~ "/" ~ INT ^^ { case n ~ _ ~ d => Rational(n, d) } |
+ "(" ~ coeff ~ ")" ^^ { case _ ~ c ~ _ => c } |
+ INT ^^ { n => Rational(n) }
+
+
+ def frac: Parser[PhaseExpression] =
+ INT_OPT ~ "*".? ~ PI ~ "/" ~ INT ^^ { case n ~ _ ~ _ ~ _ ~ d => angleExpression(Rational(n, d)) } |
+ INT_OPT ~ "*".? ~ SYMBOL ~ "/" ~ INT ^^ {
+ case n ~ _ ~ x ~ _ ~ d => angleExpression(Rational(0), Map(x -> Rational(n, d)))
+ }
+
+ def term: Parser[PhaseExpression] =
+ frac |
+ "-" ~ term ^^ { case _ ~ t => t * -1 } |
+ coeff ~ "*".? ~ PI ^^ { case c ~ _ ~ _ => angleExpression(c) } |
+ PI ^^ { _ => one } |
+ coeff ~ "*".? ~ SYMBOL ^^ { case c ~ _ ~ x => angleExpression(Rational(0), Map(x -> c)) } |
+ SYMBOL ~ "*" ~ coeff ^^ { case x ~ _ ~ c => angleExpression(Rational(0), Map(x -> c)) } |
+ SYMBOL ^^ { x => angleExpression(Rational(0), Map(x -> Rational(1))) } |
+ coeff ^^ angleExpression |
+ "(" ~ expr ~ ")" ^^ { case _ ~ t ~ _ => t }
+
+ def term1: Parser[PhaseExpression] =
+ "+" ~ term ^^ { case _ ~ t => t } |
+ "-" ~ term ^^ { case _ ~ t => t * -1 }
+
+ def terms: Parser[PhaseExpression] =
+ term1 ~ terms ^^ { case s ~ t => s + t } |
+ term1
+
+ def expr: Parser[PhaseExpression] =
+ term ~ terms ^^ { case s ~ t => s + t } |
+ term |
+ "" ^^ { _ => zero }
+
+ }
+
+
+ private object RationalExpressionParser extends CommonParser(ValueType.AngleExpr) {
+ def rationalExpression(r: Rational): PhaseExpression = PhaseExpression(r, ValueType.Rational)
+
+ def rationalExpression(r: Rational, m: Map[String, Rational]): PhaseExpression =
+ PhaseExpression(r, m, ValueType.Rational)
+
+ def INT_OPT: Parser[Int] = INT.? ^^ {
+ _.getOrElse(1)
+ }
+
+ def coeff: Parser[Rational] =
+ INT ~ "/" ~ INT ^^ { case n ~ _ ~ d => Rational(n, d) } |
+ "(" ~ coeff ~ ")" ^^ { case _ ~ c ~ _ => c } |
+ INT ^^ { n => Rational(n) }
+
+ def frac: Parser[PhaseExpression] =
+ INT_OPT ~ "*".? ~ SYMBOL ~ "/" ~ INT ^^ {
+ case n ~ _ ~ x ~ _ ~ d => rationalExpression(Rational(0), Map(x -> Rational(n, d)))
+ }
+
+ def term: Parser[PhaseExpression] =
+ frac |
+ "-" ~ term ^^ { case _ ~ t => t * -1 } |
+ coeff ~ "*".? ~ SYMBOL ^^ { case c ~ _ ~ x => rationalExpression(Rational(0), Map(x -> c)) } |
+ SYMBOL ^^ { x => rationalExpression(Rational(0), Map(x -> Rational(1))) } |
+ SYMBOL ~ "*" ~ coeff ^^ { case x ~ _ ~ c => rationalExpression(Rational(0), Map(x -> c)) } |
+ coeff ^^ rationalExpression |
+ "(" ~ expr ~ ")" ^^ { case _ ~ t ~ _ => t }
+
+ def term1: Parser[PhaseExpression] =
+ "+" ~ term ^^ { case _ ~ t => t } |
+ "-" ~ term ^^ { case _ ~ t => t * -1 }
+
+ def terms: Parser[PhaseExpression] =
+ term1 ~ terms ^^ { case s ~ t => s + t } |
+ term1
+
+ def expr: Parser[PhaseExpression] =
+ term ~ terms ^^ { case s ~ t => s + t } |
+ term |
+ "" ^^ { _ => zero }
+
+ }
+
+
+ private object StringExpressionParser {
+ def zero: PhaseExpression = PhaseExpression.zero(ValueType.String)
+
+ def p(s: String): PhaseExpression = s match {
+ case "" => zero
+ case t => PhaseExpression(0, Map(t -> Rational(1,1)), ValueType.String)
+ }
+ }
+
+ private abstract class CommonParser(T: ValueType) extends RegexParsers{
+
+ val zero: PhaseExpression = PhaseExpression.zero(T)
+ val one: PhaseExpression = PhaseExpression.one(T)
+
+ def INT: Parser[Int] =
+ """[0-9]+""".r ^^ {
+ _.toInt
+ }
+
+ def SYMBOL: Parser[String] =
+ """[\\a-zA-Z_][a-zA-Z0-9_']*""".r ^^ {
+ _.toString
+ }
+
+ final override def skipWhitespace = true
+
+ def expr: Parser[PhaseExpression]
+
+ def p(s: String): PhaseExpression = parseAll(expr, s) match {
+ case Success(e, _) => e
+ case Failure(msg, _) => throw PhaseParseException(msg, T)
+ case Error(msg, _) => throw PhaseParseException(msg, T)
+ }
+ }
+
+ private object BooleanExpressionParser extends CommonParser(ValueType.Boolean) {
+
+ def BooleanExpression(i: Int): PhaseExpression = PhaseExpression(Rational(i), ValueType.Boolean)
+
+ def BooleanExpression(i: Int, m: Map[String, Rational]): PhaseExpression =
+ PhaseExpression(Rational(i), m, ValueType.Boolean)
+
+ def TRUE: Parser[Int] = """\\?[Tt](rue|RUE)?""".r ^^ { _ => 1 }
+
+ def FALSE: Parser[Int] = """\\?[Ff](alse|ALSE)?""".r ^^ { _ => 0 }
+
+ def coeff: Parser[Int] =
+ TRUE | FALSE | INT
+
+ def INT_OPT: Parser[Int] = INT.? ^^ {
+ _.getOrElse(0)
+ }
+
+
+ def term: Parser[PhaseExpression] =
+ "-" ~ term ^^ { case _ ~ t => t } |
+ TRUE ^^ { _ => one } |
+ FALSE ^^ { _ => zero } |
+ coeff ~ "*" ~ coeff ^^ { case c ~ _ ~ x => BooleanExpression(x * c) } |
+ coeff ~ "*".? ~ SYMBOL ^^ { case c ~ _ ~ x => BooleanExpression(0, Map(x -> 1)) * c } |
+ SYMBOL ^^ { x => BooleanExpression(0, Map(x -> 1)) } |
+ coeff ^^ { x => BooleanExpression(x) } |
+ "(" ~ expr ~ ")" ^^ { case _ ~ t ~ _ => t }
+
+ def term1: Parser[PhaseExpression] =
+ "+" ~ term ^^ { case _ ~ t => t } |
+ "-" ~ term ^^ { case _ ~ t => t * -1 }
+
+ def terms: Parser[PhaseExpression] =
+ term1 ~ terms ^^ { case s ~ t => s + t } |
+ term1
+
+ def expr: Parser[PhaseExpression] =
+ term ~ terms ^^ { case s ~ t => s + t } |
+ term |
+ terms |
+ "" ^^ { _ => zero }
+
+ }
+
+
+}
diff --git a/scala/src/main/scala/quanto/data/Project.scala b/scala/src/main/scala/quanto/data/Project.scala
index 28371797..fc294409 100644
--- a/scala/src/main/scala/quanto/data/Project.scala
+++ b/scala/src/main/scala/quanto/data/Project.scala
@@ -1,46 +1,78 @@
package quanto.data
+import java.io.File
+
import quanto.rewrite.Simproc
+import quanto.util.FileHelper
import quanto.util.json._
-import java.io.File
class ProjectLoadException(message: String, cause: Throwable) extends Exception(message, cause)
-case class Project(theory: Theory, rootFolder: String, name : String) {
- def rules: Vector[String] = rulesInPath(rootFolder)
+case class Project(theory: Theory, projectFile: File, name: String) {
+ require(!projectFile.isDirectory)
+
+ FileHelper.printJson(projectFile.getAbsolutePath, Project.toJson(this))
+
+ val rootFolder: String = projectFile.getParent
+ private val rootFolderFile = projectFile.getParentFile
+
var simprocs: Map[String, Simproc] = Map()
+ var lastRunPythonFilePath: Option[String] = None
+
+ def relativePath(f: File): String = {
+ rootFolderFile.toURI.relativize(f.toURI).getPath
+ }
- private def rulesInPath(p: String): Vector[String] = {
- val f = new File(p)
- if (f.isDirectory) f.listFiles().toVector.flatMap(f => rulesInPath(f.getPath))
- else if (f.getPath.endsWith(".qrule")) {
- val fname = f.getPath
- Vector(fname.substring(rootFolder.length + 1, fname.length - 6))
+ //Scans the given folder for filenames that end in the given extension
+ def filesEndingIn(ext: String, path: String = rootFolderFile.getAbsolutePath): List[String] = {
+ val f = new File(path)
+ if (f.isDirectory) f.listFiles().toList.flatMap(f => filesEndingIn(ext, f.getPath))
+ else if (f.getPath.endsWith(ext)) {
+ val fileName = relativePath(f)
+ List(fileName.substring(0, fileName.length - ext.length))
}
- else Vector()
+ else List()
}
}
object Project {
- def fromTheoryOrProjectFile(theoryOrProjectFile: String, rootFolder: String = "", name : String = "") : Project = {
- println(s"Asked to load prjoect from: $theoryOrProjectFile")
+ def fromJson(json: Json, projectFile: File): Project = try {
+ Project(
+ Theory.fromJson(json / "theory"),
+ projectFile,
+ (json / "name").stringValue
+ )
+ } catch {
+ // First try loading as though old format
+ case _: Exception =>
+ try {
+ println("Attempting to update project file from old format.")
+ Project(Theory.fromJson(json / "theory"), projectFile, projectFile.getName)
+ }
+ catch {
+ case e: Exception => throw new ProjectLoadException("Error loading project", e)
+ }
+ }
+
+ def fromTheoryOrProjectFile(theoryOrProjectFile: File, rootFolder: File = new File("."), name: String = "main"): Project = {
+ println(s"Asked to load project from: $theoryOrProjectFile")
val theory: Theory = Theory.fromJson({
- val extension: String = theoryOrProjectFile.replaceAll(".*\\.", "")
+ val extension: String = theoryOrProjectFile.getAbsolutePath.replaceAll(".*\\.", "")
try {
extension match {
case "qtheory" =>
- Json.parse(new File(theoryOrProjectFile))
+ Json.parse(theoryOrProjectFile)
case "qproject" =>
- Json.parse(new File(theoryOrProjectFile)) / "theory"
+ Json.parse(theoryOrProjectFile) / "theory"
case _ =>
- try{
+ try {
val theoryStream = Theory.getClass.getResourceAsStream(theoryOrProjectFile + ".qtheory")
Json.parse(new Json.Input(theoryStream))
- } catch {
- case e: Exception =>
- throw new ProjectLoadException(s"Could not parse the resource $theoryOrProjectFile", e)
- }
+ } catch {
+ case e: Exception =>
+ throw new ProjectLoadException(s"Could not parse the resource $theoryOrProjectFile", e)
+ }
}
} catch {
case e: Exception =>
@@ -48,28 +80,11 @@ object Project {
}
}
)
- new Project(theory, rootFolder, name)
- }
-
- def fromJson(json: Json, rootFolder: String): Project = try {
- Project(
- Theory.fromJson(json / "theory"),
- rootFolder,
- (json / "name").stringValue
- )
- } catch {
- // First try loading as though old format
- case e: Exception =>
- try {
- println("Attempting to update project file from old format.")
- Project.fromTheoryOrProjectFile((json / "theory").stringValue, rootFolder)
- }
- catch {
- case e: Exception => throw new ProjectLoadException("Error loading project", e)
- }
+ val suggestedFilename = rootFolder + "/" + name + ".qproject"
+ new Project(theory, new File(suggestedFilename), name)
}
- def toJson(project: Project) : Json = {
+ def toJson(project: Project): Json = {
JsonObject(
"name" -> project.name,
"theory" -> Theory.toJson(project.theory)
diff --git a/scala/src/main/scala/quanto/data/ResultSet.scala b/scala/src/main/scala/quanto/data/ResultSet.scala
index 444ef48c..a5c00438 100644
--- a/scala/src/main/scala/quanto/data/ResultSet.scala
+++ b/scala/src/main/scala/quanto/data/ResultSet.scala
@@ -1,59 +1,64 @@
package quanto.data
case class ResultLine(rule: RuleDesc, index: Int, total: Int) {
- override def toString = {
- rule.name + (if (rule.inverse) "[inverse]" else "") + " (" + index + "/" + total + ")"
+ override def toString: String = {
+ (if (rule.inverse) "<- " else "-> ") + rule.name + " (" + index + "/" + total + ")"
}
}
-case class ResultSet(rules: Vector[RuleDesc], results: Map[RuleDesc,(Int,Vector[DStep])]) {
- def copy(rules: Vector[RuleDesc] = rules, results: Map[RuleDesc,(Int,Vector[DStep])] = results) =
- ResultSet(rules, results)
-
+case class ResultSet(rules: Vector[RuleDesc], results: Map[RuleDesc, (Int, Vector[DStep])]) {
def currentResult(rule: RuleDesc): Option[DStep] =
- results.get(rule).flatMap { case (i,vec) => if (i == 0) None else Some(vec(i-1)) }
+ results.get(rule).flatMap { case (i, vec) => if (i == 0) None else Some(vec(i - 1)) }
- def resultIndex(rule: RuleDesc) = results(rule)._1
- def setResultIndex(rule: RuleDesc, i: Int) =
- copy(results = results + (rule -> (results(rule) match { case (_, vec) => (i, vec) })))
-
- def nextResult(rule: RuleDesc) = {
+ def nextResult(rule: RuleDesc): ResultSet = {
val i = resultIndex(rule) + 1
if (i <= numResults(rule)) setResultIndex(rule, i)
else this
}
- def previousResult(rule: RuleDesc) = {
+ def setResultIndex(rule: RuleDesc, i: Int): ResultSet =
+ copy(results = results + (rule -> (results(rule) match {
+ case (_, vec) => (i, vec)
+ })))
+
+ def previousResult(rule: RuleDesc): ResultSet = {
val i = resultIndex(rule) - 1
if (i > 0) setResultIndex(rule, i)
else this
}
- def replaceGraph(rule: RuleDesc, i: Int, graph: Graph) = {
+ def replaceGraph(rule: RuleDesc, i: Int, graph: Graph): ResultSet = {
val x = results(rule)
val res = x._2(i - 1).copy(graph = graph)
copy(results = results + (rule -> (x._1, x._2.updated(i - 1, res))))
}
- def graph(rule: RuleDesc, i: Int) = results(rule)._2(i - 1).graph
+ def graph(rule: RuleDesc, i: Int): Graph = results(rule)._2(i - 1).graph
- def numResults(rule: RuleDesc) = results(rule)._2.size
- def +(res: (RuleDesc,DStep)) = {
+ def +(res: (RuleDesc, DStep)): ResultSet = {
val rs = results(res._1)
copy(results = results + (res._1 -> (if (rs._1 == 0) 1 else rs._1, rs._2 :+ res._2)))
}
- def -(rule: RuleDesc) = {
+ def -(rule: RuleDesc): ResultSet = {
copy(rules = rules.filter(_ != rule), results = results - rule)
}
- def resultLines = rules.map { r => ResultLine(r, resultIndex(r), numResults(r)) }
+ def copy(rules: Vector[RuleDesc] = rules, results: Map[RuleDesc, (Int, Vector[DStep])] = results) =
+ ResultSet(rules, results)
+
+ def resultLines: Vector[ResultLine] = rules.map { r => ResultLine(r, resultIndex(r), numResults(r)) }
+
+ def resultIndex(rule: RuleDesc): Int = results(rule)._1
+
+ def numResults(rule: RuleDesc): Int = results(rule)._2.size
}
object ResultSet {
def apply(rules: Vector[RuleDesc]): ResultSet = {
- val results = rules.foldLeft(Map[RuleDesc,(Int,Vector[DStep])]()) {
- case (rs, rule) => rs + (rule -> (0, Vector())) }
+ val results = rules.foldLeft(Map[RuleDesc, (Int, Vector[DStep])]()) {
+ case (rs, rule) => rs + (rule -> (0, Vector()))
+ }
ResultSet(rules, results)
}
}
diff --git a/scala/src/main/scala/quanto/data/Rule.scala b/scala/src/main/scala/quanto/data/Rule.scala
index 1e643918..3ad1fff9 100644
--- a/scala/src/main/scala/quanto/data/Rule.scala
+++ b/scala/src/main/scala/quanto/data/Rule.scala
@@ -11,19 +11,46 @@ case class RuleLoadException(message: String, cause: Throwable = null)
case class Rule(private val _lhs: Graph,
private val _rhs: Graph,
derivation: Option[String] = None,
- description: RuleDesc = RuleDesc("unnamed")) {
+ description: RuleDesc = RuleDesc()) {
+
+ val lhs: Graph = if (description.inverse) _rhs else _lhs
+ val rhs: Graph = if (description.inverse) _lhs else _rhs
+ val name: String = description.name + (if (description.inverse) " inverted" else "")
def inverse: Rule = {
Rule(lhs, rhs, derivation, description.invert)
}
- val lhs: Graph = if (description.inverse) _rhs else _lhs
+ def hasBBoxes: Boolean = lhs.bboxes.nonEmpty || rhs.bboxes.nonEmpty
- val rhs: Graph = if (description.inverse) _lhs else _rhs
+ def map(f: Graph => Graph): Rule = {
+ new Rule(f(lhs), f(rhs))
+ }
- val name: String = description.name + (if (description.inverse) " inverted" else "")
+ def colourSwap(changes: Map[String, String]): Rule = {
+ def safeChanges(s: String) : String = {
+ changes.get(s) match {
+ case Some(t) => t
+ case None => s
+ }
+ }
+ map(graph => {
+ graph.verts.foldLeft(graph) { (g, v) =>
+ g.updateVData(v)(f = {
+ case n: NodeV =>
+ n.copy(data = JsonObject(
+ "type" -> safeChanges((n.data / "type").stringValue),
+ "value" -> (n.data / "value")
+ ))
+ case m =>
+ m
+ }
+ )
+ }
+ })
+ }
- override def toString: String = name + " := "+ _lhs.toString +
+ override def toString: String = name + " := " + _lhs.toString +
(if (description.inverse) {
"<--"
} else {
@@ -36,6 +63,10 @@ case class RuleDesc(name: String = "unnamed", inverse: Boolean = false) {
def invert: RuleDesc = RuleDesc(name, !inverse)
}
+object RuleDesc {
+ implicit def fromString(string: String) : RuleDesc = RuleDesc(string)
+}
+
object Rule {
def fromJson(json: Json, thy: Theory = Theory.DefaultTheory, description: Option[RuleDesc] = None): Rule = try {
Rule(_lhs = Graph.fromJson(json / "lhs", thy),
@@ -46,7 +77,7 @@ object Rule {
},
description = if (description.isDefined) description.get else json.get("description") match {
case Some(JsonString(s)) => RuleDesc(s);
- case _ => RuleDesc("unnamed")
+ case _ => RuleDesc()
})
} catch {
case e: JsonAccessException =>
@@ -69,4 +100,11 @@ object Rule {
case None => obj
}
}
+
+ def namesUsed(rule: Rule, theory: Theory) : Set[String] = {
+
+ val namesUsedInLHS = Graph.variablesUsed(theory, rule.lhs)
+ val namesUsedInRHS = Graph.variablesUsed(theory, rule.rhs)
+ namesUsedInLHS union namesUsedInRHS
+ }
}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/data/Theory.scala b/scala/src/main/scala/quanto/data/Theory.scala
index 87744727..cd495ba7 100644
--- a/scala/src/main/scala/quanto/data/Theory.scala
+++ b/scala/src/main/scala/quanto/data/Theory.scala
@@ -4,96 +4,261 @@ import quanto.util.json._
import JsonValues._
import java.awt.{Color, Shape}
+
/**
- * Exception thrown when theory cannot be created for some reason
- *
- * @author Aleks Kissinger
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/Theory.scala Source code]]
- */
+ * Exception thrown when theory cannot be created for some reason
+ *
+ * @author Aleks Kissinger
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/Theory.scala Source code]]
+ */
class TheoryLoadException(message: String, cause: Throwable = null)
extends Exception(message, cause)
/**
- * A class which represents a theory
- *
- * @author Aleks Kissinger
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/Theory.scala Source code]]
- */
+ * A class which represents a theory
+ *
+ * @author Aleks Kissinger
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/Theory.scala Source code]]
+ */
case class Theory(
- name: String,
- coreName: String,
- vertexTypes: Map[String, Theory.VertexDesc],
- edgeTypes: Map[String, Theory.EdgeDesc] = Map("plain" -> Theory.PlainEdgeDesc),
- defaultVertexType: String,
- defaultEdgeType: String = "plain")
-{
- def defaultVertexData = vertexTypes(defaultVertexType).defaultData
- def defaultEdgeData = edgeTypes(defaultEdgeType).defaultData
- override def toString = coreName
+ name: String,
+ coreName: String,
+ vertexTypes: Map[String, Theory.VertexDesc],
+ edgeTypes: Map[String, Theory.EdgeDesc] = Map("plain" -> Theory.PlainEdgeDesc),
+ defaultVertexType: String,
+ defaultEdgeType: String = "plain") {
+ def defaultVertexData: JsonObject = vertexTypes(defaultVertexType).defaultData
+
+ def defaultEdgeData: JsonObject = edgeTypes(defaultEdgeType).defaultData
+
+ def mixin(that: Theory, overwriteName: Option[String]): Theory = {
+ // New overwrites old!
+ Theory(overwriteName.getOrElse(name),
+ overwriteName.getOrElse(coreName),
+ vertexTypes ++ that.vertexTypes,
+ edgeTypes ++ that.edgeTypes,
+ that.defaultVertexType, that.defaultEdgeType)
+ }
+
+ def mixin(
+ newVertexTypes: Map[String, Theory.VertexDesc] = Map(),
+ newEdgeTypes: Map[String, Theory.EdgeDesc] = Map(),
+ newName: Option[String] = None): Theory = {
+ // Overwrites existing types of the same name
+ Theory(newName.getOrElse(name),
+ newName.getOrElse(coreName),
+ vertexTypes ++ newVertexTypes,
+ edgeTypes ++ newEdgeTypes,
+ defaultVertexType, defaultEdgeType)
+ }
+
+ def copy(name: String = name,
+ coreName: String = coreName,
+ vertexTypes: Map[String, Theory.VertexDesc] = vertexTypes,
+ edgeTypes: Map[String, Theory.EdgeDesc] = edgeTypes,
+ defaultVertexType: String = defaultVertexType,
+ defaultEdgeType: String = defaultEdgeType): Theory = Theory(
+ name, coreName, vertexTypes, edgeTypes, defaultVertexType, defaultEdgeType
+ )
+
+ override def toString: String = coreName
}
/**
- * Companion object for the Theory class. Contains useful methods for
- * converting a Theory to/from JSON object
- *
- * @author Aleks Kissinger
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/Theory.scala Source code]]
- */
+ * Companion object for the Theory class. Contains useful methods for
+ * converting a Theory to/from JSON object
+ *
+ * @author Aleks Kissinger
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/Theory.scala Source code]]
+ */
object Theory {
- private implicit def jsonToColor(json: Json) = json match {
- case JsonArray(Vector(r,g,b,a)) => new Color(r.floatValue,g.floatValue,b.floatValue,a.floatValue)
- case JsonArray(Vector(r,g,b)) => new Color(r.floatValue,g.floatValue,b.floatValue)
- case _ => throw new JsonParseException("Expected array of 3 or 4 doubles, got: " + json)
- }
- private implicit def jsonOptToColorOpt(jopt: Option[Json]) = jopt.map(jsonToColor(_))
- private implicit def colorToJson(c: Color) =
+ private implicit def jsonToColor(json: Json): Color = json match {
+ case JsonArray(Vector(r, g, b, a)) => new Color(r.floatValue, g.floatValue, b.floatValue, a.floatValue)
+ case JsonArray(Vector(r, g, b)) => new Color(r.floatValue, g.floatValue, b.floatValue)
+ case _ => throw new JsonParseException("Expected array of 3 or 4 doubles, got: " + json)
+ }
+
+ private implicit def jsonOptToColorOpt(jopt: Option[Json]): Option[Color] = jopt.map(jsonToColor)
+
+ private implicit def colorToJson(c: Color): JsonArray =
c.getRGBComponents(null) match {
- case Array(r,g,b,1.0) => JsonArray(r,g,b)
- case Array(r,g,b,a) => JsonArray(r,g,b,a)
+ case Array(r, g, b, 1.0) => JsonArray(r, g, b)
+ case Array(r, g, b, a) => JsonArray(r, g, b, a)
+ }
+
+ type ValueType = ValueType.Value
+ type VertexShape = VertexShape.Value
+ type VertexLabelPosition = VertexLabelPosition.Value
+ type EdgeLabelPosition = EdgeLabelPosition.Value
+ val PlainEdgeDesc = EdgeDesc(
+ value = ValueDesc(typ = Vector(ValueType.Empty)),
+ style = EdgeStyleDesc(),
+ defaultData = JsonObject("type" -> "plain")
+ )
+ val DefaultTheory = Theory(
+ name = "String theory",
+ coreName = "string_theory",
+ vertexTypes = Map(
+ "string" -> VertexDesc(
+ value = ValueDesc(
+ typ = Vector(ValueType.String)
+ ),
+ style = VertexStyleDesc(
+ shape = VertexShape.Rectangle,
+ labelPosition = VertexLabelPosition.Inside
+ ),
+ defaultData = JsonObject("type" -> "string", "value" -> "")
+ )
+ ),
+ defaultVertexType = "string"
+ )
+
+ /**
+ * Same as '''fromJson(json : Json)''', but tries to parse a string to a Json object first
+ *
+ * @throws TheoryLoadException Exception thrown when theory cannot be created for some reason
+ */
+ def fromJson(s: String): Theory =
+ try {
+ fromJson(Json.parse(s))
+ }
+ catch {
+ case e: JsonParseException => throw new TheoryLoadException("Error parsing JSON", e)
+ }
+
+ /** Convert the theory to a JSON object */
+ def toJson(thy: Theory): Json = JsonObject(
+ "name" -> thy.name,
+ "core_name" -> thy.coreName,
+ "vertex_types" -> JsonObject(thy.vertexTypes.mapValues(x => x: Json)),
+ "edge_types" -> JsonObject(thy.edgeTypes.mapValues(x => x: Json)),
+ "default_vertex_type" -> thy.defaultVertexType,
+ "default_edge_type" -> thy.defaultEdgeType
+ ).noEmpty
+
+ /**
+ * Load a built-in theory from JSON file
+ *
+ * @param theoryFile name of the .qtheory file, without extension
+ * @return a theory object
+ */
+ def fromFile(theoryFile: String): Theory = {
+ Theory.fromJson(Json.parse(
+ new Json.Input(Theory.getClass.getResourceAsStream(theoryFile + ".qtheory"))))
+ }
+
+ /**
+ * Create a theory instance from a Json object
+ */
+ def fromJson(json: Json): Theory = {
+ try {
+ val name = (json / "name").stringValue
+ val coreName = (json / "core_name").stringValue
+ val vertexTypes = (json / "vertex_types").asObject.mapValues(x => x: VertexDesc)
+ val defaultVertexType = json / "default_vertex_type"
+ if (!vertexTypes.contains(defaultVertexType))
+ throw new TheoryLoadException("Default vertex type: " + defaultVertexType + " not in list.")
+
+ val edgeTypes = json.get("edge_types") match {
+ case Some(et) => et.asObject.mapValues(x => x: EdgeDesc)
+ case None => Map("plain" -> PlainEdgeDesc)
+ }
+ val defaultEdgeType = json.getOrElse("default_edge_type", "plain").stringValue
+
+ if (!edgeTypes.contains(defaultEdgeType))
+ throw new TheoryLoadException("Default edge type: " + defaultEdgeType + " not in list.")
+
+ Theory(
+ name,
+ coreName,
+ vertexTypes,
+ edgeTypes,
+ defaultVertexType,
+ defaultEdgeType
+ )
+ } catch {
+ case e: JsonAccessException => throw new TheoryLoadException("Error reading JSON", e)
}
+ }
+
+ case class ValueDesc(
+ typ: Vector[ValueType] = Vector(ValueType.Empty),
+ enumOptions: Vector[String] = Vector[String](),
+ latexConstants: Boolean = false,
+ validateWithCore: Boolean = false
+ )
+
+ case class VertexStyleDesc(
+ shape: VertexShape,
+ customShape: Option[Shape] = None,
+ strokeWidth: Int = 1,
+ strokeColor: Color = Color.BLACK,
+ fillColor: Color = Color.WHITE,
+ labelPosition: VertexLabelPosition = VertexLabelPosition.Center,
+ labelForegroundColor: Color = Color.BLACK,
+ labelBackgroundColor: Option[Color] = None
+ )
+
+ case class EdgeStyleDesc(
+ strokeColor: Color = Color.BLACK,
+ strokeWidth: Int = 1,
+ labelPosition: EdgeLabelPosition = EdgeLabelPosition.Auto,
+ labelForegroundColor: Color = Color.BLACK,
+ labelBackgroundColor: Option[Color] = None
+ )
+
+ case class VertexDesc(
+ value: ValueDesc,
+ style: VertexStyleDesc,
+ defaultData: JsonObject
+ )
+
+ case class EdgeDesc(
+ value: ValueDesc,
+ style: EdgeStyleDesc,
+ defaultData: JsonObject
+ )
object ValueType extends Enumeration with JsonEnumConversions {
- val String = Value("string")
- val AngleExpr = Value("angle_expr")
- val LongString = Value("long_string")
- val Enum = Value("enum")
- val Empty = Value("empty")
+ val String: ValueType = Value("string")
+ val AngleExpr: ValueType = Value("angle_expr")
+ val Boolean: ValueType = Value("boolean")
+ val Rational: ValueType = Value("rational")
+ val Integer: ValueType = Value("integer")
+ val Long: ValueType = Value("long")
+ val Enum: ValueType = Value("enum")
+ val Empty: ValueType = Value("empty")
}
- type ValueType = ValueType.Value
object VertexShape extends Enumeration with JsonEnumConversions {
- val Circle = Value("circle")
- val Rectangle = Value("rectangle")
- val Custom = Value("custom")
+ val Circle: VertexShape = Value("circle")
+ val Rectangle: VertexShape = Value("rectangle")
+ val Custom: VertexShape = Value("custom")
+
+
+ def fromName(name: String): Option[VertexShape] = this.values.find(v => v.toString == name)
}
- type VertexShape = VertexShape.Value
object VertexLabelPosition extends Enumeration with JsonEnumConversions {
- val Center = Value("center")
- val Inside = Value("inside")
- val Below = Value("below")
+ val Center: VertexLabelPosition = Value("center")
+ val Inside: VertexLabelPosition = Value("inside")
+ val Below: VertexLabelPosition = Value("below")
+
+ def fromName(name: String): Option[VertexLabelPosition] = this.values.find(v => v.toString == name)
}
- type VertexLabelPosition = VertexLabelPosition.Value
object EdgeLabelPosition extends Enumeration with JsonEnumConversions {
- val Center = Value("center")
- val Auto = Value("auto")
+ val Center: EdgeLabelPosition = Value("center")
+ val Auto: EdgeLabelPosition = Value("auto")
}
- type EdgeLabelPosition = EdgeLabelPosition.Value
- case class ValueDesc(
- typ: ValueType = ValueType.Empty,
- enumOptions: Vector[String] = Vector[String](),
- latexConstants: Boolean = false,
- validateWithCore: Boolean = false
- )
object ValueDesc {
implicit def fromJson(json: Json): ValueDesc =
ValueDesc(
- typ = json / "type",
+ typ = CompositeExpression.parseTypes(json / "type"),
enumOptions = (json ? "enum_options").vectorValue.map(_.stringValue),
latexConstants = json.getOrElse("latex_constants", false),
validateWithCore = json.getOrElse("validate_with_core", false)
@@ -101,27 +266,19 @@ object Theory {
implicit def toJson(v: ValueDesc): JsonObject =
JsonObject(
- "type" -> v.typ,
+ "type" -> v.typ.mkString(","),
"enum_options" -> v.enumOptions,
"latex_constants" -> v.latexConstants,
"validate_with_core" -> v.validateWithCore
).noEmpty
}
- case class VertexStyleDesc(
- shape: VertexShape,
- customShape: Option[Shape] = None,
- strokeColor: Color = Color.BLACK,
- fillColor: Color = Color.WHITE,
- labelPosition: VertexLabelPosition = VertexLabelPosition.Center,
- labelForegroundColor: Color = Color.BLACK,
- labelBackgroundColor: Option[Color] = None
- )
// TODO: implement custom shapes
object VertexStyleDesc {
- implicit def fromJson(json: Json) = VertexStyleDesc(
- shape = (json / "shape"),
+ implicit def fromJson(json: Json): VertexStyleDesc = VertexStyleDesc(
+ shape = json / "shape",
customShape = None,
+ strokeWidth = json.getOrElse("stroke_width", 1),
strokeColor = json.getOrElse("stroke_color", Color.BLACK),
fillColor = json.getOrElse("fill_color", Color.WHITE),
labelPosition = (json ? "label").getOrElse("position", VertexLabelPosition.Center),
@@ -129,30 +286,23 @@ object Theory {
labelBackgroundColor = (json ? "label").get("bg_color")
)
- implicit def toJson(v: VertexStyleDesc) =
+ implicit def toJson(v: VertexStyleDesc): JsonObject =
JsonObject(
"shape" -> v.shape,
"custom_shape" -> JsonNull,
+ "stroke_width" -> v.strokeWidth,
"stroke_color" -> v.strokeColor,
"fill_color" -> v.fillColor,
"label" -> JsonObject(
"position" -> v.labelPosition,
"fg_color" -> v.labelForegroundColor,
- "bg_color" -> v.labelBackgroundColor.map(x=>x:Json).getOrElse(JsonNull)
+ "bg_color" -> v.labelBackgroundColor.map(x => x: Json).getOrElse(JsonNull)
).noEmpty
).noEmpty
}
- case class EdgeStyleDesc(
- strokeColor: Color = Color.BLACK,
- strokeWidth: Int = 1,
- labelPosition: EdgeLabelPosition = EdgeLabelPosition.Auto,
- labelForegroundColor: Color = Color.BLACK,
- labelBackgroundColor: Option[Color] = None
- )
-
object EdgeStyleDesc {
- implicit def fromJson(json: Json) = EdgeStyleDesc(
+ implicit def fromJson(json: Json): EdgeStyleDesc = EdgeStyleDesc(
strokeColor = json.getOrElse("stroke_color", Color.BLACK),
strokeWidth = json.getOrElse("stroke_width", 1),
labelPosition = (json ? "label").getOrElse("position", EdgeLabelPosition.Auto),
@@ -160,139 +310,43 @@ object Theory {
labelBackgroundColor = (json ? "label").get("bg_color")
)
- implicit def toJson(v: EdgeStyleDesc) =
+ implicit def toJson(v: EdgeStyleDesc): JsonObject =
JsonObject(
"stroke_color" -> v.strokeColor,
"stroke_width" -> v.strokeWidth,
"label" -> JsonObject(
"position" -> v.labelPosition,
"fg_color" -> v.labelForegroundColor,
- "bg_color" -> v.labelBackgroundColor.map(x=>x:Json).getOrElse(JsonNull)
+ "bg_color" -> v.labelBackgroundColor.map(x => x: Json).getOrElse(JsonNull)
).noEmpty
).noEmpty
}
- case class VertexDesc(
- value: ValueDesc,
- style: VertexStyleDesc,
- defaultData: JsonObject
- )
object VertexDesc {
- implicit def fromJson(json: Json) = VertexDesc(
+ implicit def fromJson(json: Json): VertexDesc = VertexDesc(
value = json / "value",
style = json / "style",
- defaultData = (json / "default_data").asObject
- )
- implicit def toJson(v: VertexDesc) = JsonObject(
+ defaultData = json.getOrElse("default_data", "").asObject)
+
+ implicit def toJson(v: VertexDesc): JsonObject = JsonObject(
"value" -> v.value,
"style" -> v.style,
"default_data" -> v.defaultData
)
}
- case class EdgeDesc(
- value: ValueDesc,
- style: EdgeStyleDesc,
- defaultData: JsonObject
- )
object EdgeDesc {
- implicit def fromJson(json: Json) = EdgeDesc(
- value = (json / "value"),
- style = (json / "style"),
+ implicit def fromJson(json: Json): EdgeDesc = EdgeDesc(
+ value = json / "value",
+ style = json / "style",
defaultData = (json / "default_data").asObject
)
- implicit def toJson(v: EdgeDesc) = JsonObject(
+
+ implicit def toJson(v: EdgeDesc): JsonObject = JsonObject(
"value" -> v.value,
"style" -> v.style,
"default_data" -> v.defaultData
)
}
- /**
- * Same as '''fromJson(json : Json)''', but tries to parse a string to a Json object first
- * @throws TheoryLoadException
- */
- def fromJson(s: String): Theory =
- try { fromJson(Json.parse(s)) }
- catch { case e:JsonParseException => throw new TheoryLoadException("Error parsing JSON", e) }
-
- /**
- * Create a theory instance from a Json object
- * @throws TheoryLoadException
- */
- def fromJson(json: Json): Theory = {
- try {
- val name = (json / "name").stringValue
- val coreName = (json / "core_name").stringValue
- val vertexTypes = (json / "vertex_types").asObject.mapValues(x => x:VertexDesc)
- val defaultVertexType = (json / "default_vertex_type")
- if (!vertexTypes.contains(defaultVertexType))
- throw new TheoryLoadException("Default vertex type: " + defaultVertexType + " not in list.")
-
- val edgeTypes = json.get("edge_types") match {
- case Some(et) => et.asObject.mapValues(x => x:EdgeDesc)
- case None => Map("plain"->PlainEdgeDesc)
- }
- val defaultEdgeType = json.getOrElse("default_edge_type", "plain").stringValue
-
- if (!edgeTypes.contains(defaultEdgeType))
- throw new TheoryLoadException("Default edge type: " + defaultEdgeType + " not in list.")
-
- Theory(
- name,
- coreName,
- vertexTypes,
- edgeTypes,
- defaultVertexType,
- defaultEdgeType
- )
- } catch {
- case e: JsonAccessException => throw new TheoryLoadException("Error reading JSON", e)
- }
- }
-
- /** Convert the theory to a JSON object */
- def toJson(thy: Theory): Json = JsonObject(
- "name" -> thy.name,
- "core_name" -> thy.coreName,
- "vertex_types" -> JsonObject(thy.vertexTypes.mapValues(x => x:Json)),
- "edge_types" -> JsonObject(thy.edgeTypes.mapValues(x => x:Json)),
- "default_vertex_type" -> thy.defaultVertexType,
- "default_edge_type" -> thy.defaultEdgeType
- ).noEmpty
-
- val PlainEdgeDesc = EdgeDesc(
- value = ValueDesc(typ = ValueType.Empty),
- style = EdgeStyleDesc(),
- defaultData = JsonObject("type" -> "plain")
- )
-
- val DefaultTheory = Theory(
- name = "String theory",
- coreName = "string_theory",
- vertexTypes = Map(
- "string" -> VertexDesc(
- value = ValueDesc(
- typ = ValueType.String
- ),
- style = VertexStyleDesc(
- shape = VertexShape.Rectangle,
- labelPosition = VertexLabelPosition.Inside
- ),
- defaultData = JsonObject("type" -> "string", "value" -> "")
- )
- ),
- defaultVertexType = "string"
- )
-
- /**
- * Load a built-in theory from JSON file
- *
- * @param theoryFile name of the .qtheory file, without extension
- * @return a theory object
- */
- def fromFile(theoryFile: String): Theory = {
- Theory.fromJson(Json.parse(
- new Json.Input(Theory.getClass.getResourceAsStream(theoryFile + ".qtheory"))))
- }
}
diff --git a/scala/src/main/scala/quanto/data/VData.scala b/scala/src/main/scala/quanto/data/VData.scala
index 7d825298..d31b860e 100644
--- a/scala/src/main/scala/quanto/data/VData.scala
+++ b/scala/src/main/scala/quanto/data/VData.scala
@@ -1,125 +1,136 @@
package quanto.data
-import quanto.data.Theory.ValueType
import quanto.util.json._
/**
- * An abstract class which provides a general interface for accessing
- * vertex data
- *
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- * @author Aleks Kissinger
- */
+ * An abstract class which provides a general interface for accessing
+ * vertex data
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ * @author Aleks Kissinger
+ */
abstract class VData extends GraphElementData {
- def annotation : JsonObject
+ def annotation: JsonObject
/**
- * Get coordinates of vertex
- * @throws JsonAccessException
- * @return actual coordinates of vertex or (0,0) if none are specified
- */
+ * Get coordinates of vertex
+ *
+ * @return actual coordinates of vertex or (0,0) if none are specified
+ */
def coord: (Double, Double) = annotation.get("coord") match {
- case Some(JsonArray(Vector(x,y))) => (x.doubleValue, y.doubleValue)
+ case Some(JsonArray(Vector(x, y))) => (x.doubleValue, y.doubleValue)
case Some(otherJson) => throw new JsonAccessException("Expected: array with 2 elements", otherJson)
- case None => (0,0)
+ case None => (0, 0)
}
/** Create a copy of the current vertex with the new coordinates */
- def withCoord(c: (Double,Double)): VData
+ def withCoord(c: (Double, Double)): VData
+
def typ: String
def isWireVertex: Boolean
- def isBoundary : Boolean
+
+ def isBoundary: Boolean
}
/**
- * Companion object for the VData class. Contains a method getCoord which has
- * the same behaviour as VData.coord, but is static.
- *
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- * @author Aleks Kissinger
- */
+ * Companion object for the VData class. Contains a method getCoord which has
+ * the same behaviour as VData.coord, but is static.
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ * @author Aleks Kissinger
+ */
object VData {
- def getCoord(annotation: Json): (Double,Double) = annotation.get("coord") match {
- case Some(JsonArray(Vector(x,y))) => (x.doubleValue, y.doubleValue)
+ def getCoord(annotation: Json): (Double, Double) = annotation.get("coord") match {
+ case Some(JsonArray(Vector(x, y))) => (x.doubleValue, y.doubleValue)
case Some(otherJson) => throw new JsonAccessException("Expected: array with 2 elements", otherJson)
- case None => (0,0)
+ case None => (0, 0)
}
}
/**
- * A class which represents node vertex data.
- *
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- */
+ * A class which represents node vertex data.
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ */
case class NodeV(
- data: JsonObject = Theory.DefaultTheory.defaultVertexData,
- annotation: JsonObject = JsonObject(),
- theory: Theory = Theory.DefaultTheory) extends VData
-{
+ data: JsonObject = Theory.DefaultTheory.defaultVertexData,
+ annotation: JsonObject = JsonObject(),
+ theory: Theory = Theory.DefaultTheory) extends VData {
/** Type of the vertex */
- val typ = (data / "type").stringValue
-
-// def label = data.getOrElse("label","").stringValue
- def typeInfo = theory.vertexTypes(typ)
-
+ val typ: String = (data / "type").stringValue
// support input of old-style graphs, where data may be stored at value/pretty
val value: String = data ? "value" match {
- case str : JsonString => str.stringValue
- case obj : JsonObject => obj.getOrElse("pretty", JsonString("")).stringValue
+ case str: JsonString => str.stringValue
+ case obj: JsonObject => obj.getOrElse("pretty", JsonString("")).stringValue
case _ => ""
}
+ def newValue(value: String): NodeV = NodeV(data = JsonObject(
+ "type" -> typ,
+ "value" -> value
+ ), annotation = annotation, theory = theory)
+
+ // if the theory says this node should have a value, try to parse it,
+ // and store it in "phaseData". If it should have a value, but parsing fails, set
+ // it to empty.
+ lazy val (phaseData: CompositeExpression, hasValue: Boolean) =
+ try {
+ val phaseTypes = theory.vertexTypes(typ).value.typ
+ val phaseValues = CompositeExpression.parseKnowingTypes(value, phaseTypes)
+ (CompositeExpression(phaseTypes, phaseValues), true)
+ }
+ catch {
+ // YOU WILL END UP HERE IF THE PARSER WAS HANDED AN UNFAMILIAR VALUETYPE
+ // See uses of PhaseParseException to pinpoint where
+ case _: PhaseParseException => (CompositeExpression(Vector(), Vector()), false)
+ }
+
+ // def label = data.getOrElse("label","").stringValue
+ def typeInfo = theory.vertexTypes(typ)
- // if the theory says this node should have an angle, try to parse it from value,
- // and store it in "angle". If it should have an angle, but parsing fails, set
- // angle to "0".
- val (angle: AngleExpression, hasAngle: Boolean) =
- if (theory.vertexTypes(typ).value.typ == ValueType.AngleExpr)
- try { (AngleExpression.parse(value), true) }
- catch { case _: AngleParseException => (AngleExpression(), true) }
- else (AngleExpression(), false)
+ def withCoord(c: (Double, Double)): NodeV =
+ copy(annotation = annotation + ("coord" -> JsonArray(c._1, c._2)))
- def withCoord(c: (Double,Double)) =
- copy(annotation = annotation + ("coord" -> JsonArray(c._1, c._2)))
+ /** Create a copy of the current vertex with the new value */
+ def withValue(s: String): NodeV =
+ copy(data = data.setPath("$.value", s).asObject)
- /** Create a copy of the current vertex with the new value */
- def withValue(s: String) =
- copy(data = data.setPath("$.value", s).asObject)
+ def withTyp(s: String): NodeV =
+ copy(data = data.setPath("$.type", s).asObject)
- def withTyp(s: String) =
- copy(data = data.setPath("$.type", s).asObject)
+ def isWireVertex = false
- def isWireVertex = false
- def isBoundary = false
+ def isBoundary = false
- override def toJson =
+ override def toJson: JsonObject =
if (data == theory.defaultVertexData)
JsonObject("annotation" -> annotation).noEmpty
else
JsonObject(
"data" -> data,
"annotation" -> annotation).noEmpty
- }
+}
/**
- * Companion object for the NodeV class. Contains methods to convert to/from
- * JSON and a factory method to create instances of NodeV from a pair of
- * coordinates.
- *
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- * @author Aleks Kissinger
- */
+ * Companion object for the NodeV class. Contains methods to convert to/from
+ * JSON and a factory method to create instances of NodeV from a pair of
+ * coordinates.
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ * @author Aleks Kissinger
+ */
object NodeV {
- def apply(coord: (Double,Double)): NodeV = NodeV(annotation = JsonObject("coord" -> JsonArray(coord._1,coord._2)))
+ def apply(coord: (Double, Double)): NodeV = NodeV(annotation = JsonObject("coord" -> JsonArray(coord._1, coord._2)))
- def toJson(d: NodeV, theory: Theory) = JsonObject(
+ def toJson(d: NodeV, theory: Theory): JsonObject = JsonObject(
"data" -> (if (d.data == theory.vertexTypes(d.typ).defaultData) JsonNull else d.data),
"annotation" -> d.annotation).noEmpty
+
def fromJson(json: Json, thy: Theory = Theory.DefaultTheory): NodeV = {
val data = json.getOrElse("data", thy.defaultVertexData).asObject
val annotation = (json ? "annotation").asObject
@@ -133,36 +144,45 @@ object NodeV {
}
/**
- * A class which represents wire vertex data
- *
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- */
+ * A class which represents wire vertex data
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ */
case class WireV(
- data: JsonObject = JsonObject(),
- annotation: JsonObject = JsonObject(),
- theory: Theory = Theory.DefaultTheory) extends VData
-{
+ data: JsonObject = JsonObject(),
+ annotation: JsonObject = JsonObject(),
+ theory: Theory = Theory.DefaultTheory) extends VData {
def typ = "wire"
+
def isWireVertex = true
- def isBoundary = annotation.get("boundary") match { case Some(JsonBool(b)) => b; case _ => false }
- def withCoord(c: (Double,Double)) =
+
+ def isBoundary: Boolean = annotation.get("boundary") match {
+ case Some(JsonBool(b)) => b;
+ case _ => false
+ }
+
+ def withCoord(c: (Double, Double)): WireV =
copy(annotation = annotation + ("coord" -> JsonArray(c._1, c._2)))
+
+ def makeBoundary(b: Boolean): WireV =
+ copy(annotation = annotation + ("boundary" -> JsonBool(b)))
}
/**
- * A companion object for the WireV class. Contains methods to convert to/from
- * JSON and a factory method to create instances of WireV from a pair of
- * coordinates
- *
- * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]]
- * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
- */
+ * A companion object for the WireV class. Contains methods to convert to/from
+ * JSON and a factory method to create instances of WireV from a pair of
+ * coordinates
+ *
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]]
+ * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]]
+ */
object WireV {
- def apply(c: (Double,Double)): WireV = WireV(annotation = JsonObject("coord" -> JsonArray(c._1,c._2)))
+ def apply(c: (Double, Double)): WireV = WireV(annotation = JsonObject("coord" -> JsonArray(c._1, c._2)))
- def toJson(d: NodeV, theory: Theory) = JsonObject(
+ def toJson(d: NodeV, theory: Theory): JsonObject = JsonObject(
"data" -> d.data, "annotation" -> d.annotation).noEmpty
+
def fromJson(json: Json, thy: Theory = Theory.DefaultTheory): WireV =
WireV((json ? "data").asObject, (json ? "annotation").asObject, thy)
}
diff --git a/scala/src/main/scala/quanto/gui/AddRuleDialog.scala b/scala/src/main/scala/quanto/gui/AddRuleDialog.scala
index da3a989b..5ac80839 100644
--- a/scala/src/main/scala/quanto/gui/AddRuleDialog.scala
+++ b/scala/src/main/scala/quanto/gui/AddRuleDialog.scala
@@ -8,82 +8,73 @@ import quanto.util.Globals
import scala.util.matching
import scala.util.matching.Regex
+import quanto.util.UserOptions.{scale, scaleInt}
class AddRuleDialog(project: Project) extends Dialog {
modal = true
- implicit def buttonIsSelected(radButton: RadioButton) : Boolean = radButton.selected
-
+ implicit def buttonIsSelected(radButton: RadioButton): Boolean = radButton.selected
+ var cancelled = false
def result: Seq[RuleDesc] = {
- val unsorted = if (MainPanel.radIncludeForwards) {
- MainPanel.FilteredRuleList.selection.items.map(s => RuleDesc(s))
- } else if (MainPanel.radIncludeInverse) {
- MainPanel.FilteredRuleList.selection.items.map(s => RuleDesc(s, inverse = true))
+ if(!cancelled) {
+ val unsorted = if (MainPanel.radIncludeForwards) {
+ MainPanel.selection.map(s => RuleDesc(s))
+ } else if (MainPanel.radIncludeInverse) {
+ MainPanel.selection.map(s => RuleDesc(s, inverse = true))
+ } else {
+ MainPanel.selection.flatMap(s => Seq(RuleDesc(s), RuleDesc(s, inverse = true)))
+ }
+ unsorted.sortBy(rd => rd.name)
} else {
- MainPanel.FilteredRuleList.selection.items.flatMap(s => Seq(RuleDesc(s), RuleDesc(s, inverse = true)))
+ Seq()
}
- unsorted.sortBy(rd => rd.name)
}
val AddButton = new Button("Add")
val CancelButton = new Button("Cancel")
defaultButton = Some(AddButton)
-// val dir = Files.newDirectoryStream(Paths.get(rootDir), "**/*.qrule")
-// for (p <- dir.asScala) println(p)
+ // val dir = Files.newDirectoryStream(Paths.get(rootDir), "**/*.qrule")
+ // for (p <- dir.asScala) println(p)
- val MainPanel = new BoxPanel(Orientation.Vertical) {
- val Search = new TextField
- val InitialRules : Vector[String] = project.rules.sorted
- var FilteredRuleList : ListView[String] = new ListView[String](InitialRules)
-
- val RulePane = new ScrollPane(FilteredRuleList)
+ val MainPanel = new BoxPanel(Orientation.Vertical) {
+ val FList = new FilteredList(project.filesEndingIn(".qrule"))
+ contents += FList
+ def selection : List[String] = FList.ListComponent.selection.items.toList
val radIncludeForwards = new RadioButton("Forwards")
- radIncludeForwards.selected = true
val radIncludeInverse = new RadioButton("Inverted")
var radIncludeInverseAndForwards= new RadioButton("Both")
+ radIncludeInverseAndForwards.selected = true
val radGroupIncludeInverse = new ButtonGroup(radIncludeForwards, radIncludeInverse, radIncludeInverseAndForwards)
- RulePane.preferredSize = new Dimension(400,200)
- contents += Swing.VStrut(10)
- contents += new BoxPanel(Orientation.Horizontal) {
- contents += (Swing.HStrut(10), new Label("Filter:"), Swing.HStrut(5), Search, Swing.HStrut(10))
- }
contents += Swing.VStrut(5)
+
contents += new BoxPanel(Orientation.Horizontal) {
- contents += (Swing.HStrut(10), RulePane, Swing.HStrut(10))
+ contents += (AddButton, Swing.HStrut(5), CancelButton)
}
+
contents += new BoxPanel(Orientation.Horizontal) {
contents += (Swing.HStrut(10), radIncludeForwards, Swing.HStrut(10))
contents += (Swing.HStrut(10), radIncludeInverse, Swing.HStrut(10))
contents += (Swing.HStrut(10), radIncludeInverseAndForwards, Swing.HStrut(10))
}
- contents += Swing.VStrut(5)
- contents += new BoxPanel(Orientation.Horizontal) {
- contents += (AddButton, Swing.HStrut(5), CancelButton)
- }
contents += Swing.VStrut(10)
+
}
+
contents = MainPanel
- listenTo(AddButton, CancelButton, MainPanel.Search)
+ listenTo(AddButton, CancelButton)
reactions += {
case ButtonClicked(AddButton) =>
+ cancelled = false
close()
case ButtonClicked(CancelButton) =>
- MainPanel.FilteredRuleList.selection.indices.clear()
+ cancelled = true
close()
- case ValueChanged(MainPanel.Search) =>
- try {
- MainPanel.FilteredRuleList.listData = MainPanel.InitialRules.filter(
- s => s.matches(".*" + MainPanel.Search.text + ".*"))
- } catch {
- case e: Exception =>
- //Exceptions here are thrown by inelligable regex from the user
- }
}
}
diff --git a/scala/src/main/scala/quanto/gui/BatchDerivationCreationDocument.scala b/scala/src/main/scala/quanto/gui/BatchDerivationCreationDocument.scala
new file mode 100644
index 00000000..70ecd782
--- /dev/null
+++ b/scala/src/main/scala/quanto/gui/BatchDerivationCreationDocument.scala
@@ -0,0 +1,35 @@
+package quanto.gui
+
+import java.io.File
+
+import quanto.data.{Project, Theory}
+import quanto.util.UserAlerts
+import quanto.util.json.Json
+
+import scala.swing.{Component, Publisher}
+
+//
+// This file is intentionally bare because this panel will spawn jobs, rather than display files
+//
+class BatchDerivationCreationDocument(val parent: Component) extends Document with Publisher {
+ val description = "Batch Derivation"
+ val fileExtension = ""
+
+
+ protected def clearDocument() {
+ }
+
+ protected def saveDocument(f: File) {
+ }
+
+ override def loadDocument(f: File) {
+ }
+
+ override def titleDescription: String = "Batch Derivation"
+
+ override def unsavedChanges: Boolean = false
+
+ override protected def exportDocument(f: File) {
+ }
+
+}
diff --git a/scala/src/main/scala/quanto/gui/BatchDerivationCreatorPanel.scala b/scala/src/main/scala/quanto/gui/BatchDerivationCreatorPanel.scala
new file mode 100644
index 00000000..bcd9e6ba
--- /dev/null
+++ b/scala/src/main/scala/quanto/gui/BatchDerivationCreatorPanel.scala
@@ -0,0 +1,174 @@
+package quanto.gui
+
+import java.awt.BorderLayout
+import java.awt.event.{KeyAdapter, KeyEvent}
+import java.io.File
+import java.util.Calendar
+import java.util.concurrent.TimeUnit
+
+import org.gjt.sp.jedit.{Mode, Registers}
+import org.gjt.sp.jedit.textarea.StandaloneTextArea
+import quanto.rewrite.Simproc
+import quanto.util.UserOptions.scaleInt
+import quanto.util.swing.ToolBar
+import quanto.data.Graph
+
+import scala.swing.event.{ButtonClicked, Event, SelectionChanged}
+import quanto.cosy.SimprocBatch
+import quanto.util.{FileHelper, UserOptions}
+
+import scala.swing.{BorderPanel, BoxPanel, Button, Component, Dimension, GridPanel, Label, Orientation, Publisher, ScrollPane, Swing, TextArea}
+
+case class HaltBatchProcessesEvent() extends Event
+
+//
+// This panel will create batch jobs
+// The corresponding document for HasDocument is essentially empty
+
+class BatchDerivationCreatorPanel extends BorderPanel with HasDocument with Publisher {
+ val document = new BatchDerivationCreationDocument(this)
+ val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask
+ val Toolbar = new ToolBar {
+ //contents
+ }
+ val GraphList = new FilteredList(graphs)
+ val LabelNumSimprocs = new Label()
+ val LabelNumGraphs = new Label()
+ val LabelNumPairs = new Label()
+ val LabelTimeLimit = new Label()
+ val NotesTextBox: TextEditor = new TextEditor(TextEditor.Modes.blank)
+
+
+ val StartButton = new Button("Start")
+ val HaltButton = new Button("Halt ongoing")
+ val BatchDetails: BoxPanel = new BoxPanel(Orientation.Vertical) {
+ contents += VSpace
+ contents += new GridPanel(4, 2) {
+
+ contents += new Label("Simprocs:")
+ contents += LabelNumSimprocs
+ contents += new Label("Graphs:")
+ contents += LabelNumGraphs
+ contents += new Label("Total pairs:")
+ contents += LabelNumPairs
+ contents += new Label("Time limit:")
+ contents += LabelTimeLimit
+
+ maximumSize = new Dimension(scaleInt(200), scaleInt(400))
+ }
+ }
+ val GraphChooser: Component = new BoxPanel(Orientation.Vertical) {
+
+ contents += VSpace
+ contents += header("Graphs to include:")
+ contents += VSpace
+ contents += GraphList
+ }
+ val IgnitionButtons: BoxPanel = new BoxPanel(Orientation.Horizontal) {
+ contents += (HSpace, StartButton, HSpace, HaltButton, HSpace)
+ }
+ val NotesHolder : BoxPanel = new BoxPanel(Orientation.Vertical) {
+ contents += VSpace
+ contents += new Label("Notes:")
+ contents += VSpace
+ contents += NotesTextBox.Component
+ NotesTextBox.Component.maximumSize = new Dimension(scaleInt(600), scaleInt(300))
+ }
+ def MainPanel: BoxPanel = new BoxPanel(Orientation.Vertical) {
+ contents += SimprocChooser
+ contents += GraphChooser
+ contents += NotesHolder
+ contents += BatchDetails
+
+ contents += VSpace
+ contents += IgnitionButtons
+ contents += VSpace
+ }
+ //SimprocList is a var not a val, because it can then be destroyed and recreated when the simprocs in memory change
+ var SimprocList = new FilteredList(simprocs.keys.toList)
+
+ val SimprocChooser: BoxPanel = new BoxPanel(Orientation.Vertical) {
+ // Made def not val because it references the var SimprocList
+ contents += VSpace
+ contents += header("Simprocs to include:")
+ contents += VSpace
+ contents += SimprocList
+ }
+
+ def header(title: String): BoxPanel = new BoxPanel(Orientation.Horizontal) {
+ contents += (HSpace, new Label(title), HSpace)
+ }
+
+ private def HSpace: Component = Swing.HStrut(scaleInt(10))
+
+ private def VSpace: Component = Swing.VStrut(scaleInt(10))
+
+ def TopScrollablePane = new ScrollPane(MainPanel)
+
+ def simprocs: Map[String, Simproc] = QuantoDerive.CurrentProject.map(p => p.simprocs).getOrElse(Map())
+
+ def graphs: List[String] = QuantoDerive.CurrentProject.map(p => p.filesEndingIn(".qgraph")).getOrElse(List())
+
+ def loadGraph(name: String): Graph = {
+ val root: String = QuantoDerive.CurrentProject.map(p => p.rootFolder).getOrElse("")
+ val theory = QuantoDerive.CurrentProject.map(p => p.theory).get
+ FileHelper.readFile[Graph](new File(root + "/" + name + ".qgraph"), json => Graph.fromJson(json, theory))
+ }
+
+ def refreshDataDisplay() {
+ def numSimprocs: Int = simprocSelection.size
+
+ def numGraphs: Int = graphSelection.size
+
+ LabelNumSimprocs.text = numSimprocs.toString
+ LabelNumGraphs.text = numGraphs.toString
+ LabelNumPairs.text = (numGraphs * numSimprocs).toString
+ val totalMilliseconds = numSimprocs * numGraphs * SimprocBatch.timeout
+ LabelTimeLimit.text = s"${TimeUnit.MILLISECONDS.toHours(totalMilliseconds)} hrs ${
+ TimeUnit.MILLISECONDS.toMinutes(totalMilliseconds) -
+ TimeUnit.MINUTES.toMinutes(TimeUnit.MILLISECONDS.toHours(totalMilliseconds))
+ } mins"
+ }
+
+ listenTo(this,
+ StartButton,
+ HaltButton,
+ SimprocList.ListComponent.selection,
+ GraphList.ListComponent.selection,
+ PythonEditPanel)
+
+
+ reactions += {
+ case SimprocsUpdated() =>
+ SimprocChooser.contents -= SimprocList
+ SimprocList = new FilteredList(simprocs.keys.toList)
+ SimprocChooser.contents += SimprocList
+ refreshDataDisplay()
+ listenTo(SimprocList.ListComponent.selection)
+ case ButtonClicked(HaltButton) =>
+ BatchDerivationCreatorPanel.jobID += 1
+ case ButtonClicked(StartButton) =>
+ val batch = SimprocBatch(simprocSelection, graphSelection.map(name => loadGraph(name)), NotesTextBox.getText)
+ batch.run()
+ case SelectionChanged(_) =>
+ refreshDataDisplay()
+ }
+
+ def simprocSelection: List[String] = SimprocList.ListComponent.selection.items.toList
+
+ def graphSelection: List[String] = GraphList.ListComponent.selection.items.toList
+
+ add(TopScrollablePane, BorderPanel.Position.Center)
+
+ add(Toolbar, BorderPanel.Position.North)
+
+ refreshDataDisplay()
+
+}
+
+object BatchDerivationCreatorPanel {
+ // The job ID exists so the user can cancel any currently running batch jobs
+ // Jobs check the current job ID, and if it has increased since the job started the job will stop
+ var jobID = 0
+
+}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/gui/BatchDerivationResultsDocument.scala b/scala/src/main/scala/quanto/gui/BatchDerivationResultsDocument.scala
new file mode 100644
index 00000000..b69eaefd
--- /dev/null
+++ b/scala/src/main/scala/quanto/gui/BatchDerivationResultsDocument.scala
@@ -0,0 +1,52 @@
+package quanto.gui
+
+import java.io.File
+
+import quanto.cosy.{SimprocBatchResult, SimprocLazyBatchResult, SimprocSingleRun}
+import quanto.data.Graph
+
+import scala.swing.{Component, Publisher}
+
+// The document is read-only data generated by a batch run
+// There are two ways to access the document
+// - a standard loadDocument which lazy reads in what is essentially meta-data
+// - a full load that puts all of the results into memory
+class BatchDerivationResultsDocument(val parent: Component) extends Document with Publisher {
+ val description = "Batch Derivation Results"
+ val fileExtension = "qsbr"
+
+ var simprocsUsed: Option[List[String]] = None
+ var resultsCount: Option[Int] = None
+ var allSimprocs: Option[Map[String, String]] = None
+ var results: Option[List[SimprocSingleRun]] = None
+ var timeTaken: Option[Double] = None
+ var notes: Option[String] = None
+
+ override def loadDocument(f: File) = {
+ val lazyResults = quanto.util.FileHelper.readFile[SimprocLazyBatchResult](f, SimprocBatchResult.lazyFromJson)
+ simprocsUsed = Some(lazyResults.selectedSimprocs)
+ notes = Some(lazyResults.notes)
+ resultsCount = Some(lazyResults.resultCount)
+ publish(DocumentChanged(this))
+ }
+
+ def loadFullDocument(f: File): Unit = {
+ val batchResults = quanto.util.FileHelper.readFile[SimprocBatchResult](f, SimprocBatchResult.fromJson)
+ allSimprocs = Some(batchResults.allSimprocs)
+ results = Some(batchResults.singleResults)
+ timeTaken = Some(batchResults.singleResults.map(sr => sr.derivationTimings.map(tt => tt._2).sum).sum)
+ publish(DocumentChanged(this))
+ }
+
+ override def unsavedChanges: Boolean = false
+
+ protected def clearDocument() {
+ }
+
+ protected def saveDocument(f: File) {
+ }
+
+ override protected def exportDocument(f: File) {
+ }
+
+}
diff --git a/scala/src/main/scala/quanto/gui/BatchDerivationResultsPanel.scala b/scala/src/main/scala/quanto/gui/BatchDerivationResultsPanel.scala
new file mode 100644
index 00000000..453fd53c
--- /dev/null
+++ b/scala/src/main/scala/quanto/gui/BatchDerivationResultsPanel.scala
@@ -0,0 +1,118 @@
+package quanto.gui
+
+import java.io.File
+import java.util.Calendar
+import java.util.concurrent.TimeUnit
+
+import quanto.rewrite.Simproc
+import quanto.util.UserOptions.scaleInt
+import quanto.util.swing.ToolBar
+import quanto.data.Graph
+
+import scala.swing.event.{ButtonClicked, Event, SelectionChanged}
+import quanto.cosy.{SimprocBatch, SimprocBatchResult, SimprocSingleRun}
+import quanto.util.FileHelper
+
+import scala.swing.{BorderPanel, BoxPanel, Button, Component, Dimension, GridPanel, Label, Orientation, Publisher, ScrollPane, Swing}
+
+// The document this panel displays has two ways to access the file:
+// A lazy read, used on first pass, that just loads metadata
+// A full read that loads everything into memory
+
+class BatchDerivationResultsPanel()
+ extends BorderPanel with HasDocument with Publisher {
+
+ val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask
+ val Toolbar = new ToolBar {
+ //contents
+ }
+
+ val document: BatchDerivationResultsDocument = new BatchDerivationResultsDocument(this)
+ val LabelNumSimprocs = new Label("Calculating")
+ val LabelSimprocsUsed = new Label("Calculating")
+ val LabelNumGraphs = new Label("Calculating")
+ val LabelNumPairs = new Label("Calculating")
+ val LabelTimeTaken = new Label("Calculating")
+ val SimprocsUsed: BoxPanel = new BoxPanel(Orientation.Vertical) {
+ contents += VSpace
+ contents += new Label("Simprocs used:")
+ contents += VSpace
+ contents += LabelSimprocsUsed
+ contents += VSpace
+ }
+ val BatchDetails: BoxPanel = new BoxPanel(Orientation.Vertical) {
+ contents += VSpace
+ contents += new GridPanel(3, 2) {
+
+ contents += new Label("No. Simprocs:")
+ contents += LabelNumSimprocs
+ contents += new Label("No. Graphs:")
+ contents += LabelNumGraphs
+ contents += new Label("Total pairs:")
+ contents += LabelNumPairs
+ }
+ contents += VSpace
+ maximumSize = new Dimension(scaleInt(400), scaleInt(200))
+ }
+ val NotesTextBox = new Label(document.notes.getOrElse(""))
+ val Notes: BoxPanel = new BoxPanel(Orientation.Vertical) {
+ contents += VSpace
+ contents += new Label("Notes:")
+ contents += VSpace
+ contents += NotesTextBox
+ contents += VSpace
+ NotesTextBox.preferredSize = new Dimension(scaleInt(600), scaleInt(300))
+ }
+ val MainPanel: BoxPanel = new BoxPanel(Orientation.Vertical) {
+ contents += SimprocsUsed
+ contents += BatchDetails
+ contents += Notes
+ }
+ val TopScrollablePane = new ScrollPane(MainPanel)
+
+ def allSimprocs: Option[Map[String, String]] = document.allSimprocs
+
+ def results: Option[List[SimprocSingleRun]] = document.results
+
+ def notes: Option[String] = document.notes
+
+ def refreshData(): Unit = {
+ val numSimprocs: Int = simprocsUsed.getOrElse(List()).size
+ val numPairs: Int = resultsCount.getOrElse(0)
+ LabelNumSimprocs.text = numSimprocs.toString
+ LabelSimprocsUsed.text = simprocsUsed.getOrElse(List()).mkString("\n")
+ LabelNumPairs.text = numPairs.toString
+ LabelNumGraphs.text = (numPairs, numSimprocs) match {
+ case (0, _) => "0"
+ case (n, 0) => n.toString
+ case (a,b) => (a/b).toString
+ }
+ LabelTimeTaken.text = if (timeTaken.nonEmpty) {
+ timeTaken.get.toString
+ } else {
+ "---"
+ }
+ NotesTextBox.text = notes.getOrElse("")
+ }
+
+ listenTo(document)
+ reactions += {
+ case DocumentChanged(d) => refreshData()
+ }
+
+ def simprocsUsed: Option[List[String]] = document.simprocsUsed
+
+ def resultsCount: Option[Int] = document.resultsCount
+
+ def timeTaken: Option[Double] = document.timeTaken
+
+ private def HSpace: Component = Swing.HStrut(scaleInt(10))
+
+
+ private def VSpace: Component = Swing.VStrut(scaleInt(10))
+
+ add(TopScrollablePane, BorderPanel.Position.Center)
+
+ add(Toolbar, BorderPanel.Position.North)
+
+}
diff --git a/scala/src/main/scala/quanto/gui/ClosableTabbedPane.scala b/scala/src/main/scala/quanto/gui/ClosableTabbedPane.scala
index b8633456..4dc98438 100644
--- a/scala/src/main/scala/quanto/gui/ClosableTabbedPane.scala
+++ b/scala/src/main/scala/quanto/gui/ClosableTabbedPane.scala
@@ -1,13 +1,22 @@
package quanto.gui
+import java.awt.event.{MouseAdapter, MouseEvent}
+
import scala.swing._
import javax.swing.border.EmptyBorder
-import javax.swing.{ImageIcon, Icon}
-import java.awt.{Color, BasicStroke, RenderingHints, Graphics}
+import javax.swing.{Icon, ImageIcon, SwingUtilities}
+import java.awt.{BasicStroke, Color, Graphics, Point, RenderingHints}
+
+import quanto.gui.QuantoDerive.popup
+import quanto.util.{UserAlerts, UserOptions}
+
+import scala.swing.event._
class ClosablePage(title0: String, component0: Component, val closeAction: () => Boolean)
-extends TabbedPane.Page(title0, component0) {
- lazy val tabComponent : ClosablePage.TabComponent = { new ClosablePage.TabComponent(this) }
+ extends TabbedPane.Page(title0, component0) {
+ lazy val tabComponent: ClosablePage.TabComponent = {
+ new ClosablePage.TabComponent(this)
+ }
override def title_=(t: String) {
super.title_=(t)
@@ -15,6 +24,9 @@ extends TabbedPane.Page(title0, component0) {
}
}
+case class PageClosed(p: ClosablePage) extends DocumentEvent
+
+
object ClosablePage {
def apply(title: String, component: Component)(closeAction: => Boolean) =
new ClosablePage(title, component, () => closeAction)
@@ -28,38 +40,70 @@ object ClosablePage {
case _: PythonDocumentPage => new ImageIcon(getClass.getResource("text-x-script.png"), "Python Script")
case _ => new ImageIcon(getClass.getResource("text-x-generic.png"), "Document")
}
-
- val titleLabel = new Label(p.title,icon,Alignment.Left)
- titleLabel.border = new EmptyBorder(new Insets(5,5,5,10))
+ val titleLabel = new Label(p.title, icon, Alignment.Left)
+ val closeButton = new Button(Action("") {
+closePage()
+ })
+ titleLabel.border = new EmptyBorder(new Insets(5, 5, 5, 10))
contents += titleLabel
+ def closePage(): Unit = {
+ if (p.closeAction()) {
+ publish(PageClosed(p))
+ } else {
+ }
+ }
+
+ def TabContextMenu(): PopupMenu = new PopupMenu {
+ menu =>
+
+ val CloseJustThis: Action = new Action("Close this tab") {
+
+ menu.contents += new MenuItem(this)
+
+ def apply(): Unit = closePage()
+ }
+
+
+ val CloseOtherTabs: Action = new Action("Close all other tabs") {
+ menu.contents += new MenuItem(this)
+
+ def apply(): Unit = {
+ QuantoDerive.closeAllOrListOfDocuments(Some(QuantoDerive.MainDocumentTabs.documents.toList.filter(q => q != p)))
+ }
+ }
+ }
+
def title = titleLabel.text
+
def title_=(t: String) {
titleLabel.text = t
}
- val closeButton = new Button(Action("") {
- if (p.closeAction()) {
- if (p.parent != null) p.parent.pages -= p
- printf("got successful close")
- } else {
- printf("tried to close")
- }
- })
- closeButton.border = new EmptyBorder(new Insets(0,0,0,0))
+ closeButton.border = new EmptyBorder(new Insets(0, 0, 0, 0))
closeButton.contentAreaFilled = false
closeButton.rolloverEnabled = true
closeButton.icon = new CloseIcon(false)
closeButton.rolloverIcon = new CloseIcon(true)
contents += closeButton
+
+ //listenTo(titleLabel.mouse.clicks)
+ // Once we know how to dispatch all events further down the hierarchy we can use this
+ // Until then, any attempt to pick up mouse events blocks native behaviour
+ reactions += {
+ case e : MouseEvent =>
+ //QuantoDerive.MainDocumentTabs.focus(p.asInstanceOf[DocumentPage])
+ UserAlerts.alert("Should have popped up")
+ if(e.isPopupTrigger) popup(TabContextMenu(), Some(e))
+ peer.dispatchEvent(SwingUtilities.convertMouseEvent(e.getComponent, e, peer))
+ }
}
+
}
// draw a little X
-private class CloseIcon(rollover : Boolean) extends Icon {
- def getIconWidth = 9
- def getIconHeight = 9
- def paintIcon(c : java.awt.Component, g : Graphics, x : Int, y : Int) : Unit = {
+private class CloseIcon(rollover: Boolean) extends Icon {
+ def paintIcon(c: java.awt.Component, g: Graphics, x: Int, y: Int): Unit = {
val g2 = g.asInstanceOf[Graphics2D]
g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON)
val savedStroke = g2.getStroke
@@ -70,17 +114,131 @@ private class CloseIcon(rollover : Boolean) extends Icon {
g2.drawLine(getIconWidth - 3, 2, 2, getIconHeight - 3)
g2.setStroke(savedStroke)
}
+
+ def getIconWidth : Int = UserOptions.scaleInt(9)
+
+ def getIconHeight : Int = UserOptions.scaleInt(9)
}
-class ClosableTabbedPane extends TabbedPane { tabbedPane =>
- def +=(p: ClosablePage) {
- pages += p
- peer.setTabComponentAt(pages.length-1, p.tabComponent.peer)
+// Handler for the TabbedPane object, to distance the swing from the Java
+// TabbedPanes hold Java Components, but we want to interact with Documents
+class DocumentTabs {
+
+ val tabbedPane = new TabbedPane()
+ private var pageIndex: Map[DocumentPage, Int] = Map()
+
+ def size : Int = pageIndex.size
+
+ def component: TabbedPane = tabbedPane
+
+ def +=(p: DocumentPage) {
+ if (!pageIndex.contains(p)) {
+ pages += p
+ tabbedPane.peer.setTabComponentAt(pages.length - 1, p.tabComponent.peer)
+ pageIndex += (p -> (pages.length - 1))
+ }
+ reorderPages()
+ focus(p)
+ }
+
+ def cycle(forward: Boolean = true) : Unit = {
+ if(pageIndex.nonEmpty) {
+ focus((selection.index + (if (forward) 1 else -1) + size) % size)
+ }
+ }
+
+ def focus(index: Int): Unit = {
+ selection.index = index
+ currentFocus match {
+ case Some(p) =>
+ p.document.focusOnNaturalComponent()
+ case None =>
+ }
+ }
+
+ def remove(page: DocumentPage): Unit = {
+ val removedIndex = pageIndex(page)
+ val preFocus = currentFocus
+ remove(removedIndex)
+ pageIndex -= page
+ reorderPages()
+ if (preFocus.nonEmpty && preFocus.get != page) {
+ focus(preFocus.get)
+ } else {
+ // Handled by peer
+ //selection.index -=1
+ }
+ ensureFocusIsValid()
+ }
+
+ private def reorderPages(): Unit = {
+ val indexPages = pageIndex.map(pi => (pi._2, pi._1)).toList.sortBy(_._1)
+ var newPageIndex: Map[DocumentPage, Int] = Map()
+ for (iip <- indexPages.zipWithIndex) {
+ newPageIndex += (iip._1._2 -> iip._2)
+ }
+ pageIndex = newPageIndex
+
+ }
+
+ private def remove(index: Int): Unit = {
+ pages.remove(index)
+ for (page <- pageIndex.keys) {
+ if (pageIndex(page) > index) {
+ pageIndex += (page -> (pageIndex(page) - 1))
+ }
+ }
+ }
+
+ private def ensureFocusIsValid(): Unit = {
+ if (pages.length > 0) {
+ // There is something to focus on
+ if (selection.index >= pages.length) {
+ selection.index = pages.length - 1
+ }
+ if (selection.index < 0) {
+ selection.index = 0
+ }
+ }
+ }
+
+ def selection = tabbedPane.selection
+
+ private def pages = tabbedPane.pages
+
+ def focus(p: DocumentPage): Unit = {
+ try {
+ focus(pageIndex(p))
+ } catch {
+ case e: Exception => selection.index = -1
+ }
}
- def currentContent: Option[Component] =
- if (selection.index == -1) None
+ def currentFocus: Option[DocumentPage] = {
+ pageIndex.find(kv => kv._2 == selection.index).map(_._1)
+ }
+
+ def currentContent: Option[Component] = {
+ if (selection.index == -1 || selection.page == null || selection.page.content == null) None
else Some(selection.page.content)
+ }
+
+ // Add this reaction to the peer, then handle changes at this level.
+ tabbedPane.selection.reactions += {
+ case SelectionChanged(`tabbedPane`) =>
+ publishChanged()
+ }
+
+ def publishChanged(): Unit = {
+ currentFocus.foreach(_.document.focusOnNaturalComponent())
+ }
+
+ def documents: Iterable[DocumentPage] = pageIndex.keys
+
+ def clear(): Unit = {
+ pages.clear()
+ pageIndex = Map()
+ }
}
diff --git a/scala/src/main/scala/quanto/gui/CodeTextArea.java b/scala/src/main/scala/quanto/gui/CodeTextArea.java
new file mode 100644
index 00000000..075fa678
--- /dev/null
+++ b/scala/src/main/scala/quanto/gui/CodeTextArea.java
@@ -0,0 +1,20 @@
+package quanto.gui;
+
+import org.gjt.sp.jedit.*;
+import org.gjt.sp.jedit.textarea.*;
+
+import java.util.Properties;
+
+
+public class CodeTextArea extends StandaloneTextArea {
+ static private IPropertyManager propertyManager;
+ static {
+ final Properties props = new Properties();
+
+ propertyManager = new IPropertyManager() {
+ public String getProperty(String prop) { return props.getProperty(prop); }
+ };
+ }
+
+ public CodeTextArea() { super(propertyManager); }
+}
diff --git a/scala/src/main/scala/quanto/gui/ColourSwapDialog.scala b/scala/src/main/scala/quanto/gui/ColourSwapDialog.scala
new file mode 100644
index 00000000..bceb2212
--- /dev/null
+++ b/scala/src/main/scala/quanto/gui/ColourSwapDialog.scala
@@ -0,0 +1,87 @@
+package quanto.gui
+
+import quanto.data.Theory.ValueType
+
+import scala.swing._
+import scala.swing.event.{ButtonClicked, Key, KeyPressed, ValueChanged}
+import quanto.data._
+import quanto.util.{Globals, UserOptions}
+
+import scala.util.matching
+import scala.util.matching.Regex
+import quanto.util.UserOptions.{scale, scaleInt}
+
+class ColourSwapDialog(theory: Theory) extends Dialog {
+ modal = true
+ title = "Specify new colours"
+ val smallGap: Int = UserOptions.scaleInt(5)
+
+ val types : List[String] = theory.vertexTypes.keys.toList
+
+ val SwapButton = new Button("Swap")
+ val CancelButton = new Button("Cancel")
+ defaultButton = Some(SwapButton)
+
+ def result: Map[String, String] = {
+ if(!cancelled) {
+ types.map(t => t -> comboFields(t).selection.item).toMap
+ } else {
+ types.map(t => t -> t).toMap
+ }
+ }
+
+ var cancelled = false
+
+ val comboFields : Map[String, ComboBox[String]] = types.map(t => {
+ t -> new ComboBox(types)
+ }).toMap
+
+ val MainPanel = new BoxPanel(Orientation.Vertical) {
+
+ contents += Swing.VStrut(smallGap)
+
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(smallGap),
+ new Label("Specify which types are sent to which"),
+ Swing.HStrut(smallGap))
+ }
+ contents += Swing.VStrut(smallGap)
+
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(smallGap),
+ new GridPanel(types.length, 2) {
+ types.foreach(t => {
+ val label = new Label(t)
+ label.horizontalAlignment = Alignment.Left
+ contents += label
+ val tf = comboFields(t)
+ contents += tf
+ tf.selection.item = t
+ })
+ },
+ Swing.HStrut(smallGap))
+ }
+ contents += Swing.VStrut(smallGap)
+
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (SwapButton, Swing.HStrut(smallGap), CancelButton)
+ }
+
+ contents += Swing.VStrut(2*smallGap)
+
+ }
+
+
+ contents = MainPanel
+
+ listenTo(SwapButton, CancelButton)
+
+ reactions += {
+ case ButtonClicked(SwapButton) =>
+ cancelled = false
+ close()
+ case ButtonClicked(CancelButton) =>
+ cancelled = true
+ close()
+ }
+}
diff --git a/scala/src/main/scala/quanto/gui/DerivationController.scala b/scala/src/main/scala/quanto/gui/DerivationController.scala
index 4750368a..7dd5e1e5 100644
--- a/scala/src/main/scala/quanto/gui/DerivationController.scala
+++ b/scala/src/main/scala/quanto/gui/DerivationController.scala
@@ -1,12 +1,15 @@
package quanto.gui
import quanto.data._
+
import scala.swing._
-import scala.swing.event.{SelectionChanged, ButtonClicked, Event}
+import scala.swing.event.{ButtonClicked, Event, SelectionChanged}
import quanto.gui.histview._
import java.awt.Color
+
import quanto.gui.graphview.Highlight
import quanto.layout.DeriveLayout
+import quanto.util.UserAlerts
sealed abstract class DeriveState extends HistNode { def step: Option[DSName] }
case class StepState(s: DSName) extends DeriveState {
@@ -193,7 +196,6 @@ class DerivationController(panel: DerivationPanel) extends Publisher {
replaceDerivation(derivation.deleteHead(s), "")
panel.document.undoStack.commit()
case StepState(s) =>
- // TODO: make deletion undo-able?
if (Dialog.showConfirmation(
title = "Confirm deletion",
message = "This will delete " + derivation.allChildren(s).size +
@@ -205,34 +207,40 @@ class DerivationController(panel: DerivationPanel) extends Publisher {
panel.document.undoStack.start("Delete proof step")
state = HeadState(parentOpt)
replaceDerivation(derivation.deleteStep(s), "")
+ if(parentOpt.nonEmpty && !derivation.isHead(parentOpt.get)){
+ val p = parentOpt.get
+ replaceDerivation(derivation.addHead(p), "")
+ state = HeadState(Some(p))
+ }
+
panel.document.undoStack.commit()
}
case _ => // do nothing on root
}
case ButtonClicked(panel.ExportTheoremButton) =>
- panel.document.file match {
- case Some(f) =>
- val rf = panel.project.rootFolder
- var dname = f.getAbsolutePath
- dname = if (dname.startsWith(rf) && dname.length > rf.length) dname.substring(rf.length + 1, dname.length) else dname
- dname = if (dname.endsWith(".qderive")) dname.substring(0, dname.length - 8) else dname
-
- state.step.map { s =>
- val ruleDoc = new RuleDocument(panel, panel.theory)
- ruleDoc.rule = new Rule(panel.document.root, derivation.steps(s).graph, Some(dname))
- ruleDoc.showSaveAsDialog(Some(panel.project.rootFolder))
- }
+ if(!panel.document.unsavedChanges && panel.document.file.nonEmpty) {
+ val f = panel.document.file.get
+ val rf = panel.project.rootFolder
+ var dname = panel.project.relativePath(f)
+ dname = if (dname.endsWith(".qderive")) dname.substring(0, dname.length - 8) else dname
+
+ state.step.foreach { s =>
+ val ruleDoc = new RuleDocument(panel, panel.theory)
+ val newRule = new Rule(panel.document.root, derivation.steps(s).graph, Some(dname))
+ val page = new RuleDocumentPage(panel.theory)
+ page.document.asInstanceOf[RuleDocument].rule = newRule
+ QuantoDerive.addAndFocusPage(page)
+ }
- case None =>
- Dialog.showMessage(
- title = "Error",
- message = "You must first save this derivation before exporting a theorem",
- messageType = Dialog.Message.Error)
+ } else {
+ UserAlerts.errorBox("You must first save this derivation before exporting a theorem")
}
case SelectionChanged(_) =>
- if (panel.histView.selectedNode != Some(state))
+ if (panel.histView.selectedNode != Some(state)) {
panel.histView.selectedNode.map { st => state = st }
+ panel.histView.ensureIndexIsVisible(panel.histView.selectedIndex())
+ }
}
}
diff --git a/scala/src/main/scala/quanto/gui/DerivationDocument.scala b/scala/src/main/scala/quanto/gui/DerivationDocument.scala
index ea50c2e8..65914fc1 100644
--- a/scala/src/main/scala/quanto/gui/DerivationDocument.scala
+++ b/scala/src/main/scala/quanto/gui/DerivationDocument.scala
@@ -11,7 +11,7 @@ class DerivationDocument(panel: DerivationPanel) extends Document {
val fileExtension = "qderive"
protected def parent = panel
- private var storedDerivation: Derivation = Derivation(panel.theory, Graph(panel.theory))
+ private var storedDerivation: Derivation = Derivation(Graph())
private var _derivation: Derivation = storedDerivation
def derivation = _derivation
def derivation_=(d: Derivation) = {
@@ -49,8 +49,8 @@ class DerivationDocument(panel: DerivationPanel) extends Document {
storedDerivation = _derivation
}
- protected def clearDocument() = {
- _derivation = Derivation(panel.theory, root = Graph(panel.theory))
+ protected def clearDocument() {
+ _derivation = Derivation(root = Graph(panel.theory))
}
def root_=(g: Graph) {
@@ -60,11 +60,12 @@ class DerivationDocument(panel: DerivationPanel) extends Document {
publish(DocumentChanged(this))
}
- def root = rootRef.graph
+ def root: Graph = rootRef.graph
- override protected def exportDocument(f: File) = {
+ override protected def exportDocument(f: File) {
+ previousDir = f
val old_state : DeriveState = parent.controller.state
- var state_opt : Option[DeriveState] = derivation.parent(old_state)
+ var state_opt : Option[DeriveState] = Some(old_state) //derivation.parent(old_state)
/* sequence of states from root leading to current state's parent */
var state_list : List[DeriveState] = List()
@@ -77,29 +78,46 @@ class DerivationDocument(panel: DerivationPanel) extends Document {
state_list = state_list.tail
/* enclose everything into a quote environment */
- printToFile(f, false)(p => {
+ printToFile(f, append=false)(p => {
p.println("\\begin{quote}\\raggedright")
})
+ val baseName = f.getName.substring(0, f.getName.lastIndexOf('.'))
+ val fullBaseName = f.getAbsolutePath.substring(0, f.getAbsolutePath.lastIndexOf('.'))
+
+ var i = 0
+
/* output derivation until current state parent */
state_list.foreach { s =>
parent.controller.state = s
+ val stepFileName = fullBaseName + "-" + i + ".tikz"
+ println(stepFileName)
+ val stepFile = new File(stepFileName)
/* must compute view data to prevent race condition */
parent.LhsView.computeDisplayData()
- parent.LhsView.exportView(f, true)
+ parent.LhsView.exportView(stepFile, append=false)
+
+ printToFile(f)( p => {
+ p.print("\\tikzfig{" + baseName + "-" + i + "}")
- printToFile(f, true)( p => {
- p.println("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
- p.println("=")
- p.println("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
+ // print an equals sign, unless this is the last step
+ if (i < state_list.length - 1) {
+ p.println(" $=$")
+ } else {
+ p.println()
+ }
})
+
+ i += 1
}
+
/* finally output current state and close quote environment */
- parent.controller.state = old_state
- parent.LhsView.computeDisplayData()
- parent.LhsView.exportView(f, true)
+// parent.controller.state = old_state
+// parent.LhsView.computeDisplayData()
+// parent.LhsView.exportView(f, true)
+
printToFile(f, true)( p => {
p.println("\\end{quote}")
})
diff --git a/scala/src/main/scala/quanto/gui/DerivationPanel.scala b/scala/src/main/scala/quanto/gui/DerivationPanel.scala
index a04bc7a4..e70a980f 100644
--- a/scala/src/main/scala/quanto/gui/DerivationPanel.scala
+++ b/scala/src/main/scala/quanto/gui/DerivationPanel.scala
@@ -8,10 +8,13 @@ import scala.swing.event._
import javax.swing.ImageIcon
import quanto.gui.histview.HistView
+case class RequestReRunSimproc() extends Event
+case class SuggestRewriteRule(relativePath: RuleDesc) extends Event
class DerivationPanel(val project: Project)
extends BorderPanel
with HasDocument
+ with Publisher
{
def theory = project.theory
val document = new DerivationDocument(this)
@@ -38,6 +41,8 @@ class DerivationPanel(val project: Project)
val lhsController = new GraphEditController(LhsView, document.undoStack, readOnly = true)
val rhsController = new GraphEditController(RhsView, document.undoStack, readOnly = true)
+ def ReRunLastSimproc() : Unit = publish(RequestReRunSimproc())
+
val RewindButton = new Button() {
icon = new ImageIcon(getClass.getResource("go-first.png"), "First step")
tooltip = "First Step"
@@ -174,42 +179,11 @@ class DerivationPanel(val project: Project)
add(new SplitPane(Orientation.Horizontal, topPane, PreviewScrollPane), BorderPanel.Position.Center)
}
- val SimplifyBuiltInPane = new BorderPanel {
- val Simprocs = new ListView[String]
- val SimprocsScrollPane = new ScrollPane(Simprocs)
- SimprocsScrollPane.preferredSize = new Dimension(400,200)
- val Preview = new GraphView(theory, DummyRef)
- val PreviewScrollPane = new ScrollPane(Preview)
- Preview.zoom = 0.6
-
- val SimplifyButton = new Button {
- icon = new ImageIcon(GraphEditor.getClass.getResource("start.png"))
- preferredSize = toolbarDim
- tooltip = "Start"
- }
-
- val StopButton = new Button {
- icon = new ImageIcon(GraphEditor.getClass.getResource("stop.png"))
- preferredSize = toolbarDim
- tooltip = "Stop"
- }
-
-
- val topPane = new BorderPanel {
- add(SimprocsScrollPane, BorderPanel.Position.Center)
- add(new FlowPanel(FlowPanel.Alignment.Center)(
- SimplifyButton
- ), BorderPanel.Position.South)
- }
-
- add(new SplitPane(Orientation.Horizontal, topPane, PreviewScrollPane), BorderPanel.Position.Center)
- }
val RhsRewritePane = new TabbedPane
RhsRewritePane.pages += new TabbedPane.Page("Rewrite", ManualRewritePane)
RhsRewritePane.pages += new TabbedPane.Page("Simplify", SimplifyPane)
- RhsRewritePane.pages += new TabbedPane.Page("Simplify (built-in)", SimplifyBuiltInPane)
val LhsLabel = new Label("(root)")
val RhsLabel = new Label("(head)")
@@ -246,6 +220,7 @@ class DerivationPanel(val project: Project)
add(GraphViewPanel, BorderPanel.Position.Center)
+ listenTo(document)
listenTo(LhsGraphPane, RhsGraphPane)
listenTo(ManualRewritePane.PreviewScrollPane, SimplifyPane.PreviewScrollPane)
@@ -262,6 +237,8 @@ class DerivationPanel(val project: Project)
case UIElementResized(SimplifyPane.PreviewScrollPane) =>
SimplifyPane.Preview.resizeViewToFit()
SimplifyPane.Preview.repaint()
+ case DocumentRequestingNaturalFocus(_) =>
+ LhsView.requestFocus()
}
// construct the controller last, as it depends on the panel elements already being initialised
@@ -269,6 +246,5 @@ class DerivationPanel(val project: Project)
val rewriteController = new RewriteController(this)
val simplifyController = new SimplifyController(this)
- val simplifyBuiltInController = new SimplifyBuiltInController(this)
// rewriteController.rules = Vector(RuleDesc("axioms/test1", inverse = false), RuleDesc("axioms/test2", inverse = true))
}
diff --git a/scala/src/main/scala/quanto/gui/Document.scala b/scala/src/main/scala/quanto/gui/Document.scala
index d1c53309..f59dc77a 100644
--- a/scala/src/main/scala/quanto/gui/Document.scala
+++ b/scala/src/main/scala/quanto/gui/Document.scala
@@ -1,17 +1,21 @@
package quanto.gui
-import scala.swing.{Component, FileChooser, Dialog, Publisher}
-import java.io.{FileNotFoundException, IOException, File}
+import scala.swing.{Component, Dialog, FileChooser, Publisher}
+import java.io.{File, FileNotFoundException, IOException}
+
import scala.swing.event.Event
import quanto.data._
import quanto.util.json.JsonParseException
import javax.swing.filechooser.FileNameExtensionFilter
import java.util.prefs.Preferences
+import javax.swing.JOptionPane
+
abstract class DocumentEvent extends Event
case class DocumentChanged(sender: Document) extends DocumentEvent
case class DocumentSaved(sender: Document) extends DocumentEvent
case class DocumentReplaced(sender: Document) extends DocumentEvent
+case class DocumentRequestingNaturalFocus(sender: Document) extends DocumentEvent
/**
* For an object connected to a single file. Provides an undo stack, tracks changes, and gives
@@ -87,7 +91,12 @@ abstract class Document extends Publisher {
}
def titleDescription : String = {
- file.map(f => f.getName).getOrElse("Untitled").replaceAll("\\.[^.]*$", "")
+ val base = file.map(f => f.getName).getOrElse("Untitled").replaceAll("\\.[^.]*$", "")
+ if (unsavedChanges) {
+ base + "*"
+ } else {
+ base
+ }
// val name : String = file.map(f => f.getName).getOrElse("Untitled")
// // If there is a description then use in instead of the file extension
// val nameDescription = if (description.length > 0) name.replaceAll("\\.[^.]*$", "") + " " + description else name
@@ -121,36 +130,53 @@ abstract class Document extends Publisher {
* @return true if the document can be closed, false otherwise
* (as per user decission)
*/
- def promptUnsaved() = {
+ def promptUnsaved(): Boolean = {
if (unsavedChanges) {
- val choice = Dialog.showOptions(
- title = "Unsaved changes",
- message = "Do you want to save your changes or discard them?",
- entries = "Save" :: "Discard" :: "Cancel" :: Nil,
- initial = 0
- )
+// val choice = Dialog.showOptions(
+// title = "Unsaved changes",
+// message = "Do you want to save your changes or discard them?",
+// entries = "Save" :: "Discard" :: "Cancel" :: Nil,
+// initial = 0
+// )
+
+ val choice = JOptionPane.showOptionDialog(null,
+ "Do you want to save your changes or discard them?",
+ "Unsaved changes in "+titleDescription,
+ JOptionPane.DEFAULT_OPTION,
+ JOptionPane.WARNING_MESSAGE, null,
+ List("Save", "Discard", "Cancel").toArray,
+ "Save")
// scala swing dialogs implementation is dumb, here's what I found :
// Result(0) = Save, Result(1) = Discard, Result(2) = Cancel
- if (choice == Dialog.Result(0)) trySave()
- else choice == Dialog.Result(1)
+ if (choice == 0) trySave()
+ else choice == 1
} else true
}
- def promptExists(f: File) = {
+ def promptExists(f: File): Boolean = {
if (f.exists()) {
- Dialog.showConfirmation(
- title = "File exists",
- message = "File exists, do you wish to overwrite?") == Dialog.Result.Yes
+ JOptionPane.showConfirmDialog(null,
+ "File exists, do you wish to overwrite?",
+ "File exists",
+ JOptionPane.YES_NO_OPTION) == JOptionPane.YES_OPTION
+// Dialog.showConfirmation(
+// title = "File exists",
+// message = "File exists, do you wish to overwrite?") == Dialog.Result.Yes
} else true
}
+
def errorDialog(action: String, reason: String) {
- Dialog.showMessage(
- title = "Error",
- message = "Cannot " + action + " file (" + reason + ")",
- messageType = Dialog.Message.Error)
+// Dialog.showMessage(
+// title = "Error",
+// message = "Cannot " + action + " file (" + reason + ")",
+// messageType = Dialog.Message.Error)
+ JOptionPane.showMessageDialog(null,
+ "Cannot " + action + " file (" + reason + ")",
+ "Error",
+ JOptionPane.ERROR_MESSAGE)
}
def previousDir_=(f: File) {
@@ -227,6 +253,12 @@ abstract class Document extends Publisher {
case UndoRegistered(_) =>
publish(DocumentChanged(this))
}
+
+ // Publishes a request for the view to focus on whatever is "correct" for the document.
+ // E.g. give focus to the text area if you open a .py file
+ def focusOnNaturalComponent() : Unit = {
+ publish(DocumentRequestingNaturalFocus(this))
+ }
}
trait HasDocument {
diff --git a/scala/src/main/scala/quanto/gui/DocumentPage.scala b/scala/src/main/scala/quanto/gui/DocumentPage.scala
index 7e41ed8c..f31c91cb 100644
--- a/scala/src/main/scala/quanto/gui/DocumentPage.scala
+++ b/scala/src/main/scala/quanto/gui/DocumentPage.scala
@@ -7,7 +7,9 @@ abstract class DocumentPage(component0: Component with HasDocument)
extends ClosablePage(
component0.document.titleDescription,
component0,
- closeAction = () => { component0.document.promptUnsaved() } )
+ closeAction = () => {
+ component0.document.promptUnsaved()}
+)
with Reactor
{
val document = component0.document
@@ -49,3 +51,18 @@ class DerivationDocumentPage(val project: Project)
extends DocumentPage(new DerivationPanel(project)) {
val documentType = "Derivation"
}
+
+class TheoryPage
+ extends DocumentPage(new TheoryEditPanel) {
+ val documentType = "Theory Editor"
+}
+
+class BatchDerivationPage
+ extends DocumentPage(new BatchDerivationCreatorPanel) {
+ val documentType = "Batch Derivation"
+}
+
+class BatchDerivationResultsPage
+ extends DocumentPage(new BatchDerivationResultsPanel) {
+ val documentType = "Batch Derivation Results"
+}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/gui/FileTree.scala b/scala/src/main/scala/quanto/gui/FileTree.scala
index 2f6f4e91..f555e673 100644
--- a/scala/src/main/scala/quanto/gui/FileTree.scala
+++ b/scala/src/main/scala/quanto/gui/FileTree.scala
@@ -15,6 +15,7 @@ import scala.swing.event.Event
abstract class FileTreeEvent extends Event
case class FileOpened(file: File) extends FileTreeEvent
+case class FileContextRequested(file: File, e: Option[MouseEvent]) extends FileTreeEvent
class FileTree extends BorderPanel {
val fileTreeModel = new FileTreeModel
@@ -41,12 +42,10 @@ class FileTree extends BorderPanel {
}
case _ =>
if (e.isPopupTrigger) {
- // TODO: Make actual popup menu
- // cf https://stackoverflow.com/questions/938753/scala-popup-menu
- // Needs access to a swing component to have as parent
+ val row = fileTree.getClosestRowForLocation(e.getX, e.getY)
+ fileTree.setSelectionRow(row)
fileTree.getLastSelectedPathComponent match {
- case FileNode(file) => Desktop.getDesktop.browse(file.toURI)
- case _ =>
+ case FileNode(file) => publish(FileContextRequested(file, Some(e)))
}
}
}
diff --git a/scala/src/main/scala/quanto/gui/FilteredList.scala b/scala/src/main/scala/quanto/gui/FilteredList.scala
new file mode 100644
index 00000000..a748fab3
--- /dev/null
+++ b/scala/src/main/scala/quanto/gui/FilteredList.scala
@@ -0,0 +1,55 @@
+package quanto.gui
+
+import quanto.util.UserOptions.scaleInt
+
+import scala.swing.event.ValueChanged
+import scala.swing.{BoxPanel, Component, Dimension, Label, ListView, Orientation, ScrollPane, Swing, TextField}
+
+// Create a list of strings to select from, with a regex filter box at the top
+// Regex: A text box that contains the regex to filter by
+// listItems: The list to filter on
+// ListComponent: The UI Component that the user selects elements from
+class FilteredList(val options: List[String],
+ baseWidth: Int = 400,
+ baseHeight: Int = 200) extends BoxPanel(Orientation.Vertical) {
+ val Regex = new TextField
+ val listItems: List[String] = options.sorted
+ val ListComponent: ListView[String] = new ListView[String](listItems)
+ private val ScrollContainer = new ScrollPane(ListComponent)
+ ScrollContainer.maximumSize = new Dimension(scaleInt(baseWidth), scaleInt(baseHeight))
+
+ private def VSpace: Component = Swing.VStrut(scaleInt(10))
+
+ private def HSpace: Component = Swing.HStrut(scaleInt(10))
+
+ contents += VSpace
+ val FilterPanel : Component = new BoxPanel(Orientation.Horizontal) {
+ contents += (HSpace, new Label("Filter:"), HSpace, Regex, HSpace)
+ maximumSize = new Dimension(scaleInt(baseWidth), scaleInt(20))
+ }
+
+ contents += FilterPanel
+ contents += VSpace
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (HSpace, ScrollContainer, HSpace)
+ }
+
+ listenTo(Regex)
+ reactions += {
+ case ValueChanged(Regex) =>
+ try {
+ val filtered = listItems.filter(
+ s => s.matches("(?i).*" + Regex.text + ".*"))
+ ListComponent.listData = filtered
+ if(filtered.nonEmpty) {
+ ListComponent.peer.setSelectionInterval(0, filtered.length - 1)
+ }
+ if(Regex.text.isEmpty){
+ ListComponent.peer.clearSelection()
+ }
+ } catch {
+ case e: Exception =>
+ //Exceptions here are thrown by inelligable regex from the user
+ }
+ }
+}
diff --git a/scala/src/main/scala/quanto/gui/GraphDocument.scala b/scala/src/main/scala/quanto/gui/GraphDocument.scala
index 6ebce8d3..1d20820f 100644
--- a/scala/src/main/scala/quanto/gui/GraphDocument.scala
+++ b/scala/src/main/scala/quanto/gui/GraphDocument.scala
@@ -23,6 +23,13 @@ class GraphDocument(val parent: Component, theory: Theory) extends Document with
// resetDocumentInfo()
// }
+ def replaceJson(json: Json) {
+ graph = Graph.fromJson(json, theory)
+ publish(GraphReplaced(this, clearSelection = true))
+ publish(DocumentReplaced(this))
+ publish(DocumentChanged(this))
+ }
+
protected def loadDocument(f: File) {
val json = Json.parse(f)
storedGraph = Graph.fromJson(json, theory)
diff --git a/scala/src/main/scala/quanto/gui/GraphEditController.scala b/scala/src/main/scala/quanto/gui/GraphEditController.scala
index ae00b3c3..772cfb39 100644
--- a/scala/src/main/scala/quanto/gui/GraphEditController.scala
+++ b/scala/src/main/scala/quanto/gui/GraphEditController.scala
@@ -1,19 +1,21 @@
package quanto.gui
-import graphview._
-import swing._
-import swing.event._
-import Key.Modifier
+import java.awt.Toolkit
+import java.awt.datatransfer._
+import java.awt.event.{ActionEvent, ActionListener}
+import java.util.Calendar
+
+import quanto.data.Names._
import quanto.data._
-import Names._
+import quanto.gui.graphview.{BBoxOverlay, EdgeOverlay, _}
import quanto.layout.ForceLayout
-import quanto.util.json._
import quanto.layout.constraint._
-import java.awt.event.{ActionEvent, ActionListener}
-import java.awt.datatransfer._
-import java.awt.Toolkit
-import quanto.gui.graphview.{EdgeOverlay,BBoxOverlay}
-import quanto.util.Globals
+import quanto.util.json._
+import quanto.util.{Globals, UserOptions}
+
+import scala.swing._
+import scala.swing.event.Key.Modifier
+import scala.swing.event._
case class VertexSelectionChanged(graph: Graph, selectedVerts: Set[VName]) extends GraphEvent
@@ -67,6 +69,7 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
val layoutTimer = new javax.swing.Timer(5, new ActionListener {
def actionPerformed(e: ActionEvent) {
if (qLayout.graph != null) {
+ view.requestFocusInWindow()
qLayout.step()
qLayout.updateGraph()
graph = qLayout.graph
@@ -80,6 +83,7 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
val layoutTimer1 = new javax.swing.Timer(5, new ActionListener {
def actionPerformed(e: ActionEvent) {
if (q1Layout.graph != null) {
+ view.requestFocusInWindow()
q1Layout.step()
q1Layout.updateGraph()
graph = q1Layout.graph
@@ -159,6 +163,21 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
undoStack.register("Add Edge") { deleteEdge(e) }
}
+ // Uses the current state of the controller to add an edge
+ private def addEdgeFromController(v1: VName, v2: VName): Unit = {
+ controlsOpt.foreach { c =>
+ val edgeType = theory.edgeTypes(c.EdgeTypeSelect.selection.item).defaultData
+ val theoryJSON : JsonObject = JsonObject(
+ "data" -> edgeType
+ )
+
+ val eData = if (c.EdgeDirected.selected) DirEdge.fromJson(theoryJSON, theory)
+ else UndirEdge.fromJson(theoryJSON, theory)
+
+ addEdge(graph.edges.fresh, eData, (v1, v2))
+ }
+ }
+
private def deleteEdge(e: EName) {
val d = graph.edata(e)
val vs = (graph.source(e), graph.target(e))
@@ -367,6 +386,14 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
view.repaint()
}
+ def selectAll() : Unit = {
+ selectedBBoxes = graph.bboxes
+ selectedVerts = graph.verts
+ selectedEdges = graph.edges
+ view.publish(VertexSelectionChanged(graph, graph.verts))
+ view.repaint()
+ }
+
private def roundIfSnapped(d : Double) = {
if (keepSnapped) math.rint(d / 0.25) * 0.25 else d // rounds to .25
}
@@ -384,6 +411,84 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
replaceGraph(newGraph, "Layout Graph")
}
+ var rDown = false
+
+ def startRelaxGraph(expandNodes : Boolean) {
+ view.requestFocusInWindow()
+ val layout = if (!expandNodes) q1Layout else qLayout
+ if (!rDown) {
+ rDown = true
+ layout.initialize(graph, randomCoords = false)
+ layout.clearLockedVertices()
+ if (!selectedVerts.isEmpty) {
+ graph.verts.foreach { v => if (!selectedVerts.contains(v)) layout.lockVertex(v) }
+ }
+
+ undoStack.start("Relax layout")
+ replaceGraph(graph, "")
+ if (!expandNodes) layoutTimer1.start() else layoutTimer.start()
+ }
+ }
+
+ def endRelaxGraph() {
+ if (rDown) {
+ rDown = false
+ layoutTimer.stop()
+ layoutTimer1.stop()
+
+ replaceGraph(graph, "")
+ undoStack.commit()
+ }
+ }
+
+ def cycleVertexType(vertex: VName, shift: Int = 1, includeWire: Boolean = false): Unit = {
+
+ val currentData = graph.vdata(vertex)
+ val options = theory.vertexTypes.keys.toSeq :+ ""
+ val current = options.indexOf(currentData.typ)
+
+ val next = options((current + shift + options.length) % options.length)
+ val newTyp = if (next == "" & !includeWire)
+ options((current + shift + 1 + options.length) % options.length)
+ else next
+
+ val coords = (currentData.coord._1, currentData.coord._2)
+ // Using replaceGraph because of data lost on cycling
+ replaceGraph(graph.updateVData(vertex)(d => {
+ newTyp match {
+ case "" =>
+ WireV.apply(coords)
+ case _ =>
+ NodeV(theory.vertexTypes(newTyp).defaultData,
+ annotation = d.annotation,
+ theory = theory).withCoord(coords)
+ }
+ }
+ ), "Cycled vertex")
+
+
+ }
+
+
+ def normaliseGraph(): Unit = {
+ replaceGraph(graph.normalise.coerceWiresAndBoundaries, "Normalised graph")
+ }
+
+ def minimiseGraph(): Unit = {
+ view.requestFocusInWindow()
+ replaceGraph(graph.minimise.coerceWiresAndBoundaries, "Minimised graph")
+ }
+
+ def focusOnGraph(): Unit = {
+ view.requestFocusInWindow()
+ view.focusOnGraph()
+ }
+
+ def vertexAt(point: Point) : Option[VName] = view.vertexDisplay find {
+ _._2.pointHit(point)
+ } map {
+ _._1
+ }
view.listenTo(view.mouse.clicks, view.mouse.moves)
view.reactions += {
case MousePressed(_, pt, modifiers, clicks, _) =>
@@ -475,7 +580,25 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
mouseState = box
view.selectionBox = Some(box.rect)
}
-
+ case FreehandTool(_, _) =>
+ undoStack.start("Freehand drag")
+ var vname = graph.verts.freshWithSuggestion(VName("v0"))
+ vertexAt(pt) match {
+ case Some(vertex) => // Already a vertex here, start a path
+ vname = vertex
+ mouseState = FreehandTool(Some(vname), startedWithNew = false)
+ case None => // No vertex here! Create a wire node if enough time and distance have passed
+ val coord = view.trans fromScreen(pt.getX, pt.getY)
+ controlsOpt.foreach { c =>
+ val vertexData = WireV(theory = theory)
+ addVertex(vname, vertexData.withCoord(coord))
+ }
+ mouseState = FreehandTool(Some(vname), startedWithNew = true)
+ }
+ view.edgeOverlay = Some(EdgeOverlay(pt, src = vname, tgt = Some(vname)))
+ view.repaint()
+ case RequestMinimiseGraph() =>
+ case RequestFocusOnGraph() =>
case state => throw new InvalidMouseStateException("MousePressed", state)
}
@@ -486,6 +609,17 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
case AddBoundaryTool() => // do nothing
case AddEdgeTool() => // do nothing
case AddBangBoxTool() => // do nothing
+ case FreehandTool(maybeVName, _) =>
+ val start = maybeVName.get
+ vertexAt(pt) match {
+ case Some(vertex) =>
+ view.edgeOverlay = Some(EdgeOverlay(pt, src = start, tgt = Some(vertex)))
+ case None =>
+ view.edgeOverlay = Some(EdgeOverlay(pt, src = start, tgt = None))
+ }
+ view.repaint()
+ case RequestMinimiseGraph() =>
+ case RequestFocusOnGraph() =>
case SelectionBox(start,_) =>
val box = SelectionBox(start, pt)
mouseState = box
@@ -495,7 +629,7 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
val box = BangSelectionBox(start, pt)
mouseState = box
view.selectionBox = Some(box.rect)
- view.repaint()
+ view.repaint()
case DragVertex(start, prev) =>
shiftVertsNoRegister(selectedVerts, start, prev, pt)
view.repaint()
@@ -510,13 +644,38 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
else None
view.bboxOverlay = Some(BBoxOverlay(pt, startBB, vertexHit, bboxHit))
view.repaint()
+ case state => throw new InvalidMouseStateException("MouseReleased", state)
}
case MouseReleased(_, pt, modifiers, _, _) =>
mouseState match {
- case SelectTool() => // do nothing
- case AddEdgeTool() => // do nothing
- case AddBangBoxTool () => // do nothing
+ case SelectTool() => // do nothing
+ case AddEdgeTool() => // do nothing
+ case AddBangBoxTool() => // do nothing
+ case FreehandTool(maybeVName, startedWithNew) =>
+ view.edgeOverlay = None
+ vertexAt(pt) match {
+ case Some(vertex) => // Vertex here
+
+ if (maybeVName.get == vertex) {
+ // dragged to itself
+ if (!startedWithNew) cycleVertexType(vertex)
+ } else {
+ // dragged to another vertex
+ addEdgeFromController(maybeVName.get, vertex)
+ }
+ case None =>
+ val coord = view.trans fromScreen(pt.getX, pt.getY)
+ val vname = graph.verts.freshWithSuggestion(VName("v0"))
+ controlsOpt.foreach { c =>
+ val vertexData = WireV(theory = theory)
+ addVertex(vname, vertexData.withCoord(coord))
+ } // dragged to another vertex
+ addEdgeFromController(maybeVName.get, vname)
+ }
+ undoStack.commit()
+ mouseState = FreehandTool(None, startedWithNew = false)
+
case SelectionBox(start,_) =>
val oldSelectedVerts = selectedVerts
@@ -615,11 +774,7 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
case DragEdge(startV) =>
val vertexHit = view.vertexDisplay find { _._2.pointHit(pt) } map { _._1 }
vertexHit.map { endV =>
- controlsOpt.map { c =>
- val defaultData = if (c.EdgeDirected.selected) DirEdge.fromJson(theory.defaultEdgeData, theory)
- else UndirEdge.fromJson(theory.defaultEdgeData, theory)
- addEdge(graph.edges.fresh, defaultData, (startV, endV))
- }
+ addEdgeFromController(startV, endV)
}
mouseState = AddEdgeTool()
view.edgeOverlay = None
@@ -647,6 +802,8 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
view.bboxOverlay = None
view.repaint()
+ case RequestMinimiseGraph() =>
+ case RequestFocusOnGraph() =>
case state => throw new InvalidMouseStateException("MouseReleased", state)
}
@@ -657,7 +814,7 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
view.listenTo(view.keys)
view.listenTo(view.mouse.wheel)
- var rDown = false
+
val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask
@@ -671,27 +828,12 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
undoStack.commit()
view.repaint()
}
+ case KeyPressed(_, Key.A, m, _) =>
+ if (Modifier.Control == (m & Modifier.Control)) selectAll()
case KeyPressed(_, Key.R, m, _) =>
- if (!rDown) {
- val layout = if ((m & Modifier.Shift) == Modifier.Shift) q1Layout else qLayout
- rDown = true
- layout.initialize(graph, randomCoords = false)
- layout.clearLockedVertices()
- if (!selectedVerts.isEmpty) {
- graph.verts.foreach { v => if (!selectedVerts.contains(v)) layout.lockVertex(v) }
- }
-
- undoStack.start("Relax layout")
- replaceGraph(graph, "")
- if ((m & Modifier.Shift) == Modifier.Shift) layoutTimer1.start() else layoutTimer.start()
- }
+ startRelaxGraph((m & Modifier.Shift) != Modifier.Shift)
case KeyReleased(_, Key.R, _, _) =>
- rDown = false
- layoutTimer.stop()
- layoutTimer1.stop()
-
- replaceGraph(graph, "")
- undoStack.commit()
+ endRelaxGraph()
case KeyReleased(_, Key.G, _, _) =>
snapToGrid()
//replaceGraph(graph, "")
@@ -705,8 +847,14 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
if ((modifiers & Globals.CommandDownMask) == Globals.CommandDownMask) { cutSubgraph() }
case KeyPressed(_, Key.V, modifiers, _) =>
if (modifiers == 0) {
+
mouseState = AddVertexTool()
- controlsOpt.map { c => c.setMouseState(mouseState) }
+ controlsOpt.foreach { c =>
+ if (c.GraphToolGroup.selected.contains(c.AddVertexButton)) {
+ c.VertexTypeSelect.selection.index = (c.VertexTypeSelect.selection.index + 1) % (theory.vertexTypes.size + 1)
+ }
+ c.setMouseState(mouseState)
+ }
}
else if ((modifiers & Globals.CommandDownMask) == Globals.CommandDownMask) { pasteSubgraph() }
case KeyPressed(_, Key.S, modifiers, _) =>
@@ -717,7 +865,12 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
case KeyPressed(_, Key.E, modifiers, _) =>
if (modifiers == 0) {
mouseState = AddEdgeTool()
- controlsOpt.map { c => c.setMouseState(mouseState) }
+ controlsOpt.foreach { c =>
+ if (c.GraphToolGroup.selected.contains(c.AddEdgeButton)) {
+ c.EdgeTypeSelect.selection.index = (c.EdgeTypeSelect.selection.index + 1) % theory.edgeTypes.size
+ }
+ c.setMouseState(mouseState)
+ }
}
case KeyPressed(_, Key.B, modifiers, _) =>
if (modifiers == 0) {
@@ -738,15 +891,45 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B
if ((modifiers & Globals.CommandDownMask) == Globals.CommandDownMask) {
snapToGrid()
}
- case KeyPressed(_, Key.F, _, _) =>
- undoStack.start("Flip edge direction")
- selectedEdges.foreach { flipEdge }
- undoStack.commit()
- view.repaint()
- case KeyPressed(_, Key.D, _, _) =>
- undoStack.start("Toggle edge directed")
- selectedEdges.foreach { toggleDirected }
- undoStack.commit()
- view.repaint()
+ case KeyPressed(_, Key.M, _, _) =>
+ minimiseGraph()
+ case KeyPressed(_, Key.F, modifiers, _) =>
+ if ((modifiers & Modifier.Shift) != Modifier.Shift) {
+ mouseState = FreehandTool(None, startedWithNew = false)
+ controlsOpt.foreach { c => c.setMouseState(mouseState) }
+ } else {
+ undoStack.start("Flip edge direction")
+ selectedEdges.foreach {
+ flipEdge
+ }
+ undoStack.commit()
+ view.repaint()
+ }
+ case KeyPressed(_, Key.D, modifiers, _) =>
+ // DOn't do this if pressing ctrl: ctrl+d is "Start new derivation"
+ if((modifiers & Modifier.Control) != Modifier.Control) {
+ if (selectedEdges.nonEmpty) {
+ if((modifiers & Modifier.Shift) != Modifier.Shift){
+
+ undoStack.start("Toggle edge directed")
+ selectedEdges.foreach {
+ toggleDirected
+ }
+ }else{
+ undoStack.start("Flip edge direction")
+ selectedEdges.foreach {
+ flipEdge
+ }
+ }
+ undoStack.commit()
+ view.repaint()
+ } else {
+ controlsOpt.foreach { c =>
+ if (c.GraphToolGroup.selected.contains(c.AddEdgeButton)) {
+ c.EdgeDirected.selected = true
+ }
+ }
+ }
+ }
}
}
diff --git a/scala/src/main/scala/quanto/gui/GraphEditPanel.scala b/scala/src/main/scala/quanto/gui/GraphEditPanel.scala
index 3ce570c2..071dbf38 100644
--- a/scala/src/main/scala/quanto/gui/GraphEditPanel.scala
+++ b/scala/src/main/scala/quanto/gui/GraphEditPanel.scala
@@ -2,9 +2,12 @@ package quanto.gui
import graphview.GraphView
import quanto.data._
+
import swing._
import swing.event._
import javax.swing.ImageIcon
+
+import quanto.util.UserAlerts
import quanto.util.swing.ToolBar
case class MouseStateChanged(m : MouseState) extends Event
@@ -51,7 +54,7 @@ class GraphEditControls(theory: Theory) extends Publisher {
}
val AddEdgeButton = new ToggleButton() with ToolButton {
- icon = new ImageIcon(GraphEditor.getClass.getResource("draw-path.png"), "Add Edge")
+ icon = new ImageIcon(GraphEditor.getClass.getResource("add-edge.png"), "Add Edge")
tool = AddEdgeTool()
tooltip = "Add Edge (E)"
}
@@ -62,16 +65,53 @@ class GraphEditControls(theory: Theory) extends Publisher {
tooltip = "Add Bang Box (B)"
}
+ val RelaxButton = new ToggleButton() with ToolButton {
+ icon = new ImageIcon(GraphEditor.getClass.getResource("expand.png"), "Relax graph")
+ tool = RelaxToolDown()
+ tooltip = "Relax graph (R/shift-R)"
+ }
+
+
+ val FocusGraphButton = new ToggleButton() with ToolButton {
+ icon = new ImageIcon(GraphEditor.getClass.getResource("focus.png"), "Resize Viewport")
+ tool = RequestFocusOnGraph()
+ tooltip = "Focus the viewport on the whole graph"
+ }
+
+
+
+ val FreehandButton = new ToggleButton() with ToolButton {
+ icon = new ImageIcon(GraphEditor.getClass.getResource("draw-path.png"), "Freehand drawing")
+ tool = FreehandTool(None, startedWithNew = false)
+ tooltip = "Freehand draw (F)"
+ }
+
+
+ val MinimiseButton = new ToggleButton() with ToolButton {
+ icon = new ImageIcon(GraphEditor.getClass.getResource("normalise.png"), "Minimise")
+ tooltip = "Straighten edges, and convert leaves to boundaries (M)"
+ tool = RequestMinimiseGraph()
+ }
+
val GraphToolGroup = new ButtonGroup(SelectButton,
AddVertexButton,
AddBoundaryButton,
AddEdgeButton,
- AddBangBoxButton)
+ AddBangBoxButton,
+ FreehandButton
+ )
def setMouseState(m : MouseState) {
- val previousTool = GraphToolGroup.selected
+ val previousToolButton = GraphToolGroup.selected
publish(MouseStateChanged(m))
m match {
+ case FreehandTool(_,_) =>
+ VertexTypeLabel.enabled = false
+ VertexTypeSelect.enabled = false
+ EdgeTypeLabel.enabled = false
+ EdgeTypeSelect.enabled = false
+ EdgeDirected.enabled = false
+ GraphToolGroup.select(FreehandButton)
case SelectTool() =>
VertexTypeLabel.enabled = false
VertexTypeSelect.enabled = false
@@ -80,7 +120,7 @@ class GraphEditControls(theory: Theory) extends Publisher {
EdgeDirected.enabled = false
GraphToolGroup.select(SelectButton)
case AddVertexTool() =>
- if(previousTool.nonEmpty && previousTool.get == AddVertexButton){
+ if(previousToolButton.nonEmpty && previousToolButton.get == AddVertexButton){
//VertexTypeSelect.selection.index = (VertexTypeSelect.selection.index + 1) % vertexOptions.size
}
VertexTypeLabel.enabled = true
@@ -90,7 +130,7 @@ class GraphEditControls(theory: Theory) extends Publisher {
EdgeDirected.enabled = false
GraphToolGroup.select(AddVertexButton)
case AddEdgeTool() =>
- if(previousTool.nonEmpty && previousTool.get == AddEdgeButton){
+ if(previousToolButton.nonEmpty && previousToolButton.get == AddEdgeButton){
//EdgeTypeSelect.selection.index = (EdgeTypeSelect.selection.index + 1) % edgeOptions.size
}
VertexTypeLabel.enabled = false
@@ -123,9 +163,38 @@ class GraphEditControls(theory: Theory) extends Publisher {
setMouseState(t.tool)
}
+ // These are all different, because they need to not take focus away from the view
+ listenTo(RelaxButton.mouse.clicks, MinimiseButton.mouse.clicks, FocusGraphButton.mouse.clicks)
+
+ reactions += {
+ case MousePressed(RelaxButton,_,_,_,_) =>
+ RelaxButton.selected = false
+ publish(MouseStateChanged(RelaxToolDown()))
+ case MouseReleased(RelaxButton,_,_,_,_) =>
+ RelaxButton.selected= false
+ publish(MouseStateChanged(RelaxToolUp()))
+ case MousePressed(FocusGraphButton,_,_,_,_) =>
+ FocusGraphButton.selected = false
+ publish(MouseStateChanged(RequestFocusOnGraph()))
+ case MouseReleased(FocusGraphButton,_,_,_,_) =>
+ FocusGraphButton.selected= false
+ publish(MouseStateChanged(RequestFocusOnGraph()))
+ case MousePressed(MinimiseButton,_,_,_,_) =>
+ MinimiseButton.selected = false
+ publish(MouseStateChanged(RequestMinimiseGraph()))
+ case MouseReleased(MinimiseButton,_,_,_,_) =>
+ MinimiseButton.selected= false
+ publish(MouseStateChanged(RequestMinimiseGraph()))
+ }
+
val MainToolBar = new ToolBar {
- contents += (SelectButton, AddVertexButton, AddBoundaryButton, AddEdgeButton, AddBangBoxButton)
+ contents += (SelectButton, AddVertexButton, AddBoundaryButton, AddEdgeButton, AddBangBoxButton, FreehandButton)
}
+ MainToolBar.peer.addSeparator()
+ MainToolBar.contents += FocusGraphButton
+ MainToolBar.contents += RelaxButton
+ MainToolBar.contents += MinimiseButton
+
}
@@ -168,7 +237,15 @@ with HasDocument
case UIElementResized(GraphViewScrollPane) =>
graphView.resizeViewToFit()
graphView.repaint()
+ case MouseStateChanged(RelaxToolDown()) => graphEditController.startRelaxGraph(true)
+ case MouseStateChanged(RelaxToolUp()) => graphEditController.endRelaxGraph()
+ case MouseStateChanged(RequestMinimiseGraph()) => graphEditController.minimiseGraph()
+ case MouseStateChanged(RequestFocusOnGraph()) => graphEditController.focusOnGraph()
case MouseStateChanged(m) =>
+ if (graphEditController.rDown) graphEditController.endRelaxGraph()
graphEditController.mouseState = m
- }
+ case DocumentRequestingNaturalFocus(d) =>
+ graphView.requestFocus()
+ }
+
}
diff --git a/scala/src/main/scala/quanto/gui/JsonConsole.scala b/scala/src/main/scala/quanto/gui/JsonConsole.scala
index d48d651a..2086949a 100644
--- a/scala/src/main/scala/quanto/gui/JsonConsole.scala
+++ b/scala/src/main/scala/quanto/gui/JsonConsole.scala
@@ -94,7 +94,7 @@ object JsonConsole extends SimpleSwingApplication {
border = new LineBorder(Color.WHITE, 1)
}
- def componentFor(list: ListView[_], isSelected: Boolean,
+ override def componentFor(list: ListView[_ <: CoreOutputItem], isSelected: Boolean,
focused: Boolean, a: CoreOutputItem, index: Int): Component =
{
@@ -114,6 +114,7 @@ object JsonConsole extends SimpleSwingApplication {
panel
}
+
}
}
diff --git a/scala/src/main/scala/quanto/gui/MouseState.scala b/scala/src/main/scala/quanto/gui/MouseState.scala
index 6748cdab..35d53fe9 100644
--- a/scala/src/main/scala/quanto/gui/MouseState.scala
+++ b/scala/src/main/scala/quanto/gui/MouseState.scala
@@ -61,4 +61,14 @@ case class DragEdge(startVertex: VName) extends MouseState
case class AddBangBoxTool() extends MouseState
/** A nesting edge is being dragged from the bang box corner */
-case class DragBangBoxNesting(startBBox: BBName) extends MouseState
\ No newline at end of file
+case class DragBangBoxNesting(startBBox: BBName) extends MouseState
+
+/** The relax tool is pressed down or released again */
+case class RelaxToolDown() extends MouseState
+case class RelaxToolUp() extends MouseState
+
+// The freehand tool is being used
+case class FreehandTool(start: Option[VName], startedWithNew: Boolean) extends MouseState
+
+case class RequestMinimiseGraph() extends MouseState
+case class RequestFocusOnGraph() extends MouseState
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/gui/NewProjectDialog.scala b/scala/src/main/scala/quanto/gui/NewProjectDialog.scala
index d947f634..c762def1 100644
--- a/scala/src/main/scala/quanto/gui/NewProjectDialog.scala
+++ b/scala/src/main/scala/quanto/gui/NewProjectDialog.scala
@@ -15,7 +15,12 @@ class NewProjectDialog extends Dialog {
val NameField = new TextField()
val ProjectLocationField = new TextField(System.getProperty("user.home"))
val BrowseProjectButton = new Button("...")
- val TheoryChoiceDropdown = new ComboBox(Seq[String]("ZX", "ZW", "From existing project", "From .qtheory file"))
+ val TheoryChoiceDropdown = new ComboBox(Seq[String](
+ "ZX",
+ "ZW",
+ "From existing project",
+ "From .qtheory file",
+ "plain"))
val TheoryLocationField = new TextField("")
val BrowseTheoryButton = new Button("...")
val theoryName = new TextField("")
@@ -102,7 +107,7 @@ class NewProjectDialog extends Dialog {
fileChoiceFilter = filter
}
- disableFileChoosers("red_green")
+ disableFileChoosers("ZX")
reactions += {
case ButtonClicked(CreateButton) =>
@@ -111,11 +116,11 @@ class NewProjectDialog extends Dialog {
val path = ProjectLocationField.text
val folder = new File(path + "/" + name)
if (name.isEmpty) {
- UserAlerts.errorbox("Please enter a name for your project.")
+ UserAlerts.errorBox("Please enter a name for your project.")
} else if (folder.exists()) {
- UserAlerts.errorbox("That folder is already in use.")
+ UserAlerts.errorBox("That folder is already in use.")
} else if (theory.isEmpty) {
- UserAlerts.errorbox("Please choose a theory.")
+ UserAlerts.errorBox("Please choose a theory.")
} else {
result = Some((theory, name, path))
close()
@@ -143,9 +148,11 @@ class NewProjectDialog extends Dialog {
case SelectionChanged(TheoryChoiceDropdown) =>
TheoryChoiceDropdown.selection.item match {
case "ZX" =>
- disableFileChoosers("red_green")
+ disableFileChoosers("ZX")
case "ZW" =>
- disableFileChoosers("black_white")
+ disableFileChoosers("ZW")
+ case "plain" =>
+ disableFileChoosers("plain")
case "From existing project" =>
enableFileChoosers("Project files", "qproject")
case "From .qtheory file" =>
diff --git a/scala/src/main/scala/quanto/gui/PythonEditPanel.scala b/scala/src/main/scala/quanto/gui/PythonEditPanel.scala
index 3208443c..3d4bb3b4 100644
--- a/scala/src/main/scala/quanto/gui/PythonEditPanel.scala
+++ b/scala/src/main/scala/quanto/gui/PythonEditPanel.scala
@@ -11,50 +11,35 @@ import java.awt.event.{KeyAdapter, KeyEvent}
import javax.swing.ImageIcon
import quanto.util.swing.ToolBar
-import quanto.util.UserAlerts.{alert, Elevation, SelfAlertingProcess}
+import quanto.util.UserAlerts.{Elevation, SelfAlertingProcess, alert}
-import scala.swing.event.ButtonClicked
+import scala.swing.event.{ButtonClicked, Event}
import quanto.util._
import java.io.{File, PrintStream}
-class PythonEditPanel extends BorderPanel with HasDocument {
- val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask
-
- val pyMode = new Mode("Python")
-
- //val modeXml =
- // if (Globals.isBundle) new File("python.xml").getAbsolutePath
- // else getClass.getResource("python.xml").getPath
- pyMode.setProperty("file", QuantoDerive.pythonModeFile)
- //println(sml.getProperty("file"))
- val code = StandaloneTextArea.createTextArea()
- code.setFont(UserOptions.font)
- //mlCode.setFont(new Font("Menlo", Font.PLAIN, 14))
+import quanto.rewrite.Simproc
- val buf = new JEditBuffer1
- buf.setMode(pyMode)
+case class SimprocsUpdated() extends Event
- var execThread : Thread = null
+class PythonEditPanel extends BorderPanel with HasDocument {
+ val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask
- code.setBuffer(buf)
+ val CodeArea : TextEditor = new TextEditor(TextEditor.Modes.python)
- code.addKeyListener(new KeyAdapter {
- override def keyPressed(e: KeyEvent) {
- if (e.getModifiers == CommandMask) e.getKeyChar match {
- case 'x' => Registers.cut(code, '$')
- case 'c' => Registers.copy(code, '$')
- case 'v' => Registers.paste(code, '$')
- case _ =>
- }
- }
- })
+ // Inject python to expose some relevant variables
+ def documentName: String = document.file.map(
+ // Assume the user has a project loaded, otherwise shouldn't be able to access GUI
+ f => QuantoDerive.CurrentProject.get.relativePath(f)
+ ).getOrElse("Unsaved File")
- val document = new CodeDocument("Python Script", "py", this, code)
+ // Now run the python along with the header
+ def code : String = CodeArea.getText
+ val document = new CodeDocument("Python Script", "py", this, CodeArea.TextArea)
+ listenTo(document)
- val textPanel = new BorderPanel {
- peer.add(code, BorderLayout.CENTER)
- }
+ var execThread : Thread = null
+ val textPanel = CodeArea.Component
val RunButton = new Button() {
icon = new ImageIcon(GraphEditor.getClass.getResource("start.png"), "Run scala code")
@@ -87,20 +72,40 @@ class PythonEditPanel extends BorderPanel with HasDocument {
listenTo(RunButton, InterruptButton)
+ def allSimprocs : Map[String, Simproc] = QuantoDerive.CurrentProject.map(p => p.simprocs).getOrElse(Map())
reactions += {
+ case DocumentRequestingNaturalFocus(_) =>
+ CodeArea.TextArea.requestFocus()
case ButtonClicked(RunButton) =>
if (execThread == null) {
- val processReporting = new SelfAlertingProcess("Python from source")
+ val processReporting = new SelfAlertingProcess(s"Python $documentName")
execThread = new Thread(new Runnable {
def run() {
try {
val python = new PythonInterpreter
+
+ QuantoDerive.CurrentProject.foreach(project => project.lastRunPythonFilePath = Some(documentName))
+ def simprocsFromThisFile = allSimprocs.filter(kv => kv._2.sourceFile == documentName).keys
+ // unregister any simprocs previously linked to this file
+ simprocsFromThisFile.foreach(simprocName => QuantoDerive.CurrentProject.foreach(
+ p => p.simprocs -= simprocName
+ ))
+
QuantoDerive.CurrentProject.foreach(pr => python.getSystemState.path.add(pr.rootFolder))
+ outputTextArea.text = ""
python.set("output", output)
-
- //python.set("output", output)
- python.exec(code.getBuffer.getText)
+ python.exec(code)
+
+ // Tell the user which simprocs are linked to this file
+ alert(s"Simprocs registered to $documentName: " +
+ simprocsFromThisFile.mkString(", ")
+ )
+ // Link this python to those simprocs
+ simprocsFromThisFile.foreach(simprocName => QuantoDerive.CurrentProject.foreach(
+ p => p.simprocs(simprocName).sourceCode = code
+ ))
+ PythonEditPanel.publishUpdate()
processReporting.finish()
} catch {
case e : Throwable =>
@@ -124,3 +129,9 @@ class PythonEditPanel extends BorderPanel with HasDocument {
}
}
}
+
+object PythonEditPanel extends Publisher {
+ def publishUpdate() : Unit = {
+ publish(SimprocsUpdated())
+ }
+}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/gui/QuantoDerive.scala b/scala/src/main/scala/quanto/gui/QuantoDerive.scala
index 1e4ac6c4..3ccf26e7 100644
--- a/scala/src/main/scala/quanto/gui/QuantoDerive.scala
+++ b/scala/src/main/scala/quanto/gui/QuantoDerive.scala
@@ -5,13 +5,16 @@ import org.python.util.PythonInterpreter
import scala.io.Source
import scala.swing._
-import scala.swing.event.{Key, SelectionChanged}
-import javax.swing.{KeyStroke, UIManager}
+import scala.swing.event.{Key, KeyPressed, SelectionChanged}
+import javax.swing.{JOptionPane, KeyStroke, SwingUtilities, UIManager}
import java.awt.event.KeyEvent
+import java.awt.Frame
+import java.awt.event.{KeyEvent, MouseAdapter, MouseEvent}
-import quanto.util.json.{Json, JsonString}
+import quanto.util.json.{Json, JsonObject, JsonString}
import quanto.data._
import java.io.{File, FilenameFilter, IOException, PrintWriter}
+
import javax.swing.plaf.metal.MetalLookAndFeel
import java.util.prefs.Preferences
@@ -25,44 +28,69 @@ import akka.actor.PoisonPill
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext
import ExecutionContext.Implicits.global
-import java.awt.{Color, Window}
-import javax.swing.SwingUtilities
+import java.awt.{Color, Desktop, Window}
+import java.lang.NullPointerException
+
+import javax.imageio.ImageIO
+import javax.swing.filechooser.FileNameExtensionFilter
+import quanto.gui.QuantoDerive.FileMenu.mnemonic
+import quanto.gui.FileOpened
+import quanto.util._
-import quanto.util.{Globals, UserAlerts, UserOptions, WebHelper}
+class NoProjectException extends Exception("No project open.")
object QuantoDerive extends SimpleSwingApplication {
+ override def main(args: Array[String]): Unit = {
+ super.main(args)
+ if (args.length > 0) {
+ alert("Loading project from commandline")
+ val arg = args(0)
+ loadProject(arg)
+ }
+ if (args.length > 1) {
+ alert("Opening file from commandline")
+ val fname = args(1)
+ ProjectFileTree.publish(FileOpened(new File(fname)))
+ }
+ }
+
+ System.setProperty("apple.eawt.quitStrategy", "CLOSE_ALL_WINDOWS")
val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask
val actorSystem = ActorSystem("QuantoDerive")
//val core = actorSystem.actorOf(Props { new Core }, "core")
implicit val timeout = Timeout(1.day)
- // copy python mode xml into a temp file, as jEdit component can't handle JAR resources
- lazy val pythonModeFile = {
- val f = File.createTempFile("python", "xml")
- f.deleteOnExit()
- val pr = new PrintWriter(f)
- Source.fromInputStream(getClass.getResourceAsStream("python.xml")).foreach(pr.print)
- pr.close()
- f.getCanonicalPath
- }
-
// pre-initialise jython, so its zippy when the user clicks "run" in a script
new Thread(new Runnable { def run() { new PythonInterpreter() }}).start()
- println(new File(".").getAbsolutePath)
+ UserAlerts.alert("Working directory: " + new File(".").getAbsolutePath)
+
+ // Dialogs in in scala.swing seem to be broken since updated scala to 2.12, so
+ // we're using the javax.swing versions instead
+ def error(msg: String) =
+ UserAlerts.errorBox(msg)
- def error(msg: String) = Dialog.showMessage(
- title = "Error", message = msg, messageType = Dialog.Message.Error)
+ def alert(msg: String) =
+ UserAlerts.alert(msg)
+ def warn(msg: String) =
+ UserAlerts.alert(msg, UserAlerts.Elevation.WARNING)
+
+ def uiScale(i : Int) : Int = UserOptions.scaleInt(i)
+
+ //Dialog.showMessage(title = "Error", message = msg, messageType = Dialog.Message.Error)
val prefs = Preferences.userRoot().node(this.getClass.getName)
try {
UIManager.setLookAndFeel(new MetalLookAndFeel) // tabs in OSX PLAF look bad
- UserOptions.uiScale = prefs.getDouble("uiScale", 1.0)
+ UserOptions.uiScale = UserOptions.uiScale // Initiliases all the UI options
} catch {
- case e: Exception => e.printStackTrace()
+ case e: Exception => {
+ UserAlerts.alert("Could not load UI preferences on startup.")
+ e.printStackTrace()
+ }
}
def unloadProject() {
@@ -70,22 +98,44 @@ object QuantoDerive extends SimpleSwingApplication {
ProjectFileTree.root = None
}
- def loadProject(projectLocation: String) : Option[Project] = {
- UserAlerts.alert(s"Opening project: $projectLocation")
- val projectFile = new File(projectLocation + "/main.qproject")
+ def updateProjectFile(projectFile: File): Unit = {
+ if (CurrentProject.nonEmpty) {
+ val project = CurrentProject.get
+ try {
+ if (projectFile.exists) {
+ val parsedInput = Json.parse(projectFile)
+ if (Project.toJson(project).toString != parsedInput.toString) {
+ Project.toJson(project).writeTo(project.projectFile)
+ UserAlerts.alert(s"Updated project file", UserAlerts.Elevation.DEBUG)
+ }
+ }
+ } catch {
+ case e: Exception =>
+ throw new ProjectLoadException("Error loading project", e)
+ }
+ }
+ }
+
+ def loadProject(projectFileLocation: String) : Option[Project] = {
+ alert(s"Opening project: $projectFileLocation")
+
+ val projectFile = if(new File(projectFileLocation).isDirectory){
+ new File(projectFileLocation + "/main.qproject")
+ } else {
+ new File(projectFileLocation)
+ }
try {
if (projectFile.exists) {
val parsedInput = Json.parse(projectFile)
- val project = Project.fromJson(parsedInput, projectLocation)
+ val project = Project.fromJson(parsedInput, new File(projectFileLocation))
// Old .qproject files had links rather than embedded theories
- if (Project.toJson(project).toString != parsedInput.toString) {
- UserAlerts.alert("Updating out of date .qproject file")
- Project.toJson(project).writeTo(new File(projectLocation + "/main.qproject"))
- }
+ // So update when loading in
CurrentProject = Some(project)
- ProjectFileTree.root = Some(projectLocation)
- prefs.put("lastProjectFolder", projectLocation)
- UserAlerts.alert(s"Successfully loaded project: $projectLocation")
+ updateProjectFile(projectFile)
+ ProjectFileTree.root = Some(project.rootFolder)
+ prefs.put("lastProjectFile", projectFileLocation)
+ UserAlerts.registerLogFile(Some(new File(project.rootFolder + s"/${project.name}_log.txt")))
+ alert(s"Successfully loaded project: $projectFileLocation")
Some(project)
} else {
UserAlerts.alert("Selected project file does not exist", UserAlerts.Elevation.ERROR)
@@ -98,16 +148,16 @@ object QuantoDerive extends SimpleSwingApplication {
unloadProject()
None
} finally {
- FileMenu.updateNewEnabled()
+ refreshAllMenusAndTitle()
}
}
//CurrentProject.map { pr => core ! SetMLWorkingDir(pr.rootFolder) }
val ProjectFileTree = new FileTree
- ProjectFileTree.preferredSize = new Dimension(250,360)
+ ProjectFileTree.preferredSize = new Dimension(uiScale(250), uiScale(360))
ProjectFileTree.filenameFilter = Some(new FilenameFilter {
- val extns = Set("qgraph", "qrule", "qderive", "ML", "py")
+ val extns = Set("qgraph", "qrule", "qderive", "ML", "py", "qsbr")
def accept(parent: File, name: String) = {
val extn = name.lastIndexOf('.') match {
case i if i > 0 => name.substring(i+1) ; case _ => ""}
@@ -120,7 +170,7 @@ object QuantoDerive extends SimpleSwingApplication {
})
- var CurrentProject : Option[Project] = prefs.get("lastProjectFolder", null) match {
+ var CurrentProject : Option[Project] = prefs.get("lastProjectFile", null) match {
case path : String =>
try {
loadProject(path)
@@ -132,6 +182,28 @@ object QuantoDerive extends SimpleSwingApplication {
case _ => None
}
+ // Access via even is preferred, as then we can pinpoint where to put the popup
+ def popup(menu: PopupMenu, e: Option[MouseEvent]) : Unit = {
+ if (e.nonEmpty){
+ val componentBounds = e.get.getComponent.getBounds
+ val shift : Int = UserOptions.scaleInt(5)
+ popup(menu, e.get.getX + componentBounds.x + shift, e.get.getY + componentBounds.y+ shift)
+ } else {
+ popup(menu, 0, 0)
+ }
+ }
+
+ def popup(menu: PopupMenu, x: Int, y: Int) : Unit = {
+ menu.show(Main, x, y)
+ }
+
+ def addAndFocusPage(d : DocumentPage): Unit = {
+ MainDocumentTabs += d
+ listenTo(d.tabComponent)
+ MainDocumentTabs.focus(d)
+ d.document.publish(DocumentChanged(d.document))
+ d.document.focusOnNaturalComponent()
+ }
listenTo(quanto.util.UserOptions.OptionsChanged)
reactions += {
@@ -144,16 +216,12 @@ object QuantoDerive extends SimpleSwingApplication {
}
}
- val MainTabbedPane = new ClosableTabbedPane
+ val MainDocumentTabs = new DocumentTabs
- def currentDocument: Option[HasDocument] =
- MainTabbedPane.currentContent match {
- case Some(doc: HasDocument) => Some(doc)
- case _ => None
- }
+ def currentDocument: Option[DocumentPage] = MainDocumentTabs.currentFocus
def currentGraphController: Option[GraphEditController] =
- MainTabbedPane.currentContent match {
+ MainDocumentTabs.currentContent match {
case Some(p: GraphEditPanel) => Some(p.graphEditController)
case Some(p: RuleEditPanel) => Some(p.focusedController)
case _ => None
@@ -185,17 +253,18 @@ object QuantoDerive extends SimpleSwingApplication {
def histView = _histView
object LeftSplit extends SplitPane {
+ resizeWeight = 0.5
orientation = Orientation.Horizontal
contents_=(ProjectFileTree, HistViewSlot)
}
object Split extends SplitPane {
orientation = Orientation.Vertical
- contents_=(LeftSplit, MainTabbedPane)
+ contents_=(LeftSplit, MainDocumentTabs.component)
}
def hasUnsaved =
- MainTabbedPane.pages.exists { p => p.content match {
+ MainDocumentTabs.documents.exists { p => p.content match {
case c : HasDocument => c.document.unsavedChanges
case _ => false
}}
@@ -207,8 +276,8 @@ object QuantoDerive extends SimpleSwingApplication {
* (depends on user choice)
*/
def trySaveAll() = {
- MainTabbedPane.pages.forall { p =>
- MainTabbedPane.selection.index = p.index // focus a pane before saving
+ MainDocumentTabs.documents.forall { p =>
+ MainDocumentTabs.focus(p) // focus a pane before saving
p.content match {
case c : HasDocument => c.document.trySave()
case _ => false
@@ -218,43 +287,79 @@ object QuantoDerive extends SimpleSwingApplication {
/**
* Show a dialog (when necessary) asking the user if the program should quit
+ *@param specific : Specify a list to close, or None to close all
* @return true if the program should quit, false otherwise
*/
- def closeAllDocuments() = {
+ def closeAllOrListOfDocuments(specific: Option[List[DocumentPage]] = None) : Boolean = {
if (hasUnsaved) {
- val choice = Dialog.showOptions(
- title = "Confirm quit",
- message = "Some documents have unsaved changes.\nDo you want to save your changes or discard them?",
- entries = "Save" :: "Discard" :: "Cancel" :: Nil,
- initial = 0
- )
+// val choice = Dialog.showOptions(
+// title = "Confirm quit",
+// message = "Some documents have unsaved changes.\nDo you want to save your changes or discard them?",
+// entries = "Save" :: "Discard" :: "Cancel" :: Nil,
+// initial = 0
+// )
+
+ val choice = JOptionPane.showOptionDialog(null,
+ "Do you want to save your changes or discard them?",
+ "Unsaved changes",
+ JOptionPane.DEFAULT_OPTION,
+ JOptionPane.WARNING_MESSAGE, null,
+ List("Save", "Discard", "Cancel").toArray,
+ "Save")
+
// scala swing dialogs implementation is dumb, here's what I found :
// Result(0) = Save, Result(1) = Discard, Result(2) = Cancel
- if (choice == Dialog.Result(2)) false
- else if (choice == Dialog.Result(1)) {
- MainTabbedPane.pages.clear()
+ if (choice == 2) false
+ else if (choice == 1) {
+ if(specific.nonEmpty){
+ for(page <- specific.get) {MainDocumentTabs.remove(page)}
+ } else {
+ MainDocumentTabs.clear()
+ }
true
}
else {
val b = trySaveAll()
- if (b) MainTabbedPane.pages.clear()
+ if (b) {
+ if (specific.nonEmpty) {
+ for (page <- specific.get) {
+ MainDocumentTabs.remove(page)
+ }
+ } else {
+ MainDocumentTabs.clear()
+ }
+ }
b
}
}
else {
- MainTabbedPane.pages.clear()
+ if(specific.nonEmpty){
+ for(page <- specific.get) {MainDocumentTabs.remove(page)}
+ } else {
+ MainDocumentTabs.clear()
+ }
true
}
}
- def quitQuanto() = {
- if (closeAllDocuments()) {
+ def quitQuanto(): Boolean = {
+ val close = closeAllOrListOfDocuments()
+ if (close) {
try {
//core ! StopCore
//core ! PoisonPill
} catch {
case e : Exception => e.printStackTrace()
}
+ val rect = _mainframe.peer.getBounds()
+ val isFullScreen : Boolean = _mainframe.peer.getExtendedState() == Frame.MAXIMIZED_BOTH
+ prefs.putBoolean("fullscreen", isFullScreen)
+ if (!isFullScreen) {
+ prefs.putInt("locationx",rect.x)
+ prefs.putInt("locationy",rect.y)
+ prefs.putInt("screenwidth",rect.width)
+ prefs.putInt("screenheight",rect.height)
+ }
true
} else {
false
@@ -274,6 +379,79 @@ object QuantoDerive extends SimpleSwingApplication {
// }
+ def FolderContextMenu(folder: File) : PopupMenu = new PopupMenu { //Context menu for project folders
+ menu =>
+
+ val OpenLocationAction: Action = new Action("Open Folder") {
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.L
+ }
+
+ def apply() {
+ Desktop.getDesktop.browse(folder.toURI)
+ }
+ }
+
+ }
+
+ def FileContextMenu(file: File): PopupMenu = new PopupMenu { //Context menu for project files
+ menu =>
+
+ val OpenLocationAction: Action = new Action("Open File Location") {
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.L
+ }
+
+ def apply() {
+ Desktop.getDesktop.browse(file.getParentFile.toURI)
+ }
+ }
+
+
+ val DeleteFile: Action = new Action("Delete File") {
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.D
+ }
+
+ def apply() {
+ file.delete()
+ }
+ }
+
+ (FileHelper.extension(file), MainDocumentTabs.currentContent) match {
+ case ("qrule", Some(dp: DerivationPanel)) =>
+ val AddToRewrites : Action = new Action("Add to current derivation") {
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.R
+ }
+
+ def apply() {
+ alert(s"Publishing request for rule")
+ if(CurrentProject.nonEmpty){
+ val project = CurrentProject.get
+ val relativePath = project.relativePath(file)
+ val ruleDesc = RuleDesc(relativePath.substring(0, relativePath.length-".qrule".length))
+ dp.publish(SuggestRewriteRule(ruleDesc))
+ }
+ }
+ }
+ case _ =>
+ }
+ }
+
+
+ def newGraph(): GraphDocument = {
+ CurrentProject match {
+ case Some(project) =>
+ val page = new GraphDocumentPage(project.theory)
+ addAndFocusPage(page)
+ page.document.asInstanceOf[GraphDocument]
+ case None =>
+ throw new NoProjectException
+ }
+ }
+
+
object FileMenu extends Menu("File") { menu =>
mnemonic = Key.F
@@ -281,11 +459,7 @@ object QuantoDerive extends SimpleSwingApplication {
accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_N, CommandMask))
menu.contents += new MenuItem(this) { mnemonic = Key.G }
def apply() {
- CurrentProject.foreach{ project =>
- val page = new GraphDocumentPage(project.theory)
- MainTabbedPane += page
- MainTabbedPane.selection.index = page.index
- }
+ newGraph()
}
}
@@ -295,8 +469,7 @@ object QuantoDerive extends SimpleSwingApplication {
def apply() {
CurrentProject.foreach{ project =>
val page = new RuleDocumentPage(project.theory)
- MainTabbedPane += page
- MainTabbedPane.selection.index = page.index
+ addAndFocusPage(page)
}
}
}
@@ -319,33 +492,18 @@ object QuantoDerive extends SimpleSwingApplication {
def apply() {
CurrentProject.foreach{ project =>
val page = new PythonDocumentPage
- MainTabbedPane += page
- MainTabbedPane.selection.index = page.index
+ addAndFocusPage(page)
}
}
}
- def updateNewEnabled() {
- CurrentProject match {
- case Some(_) =>
- NewGraphAction.enabled = true
- NewAxiomAction.enabled = true
- //NewMLAction.enabled = true
- case None =>
- NewGraphAction.enabled = false
- NewAxiomAction.enabled = false
- //NewMLAction.enabled = false
- }
- }
-
- updateNewEnabled()
val SaveAction = new Action("Save") {
accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_S, CommandMask))
enabled = false
menu.contents += new MenuItem(this) { mnemonic = Key.S }
def apply() {
- MainTabbedPane.currentContent match {
+ MainDocumentTabs.currentContent match {
case Some(doc: HasDocument) =>
doc.document.file match {
case Some(_) => doc.document.save()
@@ -361,7 +519,7 @@ object QuantoDerive extends SimpleSwingApplication {
enabled = false
menu.contents += new MenuItem(this) { mnemonic = Key.A }
def apply() {
- MainTabbedPane.currentContent match {
+ MainDocumentTabs.currentContent match {
case Some(doc: HasDocument) =>
doc.document.showSaveAsDialog(CurrentProject.map(_.rootFolder))
case _ =>
@@ -374,9 +532,9 @@ object QuantoDerive extends SimpleSwingApplication {
enabled = false
menu.contents += new MenuItem(this) { mnemonic = Key.V }
def apply() {
- val selection = MainTabbedPane.selection.index
+ val selection = MainDocumentTabs.selection.index
trySaveAll()
- MainTabbedPane.selection.index = selection
+ MainDocumentTabs.selection.index = selection
}
}
@@ -386,7 +544,7 @@ object QuantoDerive extends SimpleSwingApplication {
menu.contents += new MenuItem(this) { mnemonic = Key.N }
def apply() {
- if (closeAllDocuments()) {
+ if (closeAllOrListOfDocuments()) {
val d = new NewProjectDialog()
d.centerOnScreen()
d.open()
@@ -395,11 +553,11 @@ object QuantoDerive extends SimpleSwingApplication {
println("got: " + (theoryFile, name, path))
val folder = new File(path + "/" + name)
if (name.isEmpty) {
- UserAlerts.errorbox("Please enter a name for your project.")
+ error("Please enter a name for your project.")
} else if (folder.exists()) {
- UserAlerts.errorbox("That folder is already in use.")
+ error("That folder is already in use.")
} else if (theoryFile.isEmpty) {
- UserAlerts.errorbox("Please enter a theory file.")
+ error("Please enter a theory file.")
} else {
folder.mkdirs()
new File(folder.getPath + "/graphs").mkdir()
@@ -407,12 +565,13 @@ object QuantoDerive extends SimpleSwingApplication {
new File(folder.getPath + "/theorems").mkdir()
new File(folder.getPath + "/derivations").mkdir()
new File(folder.getPath + "/simprocs").mkdir()
+ val projectFile = new File(folder.getPath + "/" + name + ".qproject")
val rootFolder = folder.getAbsolutePath
- val proj = Project.fromTheoryOrProjectFile(theoryFile, rootFolder, name)
- Project.toJson(proj).writeTo(new File(folder.getPath + "/main.qproject"))
- loadProject(folder.getPath)
+ val proj = Project.fromTheoryOrProjectFile(new File(theoryFile), new File(rootFolder), name)
+ Project.toJson(proj).writeTo(projectFile)
+ loadProject(projectFile.getAbsolutePath)
//core ! SetMLWorkingDir(rootFolder)
- updateNewEnabled()
+ refreshAllMenusAndTitle()
}
case None =>
}
@@ -423,16 +582,16 @@ object QuantoDerive extends SimpleSwingApplication {
val OpenProjectAction = new Action("Open Project...") {
menu.contents += new MenuItem(this) { mnemonic = Key.O }
def apply() {
- if (closeAllDocuments()) {
+ if (closeAllOrListOfDocuments()) {
val chooser = new FileChooser()
- chooser.fileSelectionMode = FileChooser.SelectionMode.DirectoriesOnly
+ chooser.fileFilter = new FileNameExtensionFilter("Quantomatic Project File (*.qproject)", "qproject")
+ chooser.fileSelectionMode = FileChooser.SelectionMode.FilesOnly
chooser.showOpenDialog(Split) match {
case FileChooser.Result.Approve =>
- val rootFolder = chooser.selectedFile.toString
- val projectFile = new File(rootFolder + "/main.qproject")
+ val projectFile = new File(chooser.selectedFile.toString)
if (projectFile.exists) {
try {
- loadProject(rootFolder)
+ loadProject(chooser.selectedFile.toString)
//core ! SetMLWorkingDir(rootFolder)
} catch {
case _: ProjectLoadException =>
@@ -441,10 +600,10 @@ object QuantoDerive extends SimpleSwingApplication {
error("Unexpected error when opening project")
e.printStackTrace()
} finally {
- updateNewEnabled()
+ refreshAllMenusAndTitle()
}
} else {
- error("Folder does not contain a Quantomatic project")
+ error(s"Folder does not contain a Quantomatic project: $projectFile")
}
case _ =>
}
@@ -455,14 +614,15 @@ object QuantoDerive extends SimpleSwingApplication {
val CloseProjectAction = new Action("Close Project") {
menu.contents += new MenuItem(this) { mnemonic = Key.C }
def apply() {
- if (closeAllDocuments()) {
+ if (closeAllOrListOfDocuments()) {
ProjectFileTree.root = None
CurrentProject = None
- updateNewEnabled()
+ refreshAllMenusAndTitle()
}
}
}
+
menu.contents += new Separator()
val QuitAction = new Action("Quit") {
@@ -489,15 +649,17 @@ object QuantoDerive extends SimpleSwingApplication {
case Some(doc) =>
enabled = doc.document.undoStack.canUndo
title = "Undo " + doc.document.undoStack.undoActionName.getOrElse("")
+ //listenTo(doc.document)
case None =>
enabled = false
title = "Undo"
}
- listenTo(MainTabbedPane.selection)
+ listenTo(MainDocumentTabs.selection)
reactions += {
- case DocumentChanged(_) => updateUndoCommand()
+ case DocumentChanged(_) =>
+ updateUndoCommand()
case SelectionChanged(_) =>
currentDocument.foreach { doc => listenTo(doc.document) }
updateUndoCommand()
@@ -524,7 +686,7 @@ object QuantoDerive extends SimpleSwingApplication {
title = "Redo"
}
- listenTo(MainTabbedPane.selection)
+ listenTo(MainDocumentTabs.selection)
reactions += {
case DocumentChanged(_) => updateRedoCommand()
@@ -554,13 +716,6 @@ object QuantoDerive extends SimpleSwingApplication {
def apply() { currentGraphController.foreach(_.pasteSubgraph()) }
}
- contents += new Separator
-
- val SnapToGridAction = new Action("Snap to grid") {
- menu.contents += new MenuItem(this) { mnemonic = Key.S }
- //accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_G, CommandMask))
- def apply() { currentGraphController.foreach(_.snapToGrid()) }
- }
// val LayoutAction = new Action("Layout Graph") with Reactor {
// accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_L, CommandMask))
@@ -575,53 +730,346 @@ object QuantoDerive extends SimpleSwingApplication {
// contents += new MenuItem(LayoutAction) { mnemonic = Key.L }
}
- val DeriveMenu = new Menu("Derive") { menu =>
+ val RuleMenu = new Menu("Rule") {
+ menu =>
+ mnemonic = Key.R
+
+ val ColourSwapRule = new Action("Colour Swap") {
+ accelerator = None
+ enabled = true
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.C
+ }
+
+ def apply(): Unit = (CurrentProject, MainDocumentTabs.currentContent) match {
+ case (Some(project), Some(doc: HasDocument)) =>
+ doc.document match {
+ case ruleDoc: RuleDocument =>
+ val dialog = new ColourSwapDialog(project.theory)
+ dialog.centerOnScreen()
+ dialog.open()
+ dialog.result
+ if (dialog.result != project.theory.vertexTypes.keys.map(k => k -> k).toMap) {
+ val map = dialog.result
+ warn(map.mkString("Mapping types: ",", ",""))
+ val page = new RuleDocumentPage(project.theory)
+ page.document.asInstanceOf[RuleDocument].rule = ruleDoc.rule.colourSwap(map)
+ addAndFocusPage(page)
+ }
+ case _ =>
+ warn("Trying to colour swap a rule but no rule is active")
+
+ }
+ case _ => // no project and/or document open, do nothing
+ }
+ }
+
+ val InvertRule = new Action("Invert Rule") {
+ accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_I, CommandMask))
+ enabled = true
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.I
+ }
+
+ def apply() = (CurrentProject, MainDocumentTabs.currentContent) match {
+ case (Some(project), Some(doc: HasDocument)) =>
+ doc.document match {
+ case (ruleDoc: RuleDocument) =>
+ ruleDoc.rule = ruleDoc.rule.inverse
+
+ case _ =>
+ warn("WARNING: Invert rule called with no rule active")
+ }
+ case _ => // no project and/or document open, do nothing
+ }
+ }
+ visible = false
+ }
+
+ val TheoryMenu = new Menu("Theory") {
+ menu =>
+
+ mnemonic = Key.P
+
+
+ val EditTheoryAction = new Action("Alter Theory") {
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.T
+ }
+
+ def apply() {
+ CurrentProject.foreach { project =>
+ val page = MainDocumentTabs.documents.find(tp => tp.title == "Theory Editor") match {
+ case Some(p) => p
+ case None =>
+ val p = new TheoryPage()
+ listenTo(p.document)
+ p.title = "Theory Editor"
+ addAndFocusPage(p)
+ p
+ }
+ MainDocumentTabs.focus(page)
+ }
+ }
+ }
+
+
+ val BatchDerivationAction = new Action("Batch Derivation") {
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.B
+ }
+
+ def apply() {
+ CurrentProject.foreach { project =>
+ val page = MainDocumentTabs.documents.find(tp => tp.title == "Batch Derivation") match {
+ case Some(p) => p
+ case None =>
+ val p = new BatchDerivationPage()
+ listenTo(p.document)
+ p.title = "Batch Derivation"
+ addAndFocusPage(p)
+ p
+ }
+ MainDocumentTabs.focus(page)
+ }
+ }
+ }
+
+ visible = true
+ enabled = CurrentProject.nonEmpty
+
+ }
+
+ val GraphMenu = new Menu("Graph") {
+ menu =>
+ mnemonic = Key.G
+
val StartDerivation = new Action("Start derivation") {
accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_D, CommandMask))
enabled = false
- menu.contents += new MenuItem(this) { mnemonic = Key.D }
- def apply() = (CurrentProject, MainTabbedPane.currentContent) match {
- case (Some(project), Some(doc: HasDocument)) =>
- doc.document match {
- case (graphDoc: GraphDocument) =>
- val page = new DerivationDocumentPage(project)
- MainTabbedPane += page
- MainTabbedPane.selection.index = page.index
- page.document.asInstanceOf[DerivationDocument].root = graphDoc.graph
-
- case _ =>
- System.err.println("WARNING: Start derivation called with no graph active")
- }
- case _ => // no project and/or document open, do nothing
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.D
}
+
+ def apply() = (CurrentProject, MainDocumentTabs.currentContent) match {
+ case (Some(project), Some(doc: HasDocument)) =>
+ doc.document match {
+ case (graphDoc: GraphDocument) =>
+ val page = new DerivationDocumentPage(project)
+ addAndFocusPage(page)
+ page.document.asInstanceOf[DerivationDocument].root = graphDoc.graph
+
+ case _ =>
+ warn("WARNING: Start derivation called with no graph active")
+ }
+ case _ => // no project and/or document open, do nothing
+ }
+ }
+
+ val StartRule = new Action("Make into axiom") {
+ accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_R, CommandMask))
+ enabled = false
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.X
+ }
+
+ def apply() = (CurrentProject, MainDocumentTabs.currentContent) match {
+ case (Some(project), Some(doc: HasDocument)) =>
+ doc.document match {
+ case (graphDoc: GraphDocument) =>
+ val page = new RuleDocumentPage(project.theory)
+ page.document.asInstanceOf[RuleDocument].lhsRef.graph = graphDoc.graph
+ addAndFocusPage(page)
+ case _ =>
+ warn("WARNING: Start rule called with no graph active")
+ }
+ case _ => // no project and/or document open, do nothing
+ }
+ }
+
+
+ val ExtractGraph = new Action("Extract selection to new graph") {
+ enabled = true
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.N
+ }
+
+ def apply() = (CurrentProject, MainDocumentTabs.currentContent) match {
+ case (Some(project), Some(gep: GraphEditPanel)) =>
+ gep.document match {
+ case (graphDoc: GraphDocument) =>
+ val newPage = new GraphDocumentPage(project.theory)
+ val vertSelection = gep.graphEditController.selectedVerts
+ if(vertSelection.nonEmpty) {
+ val inverseSelection = gep.graphEditController.graph.verts -- vertSelection
+ val snippedGraph = inverseSelection.foldLeft(graphDoc.graph) {
+ (g, v) => g.cutVertex(v, g.verts.filter(g.isBoundary))._1
+ }
+ newPage.document.asInstanceOf[GraphDocument].graph = snippedGraph
+ addAndFocusPage(newPage)
+ }
+ case _ =>
+ warn("WARNING: Extract selection with no graph active")
+ }
+ case (Some(project), Some(rep: RuleEditPanel)) =>
+ rep.document match {
+ case (ruleDoc: RuleDocument) =>
+ val newPage = new GraphDocumentPage(project.theory)
+ val vertSelection = rep.focusedController.selectedVerts
+ if(vertSelection.nonEmpty) {
+ val inverseSelection = rep.focusedController.graph.verts -- vertSelection
+ val snippedGraph = inverseSelection.foldLeft(rep.focusedController.graph) {
+ (g, v) => g.cutVertex(v, g.verts.filter(g.isBoundary))._1
+ }
+ newPage.document.asInstanceOf[GraphDocument].graph = snippedGraph
+ addAndFocusPage(newPage)
+ }
+ case _ =>
+ warn("WARNING: Extract selection with no graph active")
+ }
+ case _ => // no project and/or document open, do nothing
+ }
+ }
+
+
+ val SnapToGrid = new Action("Snap to grid") {
+ enabled = false
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.S}
+
+ def apply() = {
+ currentGraphController.foreach(gc => gc.snapToGrid())
+ }
+
+ }
+
+ val MinimiseGraph = new Action("Minimise") {
+ enabled = false
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.M}
+
+ def apply() = {
+ currentGraphController.foreach(gc => gc.minimiseGraph())
+ }
+
}
+ visible = false
+ }
+
+
+ val DeriveMenu = new Menu("Derivation") {
+ menu =>
+ mnemonic = Key.D
+
val LayoutDerivation = new Action("Layout derivation") {
-// accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_L, CommandMask))
+ // accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_L, CommandMask))
enabled = false
- menu.contents += new MenuItem(this) { mnemonic = Key.L }
- def apply() = (CurrentProject, MainTabbedPane.currentContent) match {
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.L
+ }
+
+ def apply() = (CurrentProject, MainDocumentTabs.currentContent) match {
case (Some(project), Some(derivePanel: DerivationPanel)) =>
derivePanel.controller.layoutDerivation()
case _ => // no project and/or derivation open, do nothing
}
}
+
+
+ val ViewGraph = new Action("Extract to new graph") {
+ enabled = true
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.E
+ }
+
+ def apply() = (CurrentProject, MainDocumentTabs.currentContent) match {
+ case (Some(project), Some(dp: DerivationPanel)) =>
+ dp.document match {
+ case (derivationDoc: DerivationDocument) =>
+ val page = new GraphDocumentPage(project.theory)
+ val graph = dp.lhsController.graph
+ page.document.asInstanceOf[GraphDocument].graph = graph
+ addAndFocusPage(page)
+ }
+ case _ =>
+ warn("WARNING: Extract selection with no graph active")
+ }
+ }
+
+
+
+ val ReRunLastSimproc = new Action("Re-run last simproc") {
+ accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_SPACE, CommandMask))
+ enabled = true
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.R
+ }
+
+ def apply() = (CurrentProject, MainDocumentTabs.currentContent) match {
+ case (Some(project), Some(dp: DerivationPanel)) =>
+ dp.ReRunLastSimproc()
+ case _ =>
+ warn("WARNING: Re-run simproc called with no derivation active")
+ }
+ }
+
+
+ visible = false
}
+
val WindowMenu = new Menu("Window") { menu =>
+ mnemonic = Key.W
+
val CloseAction = new Action("Close tab") {
accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_W, CommandMask))
enabled = false
menu.contents += new MenuItem(this) { mnemonic = Key.C }
def apply() {
- MainTabbedPane.currentContent match {
- case Some(doc: HasDocument) =>
- if (doc.document.promptUnsaved()) MainTabbedPane.pages.remove(MainTabbedPane.selection.index)
+ MainDocumentTabs.currentFocus match {
+ case Some(page: DocumentPage) =>
+ if (page.document.promptUnsaved()) MainDocumentTabs.remove(MainDocumentTabs.currentFocus.get)
case _ =>
}
}
}
+ val NextTabAction = new Action("Next tab") {
+ accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_RIGHT, CommandMask))
+ enabled = false
+ menu.contents += new MenuItem(this) {mnemonic = Key.N}
+
+ def apply(): Unit ={
+ MainDocumentTabs.cycle()
+ }
+ }
+
+ val PreviousTabAction = new Action("Previous tab") {
+ accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_LEFT, CommandMask))
+ enabled = false
+ menu.contents += new MenuItem(this) {mnemonic = Key.P}
+
+ def apply(): Unit ={
+ MainDocumentTabs.cycle(forward = false)
+ }
+ }
+
+ val CloseAllAction = new Action("Close all tabs") {
+ accelerator = None
+ enabled = false
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.A
+ }
+
+ def apply() {
+ MainDocumentTabs.documents.foreach(page => {
+ MainDocumentTabs.focus(page)
+ if (page.document.promptUnsaved()) MainDocumentTabs.remove(page)
+ })
+ }
+ }
+
contents += new Separator
val IncreaseUIScaling = new Action("Increase UI scaling") {
@@ -640,7 +1088,9 @@ object QuantoDerive extends SimpleSwingApplication {
}
val HelpMenu = new Menu("Help") { menu =>
- val CloseAction = new Action("Quantomatic website") {
+ mnemonic = Key.H
+
+ val WebsiteAction = new Action("Quantomatic website") {
menu.contents += new MenuItem(this) { mnemonic = Key.Q }
def apply() {
WebHelper.openWebpage("https://quantomatic.github.io/")
@@ -650,19 +1100,30 @@ object QuantoDerive extends SimpleSwingApplication {
val SimprocAPIAction = new Action("Simproc API") {
menu.contents += new MenuItem(this) { mnemonic = Key.S }
def apply() {
- WebHelper.openWebpage("https://quantomatic.github.io/SimprocAPI.html")
+ WebHelper.openWebpage("https://quantomatic.github.io#SimprocAPI")
}
}
+ //private val project = getClass.getPackage
+ // val version = project.getImplementationVersion()
+ // TODO: Implement versioning
+ val UpdateAction = new Action(s"Get latest version") {
+ menu.contents += new MenuItem(this) { mnemonic = Key.V }
+ def apply() {
+ WebHelper.openWebpage("https://bintray.com/quantomatic/quantomatic/quantomatic/bleeding-edge")
+ }
+ }
}
val ExportMenu = new Menu("Export") { menu =>
+ mnemonic = Key.X
+
val ExportAction = new Action("Export to LaTeX") {
accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_E, CommandMask))
enabled = false
menu.contents += new MenuItem(this) { mnemonic = Key.E }
def apply() {
- MainTabbedPane.currentContent match {
+ MainDocumentTabs.currentContent match {
case Some(doc: HasDocument) =>
if (doc.document.unsavedChanges) {
Dialog.showMessage(title = "Unsaved Changes",
@@ -677,55 +1138,25 @@ object QuantoDerive extends SimpleSwingApplication {
}
}
+ val StatusBar = new StatusBar()
- val UserMessage = new Label(UserAlerts.latestMessage.toString)
- val ConsoleProgress = new ProgressBar
- val ConsoleProgressLabel = new Label(" ")
- val StatusBar = new GridPanel(1, 2) {
- contents += new FlowPanel(FlowPanel.Alignment.Left)(UserMessage)
- contents += new FlowPanel(FlowPanel.Alignment.Right)(ConsoleProgressLabel, ConsoleProgress)
- }
- ConsoleProgress.preferredSize = ConsoleProgressSize //Currently doesn't respond to UI scaling
- def ConsoleProgressSize: Dimension = new Dimension(UserOptions.scaleInt(100), UserOptions.scaleInt(15))
+ listenTo(ProjectFileTree, MainDocumentTabs.selection)
-
- listenTo(UserAlerts.AlertPublisher)
reactions += {
- case UserAlerts.UserAlertEvent(alert: UserAlerts.Alert) =>
- UserMessage.text = alert.toString
- UserMessage.foreground = alert.color
- case UserAlerts.UserProcessUpdate(_) =>
- UserAlerts.leastCompleteProcess match {
- case Some(process) => if (process.determinate) {
- ConsoleProgress.indeterminate = false
- ConsoleProgress.value = process.value
- } else {
- ConsoleProgress.indeterminate = true
- }
- case _ => ConsoleProgress.value = 100
- }
- val ongoing = UserAlerts.ongoingProcesses.filter(op => op.value < 100)
- ongoing.count(_ => true) match {
- case 0 => ConsoleProgressLabel.text = " " //keep non-empty so the progressbar stays in line with text
- case 1 => ConsoleProgressLabel.text = ongoing.head.name
- case n => ConsoleProgressLabel.text = n.toString + " processes ongoing"
+ case PageClosed(p : DocumentPage) =>
+ MainDocumentTabs.remove(p)
+ case FileContextRequested(file, e) =>
+ if(file.isDirectory){
+ popup(FolderContextMenu(file), e)
+ } else {
+ popup(FileContextMenu(file), e)
}
- }
-
- val Main = new BorderPanel {
- add(Split, BorderPanel.Position.Center)
- add(StatusBar, BorderPanel.Position.South)
- }
-
- listenTo(ProjectFileTree, MainTabbedPane.selection)
-
- reactions += {
case FileOpened(file) =>
CurrentProject match {
case Some(project) =>
- val existingPage = MainTabbedPane.pages.find { p =>
+ val existingPage = MainDocumentTabs.documents.find { p =>
p.content match {
case doc : HasDocument => doc.document.file.exists(_.getPath == file.getPath)
case _ => false
@@ -734,7 +1165,7 @@ object QuantoDerive extends SimpleSwingApplication {
existingPage match {
case Some(p) =>
- MainTabbedPane.selection.index = p.index
+ MainDocumentTabs.focus(p)
case None =>
val extn = file.getName.lastIndexOf('.') match {
case i if i > 0 => file.getName.substring(i+1) ; case _ => ""}
@@ -745,15 +1176,15 @@ object QuantoDerive extends SimpleSwingApplication {
case "qderive" => Some(new DerivationDocumentPage(project))
case "py" => Some(new PythonDocumentPage)
case "ML" => Some(new MLDocumentPage)
+ case "qsbr" => Some(new BatchDerivationResultsPage)
case _ => None
}
- pageOpt.map{ page =>
- MainTabbedPane += page
- MainTabbedPane.selection.index = page.index
+ pageOpt.foreach{ page =>
+ addAndFocusPage(page)
if (!page.document.load(file)) {
- MainTabbedPane.pages -= page
+ MainDocumentTabs.remove(page)
}
}
}
@@ -761,25 +1192,67 @@ object QuantoDerive extends SimpleSwingApplication {
}
case SelectionChanged(_) =>
+ refreshAllMenus()
+ }
+
+// val versionResp = core ? Call("!!", "system", "version")
+// versionResp.onSuccess { case Success(JsonString(version)) =>
+// Swing.onEDT { CoreStatus.text = "OK"; CoreStatus.foreground = new Color(0,150,0) }
+// }
+
+ private def refreshAllMenusAndTitle(): Unit = {
+ refreshAllMenus()
+ refreshTitle()
+ }
+
+ private def refreshAllMenus(): Unit = {
+ try {
FileMenu.SaveAction.enabled = false
FileMenu.SaveAsAction.enabled = false
FileMenu.SaveAllAction.enabled = false
+ CurrentProject match {
+ case Some(_) =>
+ FileMenu.NewGraphAction.enabled = true
+ FileMenu.NewAxiomAction.enabled = true
+ case None =>
+ FileMenu.NewGraphAction.enabled = false
+ FileMenu.NewAxiomAction.enabled = false
+ }
+ TheoryMenu.visible = true
+ TheoryMenu.enabled = CurrentProject.nonEmpty
+ TheoryMenu.EditTheoryAction.enabled = CurrentProject.nonEmpty
+ TheoryMenu.BatchDerivationAction.enabled = CurrentProject.nonEmpty
EditMenu.CutAction.enabled = false
EditMenu.CopyAction.enabled = false
EditMenu.PasteAction.enabled = false
- EditMenu.SnapToGridAction.enabled = false
- DeriveMenu.StartDerivation.enabled = false
+ RuleMenu.visible = false
+ RuleMenu.InvertRule.enabled = false
+ GraphMenu.visible = false
+ GraphMenu.StartDerivation.enabled = false
+ GraphMenu.SnapToGrid.enabled = false
+ GraphMenu.MinimiseGraph.enabled = false
+ GraphMenu.StartRule.enabled = false
+ GraphMenu.ExtractGraph.enabled = false
+ DeriveMenu.visible = false
DeriveMenu.LayoutDerivation.enabled = false
+ DeriveMenu.ViewGraph.enabled = false
+ DeriveMenu.ReRunLastSimproc.enabled = false
WindowMenu.CloseAction.enabled = false
+ WindowMenu.PreviousTabAction.enabled = false
+ WindowMenu.NextTabAction.enabled = false
+ WindowMenu.CloseAllAction.enabled = false
ExportMenu.ExportAction.enabled = false
histView = None
FileMenu.SaveAction.title = "Save"
FileMenu.SaveAsAction.title = "Save As..."
- MainTabbedPane.currentContent match {
+ MainDocumentTabs.currentContent match {
case Some(content: HasDocument) =>
WindowMenu.CloseAction.enabled = true
+ WindowMenu.CloseAllAction.enabled = true
+ WindowMenu.NextTabAction.enabled = MainDocumentTabs.size > 1
+ WindowMenu.PreviousTabAction.enabled = MainDocumentTabs.size > 1
FileMenu.SaveAction.enabled = true
FileMenu.SaveAsAction.enabled = true
FileMenu.SaveAllAction.enabled = true
@@ -793,46 +1266,76 @@ object QuantoDerive extends SimpleSwingApplication {
EditMenu.CutAction.enabled = true
EditMenu.CopyAction.enabled = true
EditMenu.PasteAction.enabled = true
- EditMenu.SnapToGridAction.enabled = true
- DeriveMenu.StartDerivation.enabled = true
+ GraphMenu.visible = true
+ GraphMenu.StartDerivation.enabled = true
+ GraphMenu.StartRule.enabled = true
+ GraphMenu.SnapToGrid.enabled = true
+ GraphMenu.MinimiseGraph.enabled = true
+ GraphMenu.ExtractGraph.enabled = true
ExportMenu.ExportAction.enabled = true
case panel: RuleEditPanel =>
EditMenu.CutAction.enabled = true
EditMenu.CopyAction.enabled = true
EditMenu.PasteAction.enabled = true
- EditMenu.SnapToGridAction.enabled = true
ExportMenu.ExportAction.enabled = true
+ RuleMenu.visible = true
+ RuleMenu.InvertRule.enabled = true
+ GraphMenu.visible = true
+ GraphMenu.SnapToGrid.enabled = true
+ GraphMenu.MinimiseGraph.enabled = true
+ GraphMenu.ExtractGraph.enabled = true
case panel: DerivationPanel =>
- DeriveMenu.LayoutDerivation.enabled = true
ExportMenu.ExportAction.enabled = true
histView = Some(panel.histView)
+ DeriveMenu.visible = true
+ DeriveMenu.LayoutDerivation.enabled = true
+ DeriveMenu.ViewGraph.enabled = true
+ DeriveMenu.ReRunLastSimproc.enabled = true
case _ => // nothing else enabled for ML
}
case _ => // leave everything disabled
}
+ } catch {
+ case _: NullPointerException =>
+ // Null Pointer Exception thrown when accessing GUI too early
+ }
}
-// val versionResp = core ? Call("!!", "system", "version")
-// versionResp.onSuccess { case Success(JsonString(version)) =>
-// Swing.onEDT { CoreStatus.text = "OK"; CoreStatus.foreground = new Color(0,150,0) }
-// }
+ // The highest level GUI contents
+ val Main = new BorderPanel {
+ add(Split, BorderPanel.Position.Center)
+ add(StatusBar.Status, BorderPanel.Position.South)
+ }
+
+
+ val _mainframe = new MainFrame {
- def top = new MainFrame {
- override def title : String = {
- if (CurrentProject.isEmpty) {"Quantomatic"} else {
+ def refreshTitle() : Unit = {
+ // Setting the iconImage here isn't working on Windows
+ // iconImage = ImageIO.read(getClass.getResource("quantoderive.ico"))
+ title = if (CurrentProject.isEmpty) {"Quantomatic"} else {
CurrentProject.get.name match {
case "" => "Quantomatic"
- case s => "Quantomatic - $s"
+ case s => s"Quantomatic - $s"
}
}
}
contents = Main
- size = new Dimension(1280,720)
+
+ if (prefs.getBoolean("fullscreen",false)) {
+ peer.setExtendedState(peer.getExtendedState() | Frame.MAXIMIZED_BOTH)
+ }
+ else {
+ size = new Dimension(prefs.getInt("screenwidth",1280),prefs.getInt("screenheight",720))
+ peer.setLocation(prefs.getInt("locationx",300),prefs.getInt("locationy",300))
+ }
+ peer.setVisible(true)
+
menuBar = new MenuBar {
- contents += (FileMenu, EditMenu, DeriveMenu, WindowMenu, ExportMenu, HelpMenu)
+ contents += (FileMenu, TheoryMenu, EditMenu, DeriveMenu, RuleMenu, GraphMenu, WindowMenu, ExportMenu, HelpMenu)
}
import javax.swing.WindowConstants.DO_NOTHING_ON_CLOSE
@@ -842,4 +1345,17 @@ object QuantoDerive extends SimpleSwingApplication {
if (quitQuanto()) scala.sys.exit(0)
}
}
+
+ def top = _mainframe
+
+ def refreshTitle(): Unit = {
+ try {
+ top.refreshTitle()
+ } catch {
+ case _: NullPointerException =>
+ // Null Pointer Exception thrown when accessing GUI too early
+ }
+ }
+
+ refreshAllMenusAndTitle()
}
diff --git a/scala/src/main/scala/quanto/gui/RewriteController.scala b/scala/src/main/scala/quanto/gui/RewriteController.scala
index 1421e846..95415d74 100644
--- a/scala/src/main/scala/quanto/gui/RewriteController.scala
+++ b/scala/src/main/scala/quanto/gui/RewriteController.scala
@@ -3,29 +3,35 @@ package quanto.gui
import quanto.data._
import quanto.data.Names._
import quanto.rewrite._
-import scala.concurrent.{Future, Lock}
+import java.util.concurrent.locks.ReentrantLock
+
+import scala.concurrent.Future
import scala.swing._
import scala.swing.event._
import scala.swing.event.ButtonClicked
-import scala.util.{Success,Failure}
+import scala.util.{Failure, Success}
import quanto.util.json._
+
import scala.concurrent.ExecutionContext.Implicits.global
import java.io.File
-import quanto.layout.ForceLayout
+
+import quanto.data.Theory.ValueType
+import quanto.util.UserAlerts
+import quanto.util.UserAlerts.Elevation
class RewriteController(panel: DerivationPanel) extends Publisher {
implicit val timeout = QuantoDerive.timeout
var queryId = 0
- val resultLock = new Lock
+ val resultLock = new ReentrantLock()
var resultSet = ResultSet(Vector())
def theory = panel.project.theory
class ResultGraphRef(rule: RuleDesc, i: Int) extends HasGraph {
protected def gr_=(g: Graph) {
- resultLock.acquire()
+ resultLock.lock()
resultSet = resultSet.replaceGraph(rule, i, g)
- resultLock.release()
+ resultLock.unlock()
}
protected def gr = resultSet.graph(rule, i)
@@ -33,7 +39,7 @@ class RewriteController(panel: DerivationPanel) extends Publisher {
def rules = resultSet.rules
def rules_=(rules: Vector[RuleDesc]) {
- resultLock.acquire()
+ resultLock.lock()
resultSet = ResultSet(rules)
queryId += 1
@@ -42,21 +48,18 @@ class RewriteController(panel: DerivationPanel) extends Publisher {
else panel.LhsView.graph.verts
for (rd <- rules) {
- val rule = Rule.fromJson(Json.parse(new File(panel.project.rootFolder + "/" + rd.name + ".qrule")), theory)
- val ms = Matcher.initialise(if (rd.inverse) rule.rhs else rule.lhs, panel.LhsView.graph, sel)
- pullRewrite(ms, rd, rule)
-
-// QuantoDerive.core ? Call(theory.coreName, "rewrite", "find_rewrites",
-// JsonObject(
-// "rule" -> Rule.toJson(if (rd.inverse) rule.inverse else rule, theory),
-// "graph" -> Graph.toJson(panel.LhsView.graph, theory),
-// "vertices" -> JsonArray(sel.toVector.map(v => JsonString(v.toString)))
-// ))
-// resp.map { case Success(JsonString(stack)) => pullRewrite(queryId, rd, stack); case _ => }
+ try {
+ val rule = Rule.fromJson(Json.parse(new File(panel.project.rootFolder + "/" + rd.name + ".qrule")), theory)
+ val ms = Matcher.initialise(if (rd.inverse) rule.rhs else rule.lhs, panel.LhsView.graph, sel)
+ pullRewrite(ms, rd, rule)
+ } catch {
+ case RuleLoadException(message, _) =>
+ UserAlerts.alert(s"Could not load ${rd.name}, error: $message", Elevation.WARNING)
+ }
}
- resultLock.release()
-
+ resultLock.unlock()
+
refreshRewriteDisplay(clearSelection = true)
}
@@ -111,10 +114,9 @@ class RewriteController(panel: DerivationPanel) extends Publisher {
name = DSName("s"),
ruleName = rd.name,
rule = rule1,
- variant = if (rd.inverse) RuleInverse else RuleNormal,
graph = graph1.minimise).layout
- resultLock.acquire()
+ resultLock.lock()
// make sure this rewrite query is still in progress, and the rule hasn't been manually removed by the user
if (resultSet.rules.contains(rd)) {
@@ -125,7 +127,7 @@ class RewriteController(panel: DerivationPanel) extends Publisher {
}
}
- resultLock.release()
+ resultLock.unlock()
refreshRewriteDisplay()
case Success(None) => // out of matches
case Failure(t) => println("An error occurred in the matcher: " + t.getMessage)
@@ -135,7 +137,7 @@ class RewriteController(panel: DerivationPanel) extends Publisher {
def refreshRewriteDisplay(clearSelection: Boolean = false) {
Swing.onEDT {
- resultLock.acquire()
+ resultLock.lock()
if (clearSelection) {
panel.ManualRewritePane.PreviousResultButton.enabled = false
@@ -152,10 +154,27 @@ class RewriteController(panel: DerivationPanel) extends Publisher {
}
}
- resultLock.release()
+ resultLock.unlock()
}
}
+ def promptForVariableSpecification(strings: Set[(ValueType, String)]):
+ Map[(ValueType, String), String] = {
+ val d = new SpecifyVariablesDialog(strings.toList.sortBy(_._2))
+ d.centerOnScreen()
+ d.open()
+ d.result
+ }
+
+ def removeSelectedRulesFromList(): Unit = {
+
+ resultLock.lock()
+ panel.ManualRewritePane.Rewrites.selection.items.foreach { line => resultSet -= line.rule }
+ resultLock.unlock()
+ refreshRewriteDisplay()
+ }
+
+
def selectedRule =
if (panel.ManualRewritePane.Rewrites.selection.items.length == 1)
Some(panel.ManualRewritePane.Rewrites.selection.items(0).asInstanceOf[ResultLine].rule)
@@ -165,8 +184,20 @@ class RewriteController(panel: DerivationPanel) extends Publisher {
listenTo(panel.ManualRewritePane.PreviousResultButton, panel.ManualRewritePane.NextResultButton)
listenTo(panel.ManualRewritePane.ApplyButton)
listenTo(panel.ManualRewritePane.Rewrites.selection)
+ listenTo(panel)
+ listenTo(panel.ManualRewritePane.Rewrites.keys)
reactions += {
+ case KeyPressed(_, Key.Delete | Key.BackSpace, _, _) =>
+ if (panel.ManualRewritePane.Rewrites.hasFocus) {
+ removeSelectedRulesFromList()
+ }
+ case SuggestRewriteRule(ruleDesc) =>
+ val currentRules = rules.toSet
+ val newRules = Set(ruleDesc).filter(!currentRules.contains(_))
+
+ if (newRules.nonEmpty) rules ++= newRules
+ rules = rules.sortBy(r => r.name)
case ButtonClicked(panel.ManualRewritePane.AddRuleButton) =>
val d = new AddRuleDialog(panel.project)
d.centerOnScreen()
@@ -178,12 +209,9 @@ class RewriteController(panel: DerivationPanel) extends Publisher {
if (newRules.nonEmpty) rules ++= newRules
rules = rules.sortBy(r => r.name)
case ButtonClicked(panel.ManualRewritePane.RemoveRuleButton) =>
- resultLock.acquire()
- panel.ManualRewritePane.Rewrites.selection.items.foreach { line => resultSet -= line.rule }
- resultLock.release()
- refreshRewriteDisplay()
+ removeSelectedRulesFromList()
case ButtonClicked(panel.ManualRewritePane.PreviousResultButton) =>
- resultLock.acquire()
+ resultLock.lock()
selectedRule match {
case Some(rd) =>
resultSet = resultSet.previousResult(rd)
@@ -191,10 +219,10 @@ class RewriteController(panel: DerivationPanel) extends Publisher {
{ case 0 => panel.DummyRef ; case i => new ResultGraphRef(rd, i) }
case None =>
}
- resultLock.release()
+ resultLock.unlock()
refreshRewriteDisplay()
case ButtonClicked(panel.ManualRewritePane.NextResultButton) =>
- resultLock.acquire()
+ resultLock.lock()
selectedRule match {
case Some(rd) =>
resultSet = resultSet.nextResult(rd)
@@ -202,18 +230,51 @@ class RewriteController(panel: DerivationPanel) extends Publisher {
{ case 0 => panel.DummyRef ; case i => new ResultGraphRef(rd, i) }
case None =>
}
- resultLock.release()
+ resultLock.unlock()
refreshRewriteDisplay()
case ButtonClicked(panel.ManualRewritePane.ApplyButton) =>
- selectedRule.foreach { rd => resultSet.currentResult(rd).map { step =>
+ selectedRule.foreach { rd => resultSet.currentResult(rd).foreach { step =>
val parentOpt = panel.controller.state.step
-
val stepFr = step.copy(name = panel.derivation.steps.freshWithSuggestion(DSName(rd.name.replaceFirst("^.*\\/", "") + "-0")))
panel.ManualRewritePane.Preview.graphRef = panel.DummyRef
+ // Check to see if new variables were introduced
+ val oldHead : Option[Graph] = parentOpt.map(panel.derivation.steps(_).graph)
+ val newHead : Graph = stepFr.graph
+
+ val oldVariables: Set[(ValueType, String)] = oldHead match {
+ case Some(head) => Graph.variablesUsedWithType(theory, head)
+ case None => Set()
+ }
+ val newVariables: Set[(ValueType, String)] = Graph.variablesUsedWithType(theory, newHead) -- oldVariables
+ val subs: Map[(ValueType, String), String] =
+ if (newVariables.nonEmpty) {
+ promptForVariableSpecification(newVariables)
+ } else {
+ Map()
+ }
+ val graphWithReplacements: Graph = if (newVariables.nonEmpty) {
+ newHead.updateAllVData {
+ case node: NodeV =>
+ node.newValue(node.phaseData.substSubVariables(subs).toString)
+ case vd =>
+ vd
+ }
+ } else {
+ newHead
+ }
+
+ val stepWithReplacements: DStep = stepFr.copy(graph = graphWithReplacements)
+
+ val description: String = if (subs.isEmpty) {
+ ""
+ } else {
+ subs.mkString(", ")
+ }
+
panel.document.undoStack.start("Apply rewrite")
- panel.controller.replaceDerivation(panel.derivation.addStep(parentOpt, stepFr), "")
- panel.controller.state = HeadState(Some(stepFr.name))
+ panel.controller.replaceDerivation(panel.derivation.addStep(parentOpt, stepWithReplacements), description)
+ panel.controller.state = HeadState(Some(stepWithReplacements.name))
panel.document.undoStack.commit()
}}
diff --git a/scala/src/main/scala/quanto/gui/RuleDocument.scala b/scala/src/main/scala/quanto/gui/RuleDocument.scala
index 2ef9e3f0..9da17b81 100644
--- a/scala/src/main/scala/quanto/gui/RuleDocument.scala
+++ b/scala/src/main/scala/quanto/gui/RuleDocument.scala
@@ -1,10 +1,13 @@
package quanto.gui
import java.io.File
+
import quanto.data._
import quanto.util.json.Json
+
import scala.swing.Component
import quanto.util.FileHelper.printToFile
+import quanto.util.UserAlerts
class RuleDocument(val parent: Component, theory: Theory) extends Document {
val description = "Rule"
@@ -20,12 +23,13 @@ class RuleDocument(val parent: Component, theory: Theory) extends Document {
}
def rule = Rule(lhsRef.graph, rhsRef.graph, derivation)
- def rule_=(r: Rule) {
- lhsRef.graph = r.lhs
- rhsRef.graph = r.rhs
- derivation = r.derivation
+ def rule_=(newRule: Rule) {
+ lhsRef.graph = newRule.lhs
+ rhsRef.graph = newRule.rhs
+ derivation = newRule.derivation
lhsRef.publish(GraphReplaced(lhsRef, clearSelection = true))
rhsRef.publish(GraphReplaced(rhsRef, clearSelection = true))
+ publish(DocumentChanged(this))
}
private var storedRule: Rule = Rule(Graph(theory), Graph(theory))
diff --git a/scala/src/main/scala/quanto/gui/RuleEditPanel.scala b/scala/src/main/scala/quanto/gui/RuleEditPanel.scala
index 2bbb09c6..a5793ad9 100644
--- a/scala/src/main/scala/quanto/gui/RuleEditPanel.scala
+++ b/scala/src/main/scala/quanto/gui/RuleEditPanel.scala
@@ -1,9 +1,10 @@
package quanto.gui
import quanto.gui.graphview.GraphView
-import quanto.data.{HasGraph, Theory, Graph}
-import scala.swing.{GridPanel, BorderPanel, ScrollPane}
-import scala.swing.event.UIElementResized
+import quanto.data.{Graph, HasGraph, Theory}
+
+import scala.swing.{BorderPanel, GridPanel, ScrollPane}
+import scala.swing.event.{MouseClicked, UIElementResized}
class RuleEditPanel(val theory: Theory, val readOnly: Boolean = false)
extends BorderPanel
@@ -19,11 +20,10 @@ with HasDocument
val lhsController = new GraphEditController(lhsView, document.undoStack, readOnly)
lhsController.controlsOpt = Some(controls)
-
val rhsController = new GraphEditController(rhsView, document.undoStack, readOnly)
rhsController.controlsOpt = Some(controls)
- def focusedController = if (rhsView.hasFocus) rhsController else lhsController
+ def focusedController: GraphEditController = if (rhsView.hasFocus) rhsController else lhsController
val LhsScrollPane = new ScrollPane(lhsView)
val RhsScrollPane = new ScrollPane(rhsView)
@@ -50,8 +50,19 @@ with HasDocument
case UIElementResized(RhsScrollPane) =>
rhsView.resizeViewToFit()
rhsView.repaint()
+ case MouseStateChanged(RequestFocusOnGraph()) =>
+ focusedController.focusOnGraph()
+ case MouseStateChanged(RequestMinimiseGraph()) =>
+ focusedController.minimiseGraph()
+ case MouseStateChanged(RelaxToolDown()) => focusedController.startRelaxGraph(true)
+ case MouseStateChanged(RelaxToolUp()) => focusedController.endRelaxGraph()
case MouseStateChanged(m) =>
+ lhsController.endRelaxGraph()
+ rhsController.endRelaxGraph()
lhsController.mouseState = m
rhsController.mouseState = m
+ case DocumentRequestingNaturalFocus(_) =>
+ lhsView.requestFocus()
}
+
}
diff --git a/scala/src/main/scala/quanto/gui/ScalaEditPanel.hide b/scala/src/main/scala/quanto/gui/ScalaEditPanel.hide
deleted file mode 100644
index 8b7b49f3..00000000
--- a/scala/src/main/scala/quanto/gui/ScalaEditPanel.hide
+++ /dev/null
@@ -1,122 +0,0 @@
-package quanto.gui
-
-import scala.reflect.runtime.universe._
-import scala.tools.reflect.ToolBox
-import scala.swing._
-import org.gjt.sp.jedit.{Registers, Mode}
-import org.gjt.sp.jedit.textarea.StandaloneTextArea
-import java.awt.{Color, BorderLayout}
-import java.awt.event.{KeyEvent, KeyAdapter}
-import javax.swing.ImageIcon
-import quanto.util.swing.ToolBar
-import scala.swing.event.ButtonClicked
-import quanto.util._
-import java.io.{File, PrintStream}
-
-class ScalaEditPanel extends BorderPanel with HasDocument {
- val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask
-
- val sml = new Mode("StandardML")
-
- val scalaModeXml = if (Globals.isBundle) new File("scala.xml").getAbsolutePath
- else getClass.getResource("scala.xml").getPath
- sml.setProperty("file", scalaModeXml)
- //println(sml.getProperty("file"))
- val scalaCode = StandaloneTextArea.createTextArea()
- //mlCode.setFont(new Font("Menlo", Font.PLAIN, 14))
-
- val buf = new JEditBuffer1
- buf.setMode(sml)
-
- var scalaThread : Thread = null
-
- scalaCode.setBuffer(buf)
-
- // mlCode.addKeyListener(new KeyAdapter {
- // override def keyPressed(e: KeyEvent) {
- // if (e.getModifiers == CommandMask) e.getKeyChar match {
- // case 'x' => Registers.cut(mlCode, '$')
- // case 'c' => Registers.copy(mlCode, '$')
- // case 'v' => Registers.paste(mlCode, '$')
- // case _ =>
- // }
- // }
- // })
-
- val document = new CodeDocument("Scala Code", "scala", this, scalaCode)
-
- val toolbox = runtimeMirror(getClass.getClassLoader).mkToolBox() //currentMirror.mkToolBox()
-
-
- val textPanel = new BorderPanel {
- peer.add(scalaCode, BorderLayout.CENTER)
- }
-
- val RunButton = new Button() {
- icon = new ImageIcon(GraphEditor.getClass.getResource("start.png"), "Run scala code")
- tooltip = "Run Scala"
- }
-
- val InterruptButton = new Button() {
- icon = new ImageIcon(GraphEditor.getClass.getResource("stop.png"), "Interrupt execution")
- tooltip = "Interrupt execution"
- }
-
- val ScalaToolbar = new ToolBar {
- contents += (RunButton, InterruptButton)
- }
-
- val outputTextArea = new TextArea()
- outputTextArea.editable = false
- val textOut = new TextAreaOutputStream(outputTextArea)
-
- val scalaOutput = new PrintStream(new TextAreaOutputStream(outputTextArea))
-
- add(ScalaToolbar, BorderPanel.Position.North)
-
- object Split extends SplitPane {
- orientation = Orientation.Horizontal
- contents_=(textPanel, new ScrollPane(outputTextArea))
- }
-
- add(Split, BorderPanel.Position.Center)
-
- listenTo(RunButton, InterruptButton)
-
- reactions += {
- case ButtonClicked(RunButton) =>
- if (scalaThread == null) {
- QuantoDerive.CoreStatus.text = "Running scala code"
- QuantoDerive.CoreStatus.foreground = Color.BLUE
-
- scalaThread = new Thread(new Runnable {
- def run() {
- try {
- val tree = toolbox.parse(scalaCode.getBuffer.getText())
- toolbox.eval(tree)
- QuantoDerive.CoreStatus.text = "Scala compiled sucessfully"
- QuantoDerive.CoreStatus.foreground = new Color(0, 150, 0)
- } catch {
- case e : Throwable =>
- QuantoDerive.CoreStatus.text = "Error in scala code"
- QuantoDerive.CoreStatus.foreground = Color.RED
- Swing.onEDT { e.printStackTrace(scalaOutput) }
- } finally {
- scalaThread = null
- }
- }
- })
- scalaThread.start()
-
- } else {
- QuantoDerive.CoreStatus.text = "Scala already running"
- QuantoDerive.CoreStatus.foreground = Color.RED
- }
-
- case ButtonClicked(InterruptButton) =>
- if (scalaThread != null) {
- scalaThread.interrupt()
- scalaThread = null
- }
- }
-}
diff --git a/scala/src/main/scala/quanto/gui/SimplifyBuiltInController.scala b/scala/src/main/scala/quanto/gui/SimplifyBuiltInController.scala
deleted file mode 100644
index 6e89eba2..00000000
--- a/scala/src/main/scala/quanto/gui/SimplifyBuiltInController.scala
+++ /dev/null
@@ -1,368 +0,0 @@
-package quanto.gui
-
-import java.io.File
-
-import scala.swing._
-import quanto.core._
-import quanto.data._
-import quanto.data.Names._
-import quanto.util.json._
-import akka.pattern.ask
-
-import scala.concurrent.Future
-import scala.concurrent.ExecutionContext.Implicits.global
-import scala.swing.event.ButtonClicked
-import quanto.cosy.SimplificationProcedure
-import quanto.cosy.GraphAnalysis
-import quanto.data.Derivation.DerivationWithHead
-
-import scala.util.Random
-
-
-class SimplifyBuiltInController(panel: DerivationPanel) extends Publisher {
- implicit val timeout = QuantoDerive.timeout
- private var simpId = 0 // incrementing the simpId will (lazily) cancel any pending simplification jobs
-
- listenTo(panel.SimplifyBuiltInPane.SimplifyButton)
-
- def refreshSimprocs() {
- simpId += 1
- // val res = QuantoDerive.core ? Call(theory.coreName, "simplify", "list")
- // res.map {
- // case Success(JsonArray(procs)) =>
- // Swing.onEDT { panel.SimplifyBuiltInPane.Simprocs.listData = procs.map(_.stringValue) }
- // case r => println("ERROR: Unexpected result from core: " + r) // TODO: errror dialogs
- // }
- }
-
- def theory = panel.theory
-
- private def pullSimp(simproc: String, sid: Int, stack: String, parentOpt: Option[DSName]) {
- // if (simpId == sid) {
- // val res = QuantoDerive.core ? Call(theory.coreName, "simplify", "pull_next_step",
- // JsonObject("stack" -> JsonString(stack)))
- //
- // res.map {
- // case Success(JsonNull) => // out of steps
- // Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false }
- // simpId += 1
- // case Success(json) =>
- // if (simpId == sid) {
- // val suggest = simproc + "-" + (json / "rule_name").stringValue.replaceFirst("^.*\\/", "") + "-0"
- // val sname = panel.derivation.steps.freshWithSuggestion(DSName(suggest))
- // val step = DStep.fromJson(sname, json, theory).layout
- //
- // Swing.onEDT {
- // panel.document.derivation = panel.document.derivation.addStep(parentOpt, step)
- // panel.controller.state = HeadState(Some(step.name))
- // pullSimp(simproc, sid, stack, Some(sname))
- // }
- //
- // } else {
- // Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false }
- // QuantoDerive.core ! Call(theory.coreName, "simplify", "delete_stack",
- // JsonObject("stack" -> JsonString(stack)))
- // }
- // case _ => println("ERROR: Unexpected result from core: " + res) // TODO: errror dialogs
- // }
- // } else {
- // Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false }
- // QuantoDerive.core ! Call(theory.coreName, "simplify", "delete_stack",
- // JsonObject("stack" -> JsonString(stack)))
- // }
- }
-
- private def evaluateSimproc(): Unit = {
- val d = new EvaluationInputPanel(panel.project)
- d.centerOnScreen()
- d.open()
-
- val targetString: String = d.TargetText.text.replaceAll(raw"\\", raw"\\\\")
- val replacementString: String = d.ReplacementText.text.replaceAll(raw"\\", raw"\\\\")
- val initialDerivation = (panel.derivation, panel.controller.state.step)
-
- import quanto.cosy.SimplificationProcedure.Evaluation._
-
- if (targetString.length > 0) {
-
- val initialState: State = new State(
- List(),
- 0,
- Some(initialDerivation.verts.size),
- new Random(),
- initialDerivation.verts.toList,
- targetString,
- replacementString,
- Some(initialDerivation.verts.size)
- )
- val simproc = new SimplificationProcedure[State](
- initialDerivation,
- initialState,
- step,
- progress,
- (der, state) => state.currentStep == state.maxSteps.get
- )
- val evluationProgressController = new SimprocProgress[State](
- panel.project, "Evaluation", simproc
- )
- evluationProgressController.centerOnScreen()
- evluationProgressController.open()
- updateDerivation(evluationProgressController.returningDerivation, "evaluation")
- }
- }
-
- private def annealSimproc(): Unit = {
- val d = new SimulatedAnnealingDialog(panel.project)
- d.centerOnScreen()
- d.open()
-
- val timeSteps = d.MainPanel.TimeSteps.text.toInt
- val vertexLimit = d.MainPanel.vertexLimit()
- if (timeSteps > 0) {
-import quanto.cosy.SimplificationProcedure.Annealing._
- val initialState: State = new State(allowedRules, 0, Some(timeSteps), new Random(), 3, vertexLimit)
- val simproc = new SimplificationProcedure[State](
- (panel.derivation, panel.controller.state.step),
- initialState,
- step,
- progress,
- (_, state) => state.currentStep == state.maxSteps.get
- )
- val simulatedAnnealingController = new SimprocProgress[State](
- panel.project, "Simulated Annealing", simproc
- )
- simulatedAnnealingController.centerOnScreen()
- simulatedAnnealingController.open()
- updateDerivation(simulatedAnnealingController.returningDerivation, "annealing reduce")
- }
- }
-
- refreshSimprocs()
-
- private def pullErrorsSimproc(): Unit = {
- val initialDerivation = (panel.derivation, panel.controller.state.step)
- val graph = Derivation.derivationHeadPairToGraph(initialDerivation)
- val boundaries = graph.verts.filter(v => graph.vdata(v).isBoundary)
- val d = new SimpleSelectionPanel(panel.project,
- "Select target boundaries:",
- boundaries.toList.sorted.map(_.toString))
- d.centerOnScreen()
- d.open()
-
- val targets = d.MainPanel.OptionList.selection.items.map(s => VName(s)).toList
-
- val e = new SimpleSelectionPanel(panel.project,
- "Select greedy rules:",
- allowedRules.map(_.description.name))
- e.centerOnScreen()
- e.open()
-
- val greedyRules = e.MainPanel.OptionList.selection.items.map(s => ruleFromDesc(RuleDesc(s))).toList
-
-
-
- println(targets)
- import quanto.cosy.SimplificationProcedure.PullErrors
- if (targets.nonEmpty && allowedRules.nonEmpty) {
-
- val initialState: PullErrors.State = PullErrors.State(
- allowedRules.filterNot(r => greedyRules.contains(r)),
- 0,
- None,
- new Random(),
- PullErrors.errorsDistance(targets.toSet),
- greedyRules = Some(greedyRules),
- currentDistance = None,
- heldVertices = None,
- vertexLimit = None
- )
-
-
- val simplificationProcedure = new quanto.cosy.SimplificationProcedure[PullErrors.State](
- initialDerivation,
- initialState,
- PullErrors.step,
- PullErrors.progress,
- (_, state) => state.currentStep == state.maxSteps.getOrElse(-1) || state.currentDistance.getOrElse(2.0) < 1
- )
- val progressController = new SimprocProgress[PullErrors.State](
- panel.project, "Pull Errors Through", simplificationProcedure
- )
- progressController.centerOnScreen()
- progressController.open()
- updateDerivation(progressController.returningDerivation, "pull errors")
- }
- }
-
-
- private def pullSpecialsSimproc(): Unit = {
- val initialDerivation = (panel.derivation, panel.controller.state.step)
- val graph = Derivation.derivationHeadPairToGraph(initialDerivation)
- val boundaries = graph.verts.filter(v => graph.vdata(v).isBoundary)
- val d = new SimpleSelectionPanel(panel.project,
- "Select target boundaries:",
- boundaries.toList.sorted.map(_.toString))
- d.centerOnScreen()
- d.open()
-
- val e = new SimpleSelectionPanel(panel.project,
- "Select vertices to hold in place:",
- graph.verts.toList.sorted.map(_.toString))
- e.centerOnScreen()
- e.open()
-
- val targets = d.MainPanel.OptionList.selection.items.map(s => VName(s)).toList
- val specials = e.MainPanel.OptionList.selection.items.map(s => VName(s)).toList
- import quanto.cosy.SimplificationProcedure.LTEByWeight._
- println(targets)
- if (targets.nonEmpty) {
- val initialState: State = State(
- allowedRules,
- 0,
- None,
- new Random(),
- quanto.cosy.GraphAnalysis.distanceSpecialFromEnds(specials)(targets),
- None,
- heldVertices = Some(specials.toSet),
- None
- )
- val simproc = new SimplificationProcedure[State](
- (panel.derivation, panel.controller.state.step),
- initialState,
- step,
- progress,
- (_, state) => state.currentStep == state.maxSteps.getOrElse(-1) || state.currentDistance.getOrElse(1) == 0
- )
- val progressController = new SimprocProgress[State](
- panel.project, "Pull Errors Through", simproc
- )
- progressController.centerOnScreen()
- progressController.open()
- updateDerivation(progressController.returningDerivation, "pull errors")
- }
- }
-
- private def greedySimproc(): Unit = {
- import quanto.cosy.SimplificationProcedure.Greedy._
- val initialState: State = new State(
- allowedRules,
- 0,
- None,
- new Random(),
- allowedRules,
- None)
- val simplificationProcedure = new SimplificationProcedure[State](
- (panel.derivation, panel.controller.state.step),
- initialState,
- step,
- progress,
- (_, state) => state.currentStep == state.maxSteps.getOrElse(-1) || state.remainingRules.isEmpty
- )
- val progressController = new SimprocProgress[State](
- panel.project, "Greedy Reduction", simplificationProcedure
- )
- progressController.centerOnScreen()
- progressController.open()
- updateDerivation(progressController.returningDerivation, "greedy reduce")
- }
-
- private def lteSimproc(): Unit = {
- import quanto.cosy.SimplificationProcedure.LTEByWeight._
- val initialState: State = State(
- rules = allowedRules,
- currentStep = 0,
- currentDistance = None,
- maxSteps = Some(100),
- seed = new Random(),
- weightFunction = g => Some(g.verts.size + g.edges.size),
- heldVertices = None,
- vertexLimit = None)
- val simproc = new SimplificationProcedure[State](
- (panel.derivation, panel.controller.state.step),
- initialState,
- step,
- progress,
- (_, state) => state.currentStep == state.maxSteps.getOrElse(-1)
- )
- val progressController = new SimprocProgress[State](
- panel.project, "Greedy Reduction", simproc
- )
- progressController.centerOnScreen()
- progressController.open()
- updateDerivation(progressController.returningDerivation, "lte reduce")
- }
-
-
- private def randomSimproc(): Unit = {
- import quanto.cosy.SimplificationProcedure.LTEByWeight._
- val initialState: State = State(
- rules = allowedRules,
- currentStep = 0,
- currentDistance = None,
- maxSteps = Some(100),
- seed = new Random(),
- weightFunction = _ => None,
- heldVertices = None,
- vertexLimit = None)
- val simproc = new SimplificationProcedure[State](
- (panel.derivation, panel.controller.state.step),
- initialState,
- step,
- progress,
- (_, state) => state.currentStep == state.maxSteps.getOrElse(-1)
- )
- val progressController = new SimprocProgress[State](
- panel.project, "Random Rule Application", simproc
- )
- progressController.centerOnScreen()
- progressController.open()
- updateDerivation(progressController.returningDerivation, "random apply")
- }
-
-
- val availableProcedures: Map[String, () => Unit] = Map(
- "Random x 100" -> randomSimproc,
- "Graph shrink x 100" -> lteSimproc,
- "Greedy reduce" -> greedySimproc,
- "Pull specials" -> pullSpecialsSimproc,
- "Pull pi-errors" -> pullErrorsSimproc,
- "Anneal" -> annealSimproc,
- "Evaluate" -> evaluateSimproc
- )
- Swing.onEDT { panel.SimplifyBuiltInPane.Simprocs.listData = availableProcedures.keys.toSeq }
-
- implicit def ruleFromDesc(ruleDesc: RuleDesc): Rule = {
- Rule.fromJson(Json.parse(new File(panel.project.rootFolder + "/" + ruleDesc.name + ".qrule")),
- theory,
- description = Some(ruleDesc))
- }
-
- private def allowedRules = panel.rewriteController.rules.map(ruleFromDesc).toList
-
-
- private def updateDerivation(derivationWithHead: DerivationWithHead, desc: String): Unit = {
- println("updating derivation")
- val currentDerivation = panel.document.derivation
-
- panel.document.undoStack.register(desc) {
- updateDerivation((currentDerivation, panel.controller.state.step), desc)
- }
-
- panel.document.derivation = derivationWithHead._1
- derivationWithHead._2 match {
- case Some(stepName) => panel.controller.state = StepState(stepName)
- case None => panel.controller.state = HeadState(None)
- }
- }
-
-
- reactions += {
- case ButtonClicked(panel.SimplifyBuiltInPane.SimplifyButton) =>
- if (panel.SimplifyBuiltInPane.Simprocs.selection.indices.nonEmpty) {
- val procedureName: String = panel.SimplifyBuiltInPane.Simprocs.selection.items(0)
- val procedure = availableProcedures(procedureName)
- procedure.apply()
- }
- }
-
-}
diff --git a/scala/src/main/scala/quanto/gui/SimplifyController.scala b/scala/src/main/scala/quanto/gui/SimplifyController.scala
index 37af3f61..15ed1d43 100644
--- a/scala/src/main/scala/quanto/gui/SimplifyController.scala
+++ b/scala/src/main/scala/quanto/gui/SimplifyController.scala
@@ -6,8 +6,10 @@ import quanto.data._
import quanto.data.Names._
import quanto.util.json._
import akka.pattern.ask
+import quanto.util.UserAlerts
+
import scala.concurrent.ExecutionContext.Implicits.global
-import scala.swing.event.ButtonClicked
+import scala.swing.event.{ButtonClicked, Event}
import scala.util.{Failure, Success, Try}
import quanto.util.UserAlerts.SelfAlertingProcess
@@ -18,7 +20,19 @@ class SimplifyController(panel: DerivationPanel) extends Publisher {
private var activeSimp: Option[Future[Boolean]] = None
- listenTo(panel.SimplifyPane.RefreshButton, panel.SimplifyPane.SimplifyButton, panel.SimplifyPane.StopButton)
+
+ case class StartSimproc(simproc: String) extends Event
+ case class RefreshSimprocs() extends Event
+ case class ReRunSimproc() extends Event
+ case class HaltSimproc() extends Event
+ case class SimprocHaltedWhileRunning() extends Exception("Simproc halted")
+
+ listenTo(this,
+ panel.SimplifyPane.RefreshButton,
+ panel.SimplifyPane.SimplifyButton,
+ panel.SimplifyPane.StopButton,
+ panel,
+ PythonEditPanel)
def theory = panel.theory
@@ -33,98 +47,79 @@ class SimplifyController(panel: DerivationPanel) extends Publisher {
refreshSimprocs()
- private def pullSimp(simproc: String, sid: Int, stack: String, parentOpt: Option[DSName]) {
- if (simpId == sid) {
- //val res = QuantoDerive.core ? Call(theory.coreName, "simplify", "pull_next_step",
- // JsonObject("stack" -> JsonString(stack)))
-
-// res.map {
-// case Success(JsonNull) => // out of steps
-// Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false }
-// simpId += 1
-// case Success(json) =>
-// if (simpId == sid) {
-// val suggest = simproc + "-" + (json / "rule_name").stringValue.replaceFirst("^.*\\/", "") + "-0"
-// val sname = panel.derivation.steps.freshWithSuggestion(DSName(suggest))
-// val step = DStep.fromJson(sname, json, theory).layout
-//
-// Swing.onEDT {
-// panel.document.derivation = panel.document.derivation.addStep(parentOpt, step)
-// panel.controller.state = HeadState(Some(step.name))
-// pullSimp(simproc, sid, stack, Some(sname))
-// }
-//
-// } else {
-// Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false }
-// QuantoDerive.core ! Call(theory.coreName, "simplify", "delete_stack",
-// JsonObject("stack" -> JsonString(stack)))
-// }
-// case _ => println("ERROR: Unexpected result from core: " + res) // TODO: errror dialogs
-// }
- None
- } else {
-// Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false }
-// QuantoDerive.core ! Call(theory.coreName, "simplify", "delete_stack",
-// JsonObject("stack" -> JsonString(stack)))
- None
- }
- }
+ private var _lastRunSimproc : String = ""
reactions += {
- case ButtonClicked(panel.SimplifyPane.RefreshButton) => refreshSimprocs()
- case ButtonClicked(panel.SimplifyPane.SimplifyButton) =>
- if (panel.SimplifyPane.Simprocs.selection.indices.nonEmpty) {
- simpId += 1
- val simpName = panel.SimplifyPane.Simprocs.selection.items(0)
-
- QuantoDerive.CurrentProject.flatMap { pr => pr.simprocs.get(simpName) }.foreach { simproc =>
- var parentOpt = panel.controller.state.step
- val processReporting = new SelfAlertingProcess("Simproc: " + simpName)
-
- val res = Future[Boolean] {
- for ((graph, rule) <- simproc.simp(panel.LhsView.graph)) {
- val suggest = simpName + "-" + rule.name.replaceFirst("^.*\\/", "") + "-0"
- val step = DStep(
- name = panel.derivation.steps.freshWithSuggestion(DSName(suggest)),
- rule = rule,
- graph = graph.minimise) // layout is already done by simproc now
+ case RequestReRunSimproc() =>
+ publish(ReRunSimproc())
+ case ReRunSimproc() =>
+ publish(StartSimproc(_lastRunSimproc))
+ case StartSimproc(simpName) =>
+ _lastRunSimproc = simpName
+ QuantoDerive.CurrentProject.flatMap { pr => pr.simprocs.get(simpName) }.foreach { simproc =>
+ var parentOpt = panel.controller.state.step
+ val sourceMessage = s"Running simproc '$simpName'".concat(
+ if (simproc.sourceFile != "") {
+ s" from ${simproc.sourceFile}"
+ } else {
+ ""
+ }
+ )
+ UserAlerts.alert(sourceMessage)
+ val processReporting = new SelfAlertingProcess("Simproc: " + simpName)
+ val simpIdAtStart = simpId
+ val res = Future[Boolean] {
+ val iteratedSimp : Iterator[(Graph, Rule)] = simproc.simp(panel.LhsView.graph)
+
+ // Don't update the derivation if the simpId call has changed
+ while(iteratedSimp.hasNext && simpId == simpIdAtStart){
+ val (graph, rule) = iteratedSimp.next()
+ val suggest = simpName + "-" + rule.name.replaceFirst("^.*\\/", "") + "-0"
+ val step = DStep(
+ name = panel.derivation.steps.freshWithSuggestion(DSName(suggest)),
+ rule = rule,
+ graph = graph.minimise) // layout is already done by simproc now
panel.document.derivation = panel.document.derivation.addStep(parentOpt, step)
parentOpt = Some(step.name)
+ Swing.onEDT {
+ panel.controller.state = HeadState(Some(step.name))
+ }
+ }
- Swing.onEDT { panel.controller.state = HeadState(Some(step.name)) }
- }
- true
+ if(simpId != simpIdAtStart) {
+ throw SimprocHaltedWhileRunning()
}
- res.onComplete {
- case Success(b) =>
- processReporting.finish()
- case Failure(e) =>
- processReporting.fail()
- e.printStackTrace()
- }
+ true
}
-
-// val res = QuantoDerive.core ? Call(theory.coreName, "simplify", "simplify", JsonObject(
-// "simproc" -> JsonString(simproc),
-// "graph" -> Graph.toJson(panel.LhsView.graph, theory)
-// ))
-// res.map {
-// case Success(JsonString(stack)) =>
-// Swing.onEDT {
-// QuantoDerive.ConsoleProgress.indeterminate = true
-// pullSimp(simproc, simpId, stack, panel.controller.state.step)
-// }
-//
-// case _ => println("ERROR: Unexpected result from core: " + res) // TODO: errror dialogs
-// }
+ res.onComplete {
+ case Success(b) =>
+ processReporting.finish()
+ case Failure(SimprocHaltedWhileRunning()) =>
+ processReporting.halt()
+ case Failure(e) =>
+ e.printStackTrace()
+ }
}
- case ButtonClicked(panel.SimplifyPane.StopButton) =>
- Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false }
+ case HaltSimproc() =>
simpId += 1
+ case RefreshSimprocs() =>
+ refreshSimprocs()
+ case ButtonClicked(panel.SimplifyPane.RefreshButton) =>
+ publish(RefreshSimprocs())
+ case ButtonClicked(panel.SimplifyPane.SimplifyButton) =>
+ if (panel.SimplifyPane.Simprocs.selection.indices.nonEmpty) {
+ simpId += 1
+ val simpName = panel.SimplifyPane.Simprocs.selection.items(0)
+ publish(StartSimproc(simpName))
+ }
+ case ButtonClicked(panel.SimplifyPane.StopButton) =>
+ publish(HaltSimproc())
+ case SimprocsUpdated() =>
+ refreshSimprocs()
}
}
diff --git a/scala/src/main/scala/quanto/gui/SpecifyVariablesDialog.scala b/scala/src/main/scala/quanto/gui/SpecifyVariablesDialog.scala
new file mode 100644
index 00000000..9f6475bc
--- /dev/null
+++ b/scala/src/main/scala/quanto/gui/SpecifyVariablesDialog.scala
@@ -0,0 +1,93 @@
+package quanto.gui
+
+import quanto.data.Theory.ValueType
+
+import scala.swing._
+import scala.swing.event.{ButtonClicked, Key, KeyPressed, ValueChanged}
+import quanto.data._
+import quanto.util.{Globals, UserOptions}
+
+import scala.util.matching
+import scala.util.matching.Regex
+import quanto.util.UserOptions.{scale, scaleInt}
+
+class SpecifyVariablesDialog(variables: List[(ValueType, String)]) extends Dialog {
+ modal = true
+ title = "Specify values for introduced variables?"
+ val smallGap: Int = UserOptions.scaleInt(5)
+
+
+
+ val AddButton = new Button("Add")
+ val CancelButton = new Button("Cancel")
+ defaultButton = Some(AddButton)
+
+ def result: Map[(ValueType, String), String] = {
+ if(!cancelled) {
+ variables.map(sv => sv -> textFields(sv).text).toMap
+ } else {
+ variables.map(sv => sv -> sv._2).toMap
+ }
+ }
+
+
+
+ implicit def buttonIsSelected(radButton: RadioButton): Boolean = radButton.selected
+ var cancelled = false
+
+
+ // val dir = Files.newDirectoryStream(Paths.get(rootDir), "**/*.qrule")
+ // for (p <- dir.asScala) println(p)
+ val textFields : Map[(ValueType, String), TextField] = variables.map(name => {
+ name -> new TextField
+ }).toMap
+
+ val MainPanel = new BoxPanel(Orientation.Vertical) {
+
+ contents += Swing.VStrut(smallGap)
+
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(smallGap),
+ new Label("This rule introduces new variables. Would you like to specify values?"),
+ Swing.HStrut(smallGap))
+ }
+ contents += Swing.VStrut(smallGap)
+
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(smallGap),
+ new GridPanel(variables.length, 2) {
+ variables.foreach(name => {
+ val label = new Label(s"${name._2} (${TheoryEditPanel.valueTypesAsHumanReadable(name._1)})")
+ label.horizontalAlignment = Alignment.Left
+ contents += label
+ val tf = textFields(name)
+ contents += tf
+ tf.text = name._2
+ })
+ },
+ Swing.HStrut(smallGap))
+ }
+ contents += Swing.VStrut(smallGap)
+
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (AddButton, Swing.HStrut(smallGap), CancelButton)
+ }
+
+ contents += Swing.VStrut(2*smallGap)
+
+ }
+
+
+ contents = MainPanel
+
+ listenTo(AddButton, CancelButton)
+
+ reactions += {
+ case ButtonClicked(AddButton) =>
+ cancelled = false
+ close()
+ case ButtonClicked(CancelButton) =>
+ cancelled = true
+ close()
+ }
+}
diff --git a/scala/src/main/scala/quanto/gui/StatusBar.scala b/scala/src/main/scala/quanto/gui/StatusBar.scala
new file mode 100644
index 00000000..88a8a562
--- /dev/null
+++ b/scala/src/main/scala/quanto/gui/StatusBar.scala
@@ -0,0 +1,118 @@
+package quanto.gui
+
+import java.awt.Desktop
+import java.awt.event.{MouseAdapter, MouseEvent}
+
+import quanto.gui.QuantoDerive.{Split, listenTo, popup, CurrentProject}
+import quanto.util.{UserAlerts, UserOptions}
+
+import scala.swing.{Action, BorderPanel, Dimension, FlowPanel, GridPanel, Label, MenuItem, ProgressBar, Publisher, Separator}
+import scala.swing.event.{Event, Key}
+
+class StatusBar extends BorderPanel {
+
+
+ def MessagesContextMenu() : PopupMenu = new PopupMenu {
+ menu =>
+
+ def showAlert(a : UserAlerts.Alert) : Unit = {
+ menu.contents += new Label(" "+a.toString)
+ }
+
+ val numAlerts : Int = UserAlerts.alerts.count(_ => true)
+ if(numAlerts < 6) {
+ // Few alerts, so list all of them
+ UserAlerts.alerts.reverse.foreach(showAlert)
+ } else {
+ menu.contents += new Label(" ...")
+ // List only first 5
+ UserAlerts.alerts.slice(0,5).reverse.foreach(showAlert)
+ }
+
+
+ def whichLoggingTitle : String = {
+ if (UserOptions.logging) {"Disable logging"} else {"Enable logging"}
+ }
+
+ val ShowLogAction: Action = new Action(whichLoggingTitle) {
+
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.L
+ }
+
+ def apply() {
+ UserOptions.logging = !UserOptions.logging
+ }
+ }
+
+ menu.contents += new Separator()
+
+ val ClearMessagesAction: Action = new Action("Clear message") {
+ menu.contents += new MenuItem(this) {
+ mnemonic = Key.C
+ }
+
+ def apply() {
+ clearStatusBar()
+ }
+ }
+
+ }
+
+
+ val UserMessage = new Label(UserAlerts.latestMessage.toString)
+ val ConsoleProgress = new ProgressBar
+ val ConsoleProgressLabel = new Label(" ")
+ val Status : GridPanel = new GridPanel(1, 2) {
+ contents += new FlowPanel(FlowPanel.Alignment.Left)(UserMessage)
+ contents += new FlowPanel(FlowPanel.Alignment.Right)(ConsoleProgressLabel, ConsoleProgress)
+ }
+
+ ConsoleProgress.preferredSize = ConsoleProgressSize //Currently doesn't respond to UI scaling
+
+ def ConsoleProgressSize: Dimension = new Dimension(UserOptions.scaleInt(100), UserOptions.scaleInt(15))
+
+ Status.peer.addMouseListener(new MouseAdapter {
+ override def mousePressed(e: MouseEvent) {
+ e.getButton match {
+ case _ =>
+ if (e.isPopupTrigger) {
+ popup(MessagesContextMenu(), Some(e))
+ }
+ }
+ }
+
+ })
+
+ def clearStatusBar() : Unit = {
+ UserMessage.text = ""
+ ConsoleProgressLabel.text = " "
+ ConsoleProgress.value = 0
+ ConsoleProgress.indeterminate = false
+ }
+
+
+ listenTo(UserAlerts.AlertPublisher)
+ reactions += {
+ case UserAlerts.UserAlertEvent(alert: UserAlerts.Alert) =>
+ UserMessage.text = alert.toString
+ UserMessage.foreground = alert.color
+ case UserAlerts.UserProcessUpdate(_) =>
+ UserAlerts.leastCompleteProcess match {
+ case Some(process) => if (process.determinate) {
+ ConsoleProgress.indeterminate = false
+ ConsoleProgress.value = process.value
+ } else {
+ ConsoleProgress.indeterminate = true
+ }
+ case _ => ConsoleProgress.value = 100
+ }
+ val ongoing = UserAlerts.ongoingProcesses.filter(op => op.value < 100)
+ ongoing.count(_ => true) match {
+ case 0 => ConsoleProgressLabel.text = " " //keep non-empty so the progressbar stays in line with text
+ case 1 => ConsoleProgressLabel.text = ongoing.head.name
+ case n => ConsoleProgressLabel.text = n.toString + " processes ongoing"
+ }
+ }
+
+}
diff --git a/scala/src/main/scala/quanto/gui/TextEditor.scala b/scala/src/main/scala/quanto/gui/TextEditor.scala
new file mode 100644
index 00000000..7f7b96a5
--- /dev/null
+++ b/scala/src/main/scala/quanto/gui/TextEditor.scala
@@ -0,0 +1,107 @@
+package quanto.gui
+
+import java.awt.BorderLayout
+import java.awt.event.{KeyAdapter, KeyEvent}
+import java.io.{File, PrintWriter}
+
+import org.gjt.sp.jedit.buffer.JEditBuffer
+import org.gjt.sp.jedit.textarea.StandaloneTextArea
+import org.gjt.sp.jedit.{Mode, Registers}
+import quanto.util.UserOptions
+
+import scala.io.Source
+import scala.swing.{BorderPanel, Label}
+
+class TextEditor(val mode: Mode) extends BorderPanel {
+ lazy val Component: BorderPanel = new BorderPanel {
+ peer.add(TextArea, BorderLayout.CENTER)
+ }
+
+ import org.gjt.sp.jedit.IPropertyManager
+ import org.gjt.sp.jedit.textarea.StandaloneTextArea
+
+ // TODO: check font availability and/or allow user to select one
+ val props = new java.util.Properties()
+ val propFile = this.getClass.getResourceAsStream("jedit.props")
+ val keyFile = this.getClass.getResourceAsStream("jEdit_keys.props")
+ props.load(propFile)
+ props.load(keyFile)
+ //props.setProperty("view.font", "Arial")
+ props.setProperty("view.fontsize", UserOptions.fontSize.toString)
+ //props.setProperty("view.fontstyle", "0")
+ //props.putAll(loadProperties("/keymaps/jEdit_keys.props"))
+ //props.putAll(loadProperties("/org/gjt/sp/jedit/jedit.props"))
+ val TextArea = new StandaloneTextArea(new IPropertyManager() {
+ override def getProperty(name: String): String = props.getProperty(name)
+ })
+ //textArea.getBuffer.setProperty("folding", "explicit")
+ //val TextArea: StandaloneTextArea = StandaloneTextArea.createTextArea()
+
+ private val buf = new JEditBuffer1
+ private val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask
+
+ def registerBuffer() {
+ TextArea.setFont(UserOptions.font)
+ buf.setMode(mode)
+ TextArea.setBuffer(buf)
+ TextArea.addKeyListener(new KeyAdapter {
+ override def keyPressed(e: KeyEvent) {
+ if (e.getModifiers == CommandMask) e.getKeyChar match {
+ case 'x' => Registers.cut(TextArea, '$')
+ case 'c' => Registers.copy(TextArea, '$')
+ case 'v' => Registers.paste(TextArea, '$')
+ case _ =>
+ }
+ }
+ })
+ }
+
+ def getText: String = {
+ TextArea.getBuffer.getText
+ }
+
+ registerBuffer()
+}
+
+object TextEditor {
+ // JEdit can't access scala resources directly, so need to make a dummy file to hold the xml
+ def makeDummyXMLFile(name: String): String = {
+ val f = File.createTempFile(name, "xml")
+ f.deleteOnExit()
+ val pr = new PrintWriter(f)
+ Source.fromInputStream(getClass.getResourceAsStream(s"$name.xml")).foreach(pr.print)
+ pr.close()
+ f.getCanonicalPath
+ }
+
+ object Modes {
+ def python: Mode = {
+ val mode = new Mode("Python")
+ val modeFile = makeDummyXMLFile("python")
+ mode.setProperty("file", modeFile)
+ mode
+ }
+
+ def markdown: Mode = {
+ val mode = new Mode("Markdown")
+ val modeFile = makeDummyXMLFile("markdown")
+ mode.setProperty("file", modeFile)
+ mode
+ }
+
+ def blank : Mode = {
+ val mode = new Mode("Blank")
+ val modeFile = makeDummyXMLFile("blank")
+ mode.setProperty("file", modeFile)
+ mode
+ }
+
+ def fromFile(path: String, identifier: String): Mode = {
+ val mode = new Mode(identifier)
+ val modeFile = path
+ mode.setProperty("file", modeFile)
+ mode
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/gui/TheoryDocument.scala b/scala/src/main/scala/quanto/gui/TheoryDocument.scala
new file mode 100644
index 00000000..f2ec1c92
--- /dev/null
+++ b/scala/src/main/scala/quanto/gui/TheoryDocument.scala
@@ -0,0 +1,85 @@
+package quanto.gui
+
+import java.io.File
+
+import quanto.data._
+import quanto.util.UserAlerts
+import quanto.util.json.Json
+
+import scala.collection.mutable
+import scala.ref.Reference
+import scala.swing.Reactions.Reaction
+import scala.swing.{Component, Publisher, RefSet}
+import scala.swing.event.Event
+import scala.swing.event.Event
+
+
+case class TheoryChanged() extends Event
+
+// Will infer the theory from QuantoDerive!
+// Otherwise can have multiple theories editing at once, but only one should be affecting the project
+class TheoryDocument(val parent: Component) extends Document with Publisher {
+ val description = "Theory"
+ val fileExtension = "qtheory"
+ private var _theory: Theory = getTheory()
+
+ def getTheory(): Theory = {
+ if (QuantoDerive.CurrentProject.isEmpty) {
+ blankTheory()
+ } else {
+ QuantoDerive.CurrentProject.get.theory
+ }
+ }
+
+ protected def clearDocument() {
+ _theory = blankTheory()
+ publish(TheoryChanged())
+ }
+
+ def blankTheory(): Theory = {
+ new Theory(name = "unsaved",
+ coreName = "unsaved",
+ vertexTypes = Map(), edgeTypes = Map(),
+ defaultVertexType = "")
+ }
+
+ override def loadDocument(f: File) {
+ // Not convinced this should ever be called
+ val json = Json.parse(f)
+ val newTheory = Theory.fromJson(json)
+ //theory = newTheory
+ publish(TheoryChanged())
+ }
+
+ // There should never be unsaved changes
+ override def unsavedChanges: Boolean = false
+
+ protected def saveDocument(f: File) {
+ val json = Theory.toJson(theory)
+ json.writeTo(f)
+ publish(TheoryChanged())
+ }
+
+ def theory: Theory = _theory
+
+ def theory_=(th: Theory): Unit = {
+ if (QuantoDerive.CurrentProject.isEmpty){
+ UserAlerts.alert("Please open a project before altering theory files", UserAlerts.Elevation.WARNING)
+ } else {
+ val project = QuantoDerive.CurrentProject.get
+ QuantoDerive.CurrentProject = Some(new Project(th, project.projectFile, project.name))
+ QuantoDerive.updateProjectFile(project.projectFile)
+ }
+ _theory = getTheory()
+ publish(TheoryChanged())
+ }
+
+
+ override protected def exportDocument(f: File) {
+ showSaveAsDialog(None)
+ }
+
+ override def titleDescription: String = "Theory Editor"
+
+
+}
diff --git a/scala/src/main/scala/quanto/gui/TheoryEditPanel.scala b/scala/src/main/scala/quanto/gui/TheoryEditPanel.scala
new file mode 100644
index 00000000..66c41629
--- /dev/null
+++ b/scala/src/main/scala/quanto/gui/TheoryEditPanel.scala
@@ -0,0 +1,757 @@
+package quanto.gui
+
+import java.awt.{Color, Shape}
+
+import org.lindenb.svg.SVGUtils
+import quanto.data.Theory._
+import quanto.data.{CompositeExpression, GenericParseException, Theory, TheoryLoadException}
+import quanto.util.UserAlerts.{Elevation, alert}
+import quanto.util.UserOptions.scaleInt
+import quanto.util._
+import quanto.util.json.{Json, JsonObject}
+import quanto.util.swing.ToolBar
+
+import scala.language.postfixOps
+
+import scala.swing._
+import scala.swing.event.{ButtonClicked, SelectionChanged}
+
+import TheoryEditPanel._
+
+class TheoryEditPanel() extends BorderPanel with HasDocument {
+ val document = new TheoryDocument(this)
+ val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask
+ val Toolbar = new ToolBar {
+ //contents
+ }
+ val EditorsCombined = new BoxPanel(Orientation.Vertical)
+ val TopScrollablePane = new ScrollPane(EditorsCombined)
+
+
+
+ def chooseNodeDataType(node: String, current: String): Unit = {
+ val dialog = new DataTypePickingDialog(node, current)
+ dialog.centerOnScreen()
+ dialog.open()
+ if(dialog.wasAccepted){
+ val typeSelected = dialog.DataComboBox.selection.item
+ val newTypeVector: Vector[ValueType] = if(typeSelected == "composite"){
+ try{
+ CompositeExpression.parseTypes(dialog.CustomText.text)
+ } catch {
+ case _ : GenericParseException => Vector()
+ }
+ } else {
+ CompositeExpression.parseTypes(typeSelected)
+ }
+
+
+ val oldVertexDesc = theory.vertexTypes(node)
+ val newTypeValue : ValueDesc = new ValueDesc(typ = newTypeVector,
+ oldVertexDesc.value.enumOptions,
+ oldVertexDesc.value.latexConstants,
+ oldVertexDesc.value.validateWithCore)
+
+ val newVertexDesc : VertexDesc = new VertexDesc(newTypeValue,
+ oldVertexDesc.style,
+ oldVertexDesc.defaultData)
+
+ val newVertexTypes: Map[String, VertexDesc] = theory.vertexTypes + (node -> newVertexDesc)
+ // Update the base document's theory
+ theory = theory.copy(vertexTypes = newVertexTypes)
+ }
+ }
+
+ def chooseNodeBorderWidth(node: String, current: Int): Unit = {
+
+ val dialog = new SizeDialog(s"$node's border", current)
+ dialog.centerOnScreen()
+ dialog.open()
+ val newSize = dialog.SizeComboBox.selection.item.toInt
+
+ val oldStyle: VertexStyleDesc = theory.vertexTypes(node).style
+ val newStyle: VertexStyleDesc = oldStyle.copy(strokeWidth = newSize)
+
+ val oldVertexDesc = theory.vertexTypes(node)
+ val newVertexDesc = new VertexDesc(oldVertexDesc.value,
+ newStyle,
+ oldVertexDesc.defaultData)
+
+ val newVertexTypes: Map[String, VertexDesc] = theory.vertexTypes + (node -> newVertexDesc)
+ // Update the base document's theory
+ theory = theory.copy(vertexTypes = newVertexTypes)
+ }
+
+ def chooseNodeColour(node: String, current: Color): Unit = {
+ val dialog = new ColourPickingDialog(node, current)
+ dialog.centerOnScreen()
+ dialog.open()
+ val newColourHex: String = dialog.CustomText.text
+ val newColour = Color.decode(newColourHex)
+ val oldStyle: VertexStyleDesc = theory.vertexTypes(node).style
+ val newStyle: VertexStyleDesc = oldStyle.copy(fillColor = newColour)
+
+ val oldVertexDesc = theory.vertexTypes(node)
+ val newVertexDesc = new VertexDesc(oldVertexDesc.value,
+ newStyle,
+ oldVertexDesc.defaultData)
+
+ val newVertexTypes: Map[String, VertexDesc] = theory.vertexTypes + (node -> newVertexDesc)
+ // Update the base document's theory
+ theory = theory.copy(vertexTypes = newVertexTypes)
+ }
+
+ // Method for choosing the strokecolour of an edge
+ def chooseEdgeColour(edge: String, current: Color): Unit = {
+ val dialog = new ColourPickingDialog(s"Choose colour for edge $edge", current)
+ dialog.centerOnScreen()
+ dialog.open()
+ val newColourHex: String = dialog.CustomText.text
+ val newColour = Color.decode(newColourHex)
+ val oldStyle: EdgeStyleDesc = theory.edgeTypes(edge).style
+ val newStyle: EdgeStyleDesc = oldStyle.copy(strokeColor = newColour)
+
+ val oldEdgeDesc = theory.edgeTypes(edge)
+ val newEdgeDesc = new EdgeDesc(oldEdgeDesc.value,
+ newStyle,
+ oldEdgeDesc.defaultData)
+
+ val newEdgeTypes: Map[String, EdgeDesc] = theory.edgeTypes + (edge -> newEdgeDesc)
+ // Update the base document's theory
+ theory = theory.copy(edgeTypes = newEdgeTypes)
+ }
+
+ // Method for choosing the strokewidth of an edge
+ def chooseEdgeWidth(edge: String, current: Int): Unit = {
+ val dialog = new SizeDialog(edge, current)
+ dialog.centerOnScreen()
+ dialog.open()
+ val newSize = dialog.SizeComboBox.selection.item.toInt
+ val oldStyle: EdgeStyleDesc = theory.edgeTypes(edge).style
+ val newStyle: EdgeStyleDesc = oldStyle.copy(strokeWidth = newSize)
+
+ val oldEdgeDesc = theory.edgeTypes(edge)
+ val newEdgeDesc = new EdgeDesc(oldEdgeDesc.value,
+ newStyle,
+ oldEdgeDesc.defaultData)
+
+ val newEdgeTypes: Map[String, EdgeDesc] = theory.edgeTypes + (edge -> newEdgeDesc)
+ // Update the base document's theory
+ theory = theory.copy(edgeTypes = newEdgeTypes)
+ }
+
+ // Choose label position for current node
+ def chooseLabelPlacement(node: String, current: VertexLabelPosition): Unit = {
+ val dialog = new LabelPlacementDialog(node, current)
+ dialog.centerOnScreen()
+ dialog.open()
+ val newPlacementName: String = dialog.PlacementComboBox.item
+ val newPlacement: VertexLabelPosition = VertexLabelPosition.fromName(newPlacementName).
+ getOrElse(VertexLabelPosition.values.head)
+ val oldStyle: VertexStyleDesc = theory.vertexTypes(node).style
+ val newStyle: VertexStyleDesc = oldStyle.copy(labelPosition = newPlacement)
+
+ val oldVertexDesc = theory.vertexTypes(node)
+ val newVertexDesc = new VertexDesc(oldVertexDesc.value,
+ newStyle,
+ oldVertexDesc.defaultData)
+
+ val newVertexTypes: Map[String, VertexDesc] = theory.vertexTypes + (node -> newVertexDesc)
+ // Update the base document's theory
+ theory = theory.copy(vertexTypes = newVertexTypes)
+ }
+
+ def theory: Theory = document.theory
+
+ def theory_=(th: Theory) = document.theory_=(th)
+
+ // Prompts the user to select a type of node, including "custom" as an option
+ def chooseNodeShape(node: String, current: (VertexShape, Option[Shape])): Unit = {
+ val currentPath = current._2 match {
+ case None => None
+ case Some(shape) => Some(SVGUtils.shapeToPath(shape))
+ }
+ val dialog = new NodeShapeDialog(node, current._1, currentPath.getOrElse(""))
+ dialog.centerOnScreen()
+ dialog.open()
+ val newShapeName: String = dialog.ShapeComboBox.item
+ val newShape = VertexShape.fromName(newShapeName).getOrElse(VertexShape.Circle)
+ val oldStyle: VertexStyleDesc = theory.vertexTypes(node).style
+ val newStyle: VertexStyleDesc = oldStyle.copy(
+ newShape,
+ newShape match {
+ case VertexShape.Custom =>
+ try {
+ Some(SVGUtils.pathToShape(dialog.CustomText.text))
+ } catch {
+ case e: Exception => new TheoryLoadException("Could not interpret custom shape path", e)
+ None
+ }
+ case _ => None
+ }
+ )
+
+ val oldVertexDesc = theory.vertexTypes(node)
+ val newVertexDesc = new VertexDesc(oldVertexDesc.value,
+ newStyle,
+ oldVertexDesc.defaultData)
+
+ val newVertexTypes: Map[String, VertexDesc] = theory.vertexTypes + (node -> newVertexDesc)
+ // Update the base document's theory
+ theory = theory.copy(vertexTypes = newVertexTypes)
+ }
+
+ def createPageComponents(): Unit = {
+
+ val separation = UserOptions.scaleInt(20)
+
+ def horizontalWrap(component: Component): Unit = {
+ EditorsCombined.contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), component, Swing.HStrut(10))
+ }
+ }
+
+ val buttonSize = new Dimension(maxGridSize.width, scaleInt(30))
+
+ EditorsCombined.contents.clear
+
+ EditorsCombined.contents += Swing.VStrut(separation)
+ horizontalWrap(new Label(
+ """All changes are saved immediately. Re-open graphs to see the changes."""))
+ horizontalWrap(new Label(
+ """It is recommended that you back up your project before altering anything on this page!"""))
+
+ EditorsCombined.contents += Swing.VStrut(separation)
+ EditorsCombined.contents += new Separator()
+ EditorsCombined.contents += Swing.VStrut(separation)
+
+ horizontalWrap(new Label("Vertices:"))
+ theory.vertexTypes.toSeq.sortBy(vt => vt._1).map(vt => {
+ EditorsCombined.contents += new NodeEditor(vt._1, vt._2)
+ EditorsCombined.contents += Swing.VStrut(separation)
+ })
+ val AddVertexButton = new Button("Add Vertex Type")
+ AddVertexButton.preferredSize = buttonSize
+ horizontalWrap(AddVertexButton)
+ EditorsCombined.contents += Swing.VStrut(separation)
+ EditorsCombined.contents += new Separator()
+ EditorsCombined.contents += Swing.VStrut(separation)
+ horizontalWrap(new Label("Edges:"))
+ theory.edgeTypes.toSeq.sortBy(et => et._1).map(et => {
+ EditorsCombined.contents += new EdgeEditor(et._1, et._2)
+ EditorsCombined.contents += Swing.VStrut(separation)
+ }
+ )
+ val AddEdgeButton = new Button("Add Edge Type")
+ AddEdgeButton.preferredSize = buttonSize
+ horizontalWrap(AddEdgeButton)
+ EditorsCombined.contents += Swing.VStrut(separation)
+
+ listenTo(AddEdgeButton, AddVertexButton)
+ reactions += {
+ case ButtonClicked(AddEdgeButton) =>
+ addNewEdge()
+ case ButtonClicked(AddVertexButton) =>
+ addNewVertex()
+
+ }
+ repaint()
+ }
+
+ def addNewVertex(): Unit = {
+ val d = new ChooseStringDialog("Name for new vertex type")
+ d.centerOnScreen()
+ d.open()
+
+ val result = d.StringField.text
+ if (result != "" && !theory.vertexTypes.keys.exists(k => k == result)) {
+ val newVertexDesc = new VertexDesc(
+ value = new ValueDesc(),
+ style = new VertexStyleDesc(shape = VertexShape.Circle),
+ defaultData = JsonObject(
+ "type" -> Json.stringToJson(result),
+ "value" -> Json.stringToJson("")
+ )
+ )
+ val newVertexTypes: Map[String, Theory.VertexDesc] = theory.vertexTypes ++ Map(result -> newVertexDesc)
+ // Update the base document's theory
+ theory = theory.copy(vertexTypes = newVertexTypes)
+ }
+ }
+
+ def addNewEdge(): Unit = {
+ val d = new ChooseStringDialog("Name for new edge type")
+ d.centerOnScreen()
+ d.open()
+
+ val result = d.StringField.text
+ if (result != "") {
+ if (!theory.edgeTypes.keys.exists(k => k == result)) {
+ val newEdgeTypes: Map[String, Theory.EdgeDesc] = theory.edgeTypes ++ Map(result -> EdgeDesc(
+ value = ValueDesc(typ = Vector(ValueType.Empty)),
+ style = EdgeStyleDesc(),
+ defaultData = JsonObject("type" -> result)
+ ))
+ // Update the base document's theory
+ theory = new Theory(name = theory.name,
+ coreName = theory.coreName,
+ vertexTypes = theory.vertexTypes,
+ edgeTypes = newEdgeTypes,
+ defaultVertexType = theory.defaultVertexType,
+ defaultEdgeType = theory.defaultEdgeType
+ )
+ } else {
+ UserAlerts.alert("That name is already in use", Elevation.WARNING)
+ }
+ }
+ }
+
+ def colourToString(colour: Color): String = {
+ var returnColour = f"#${0xFFFFFF & colour.hashCode()}%06X"
+ for ((colourName, colourValue) <- approvedColours) {
+ if (colourValue == colour) {
+ returnColour = colourName
+ }
+ }
+ returnColour
+ }
+
+ class ChooseStringDialog(title: String) extends Dialog {
+ modal = true
+
+
+ val AcceptButton = new Button("Accept")
+ val CancelButton = new Button("Cancel")
+ val StringField = new TextField("")
+ defaultButton = Some(AcceptButton)
+ val ShapeEditorPanel : BoxPanel = new BoxPanel(Orientation.Vertical) {
+
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), new Label("Name:"), Swing.HStrut(5), StringField, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), AcceptButton, Swing.HStrut(5), CancelButton, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ }
+
+ contents = ShapeEditorPanel
+
+ listenTo(AcceptButton, CancelButton)
+
+ reactions += {
+ case ButtonClicked(AcceptButton) =>
+ close()
+ case ButtonClicked(CancelButton) =>
+ StringField.text = ""
+ close()
+ }
+ }
+
+ class SizeDialog(identifier: String,
+ current: Int,
+ sizeOptions: Seq[String] = (1 to 5).map(_.toString)) extends Dialog {
+ title = s"Choose the size of $identifier"
+ modal = true
+
+ val AcceptButton = new Button("Accept")
+ val CancelButton = new Button("Cancel")
+ val SizeComboBox = new ComboBox(sizeOptions)
+ SizeComboBox.selection.item = current.toString
+ defaultButton = Some(CancelButton)
+ val ShapeEditorPanel : BoxPanel = new BoxPanel(Orientation.Vertical) {
+
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), new Label("Size:"), Swing.HStrut(5), SizeComboBox, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), AcceptButton, Swing.HStrut(5), CancelButton, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ }
+
+ contents = ShapeEditorPanel
+
+ listenTo(AcceptButton, CancelButton, SizeComboBox.selection)
+
+ reactions += {
+ case ButtonClicked(AcceptButton) =>
+ close()
+ case ButtonClicked(CancelButton) =>
+ SizeComboBox.item = current.toString
+ close()
+ case SelectionChanged(SizeComboBox) =>
+ }
+ }
+
+ class LabelPlacementDialog(identifier: String, current: VertexLabelPosition) extends Dialog {
+ title = s"Choose the label placement for $identifier"
+ modal = true
+
+ val positionOptions: Seq[String] = VertexLabelPosition.values.
+ map(vt => vt.toString).filterNot(s => s == "custom").toSeq
+
+
+ val AcceptButton = new Button("Accept")
+ val CancelButton = new Button("Cancel")
+ val PlacementComboBox = new ComboBox(positionOptions)
+ PlacementComboBox.selection.item = current.toString
+ defaultButton = Some(CancelButton)
+ val ShapeEditorPanel : BoxPanel = new BoxPanel(Orientation.Vertical) {
+
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), new Label("Label placement:"), Swing.HStrut(5), PlacementComboBox, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), AcceptButton, Swing.HStrut(5), CancelButton, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ }
+
+ contents = ShapeEditorPanel
+
+ listenTo(AcceptButton, CancelButton, PlacementComboBox.selection)
+
+ reactions += {
+ case ButtonClicked(AcceptButton) =>
+ close()
+ case ButtonClicked(CancelButton) =>
+ PlacementComboBox.item = current.toString
+ close()
+ case SelectionChanged(PlacementComboBox) =>
+ }
+ }
+
+ class DataTypePickingDialog(str: String, current: String) extends Dialog {
+ modal = true
+
+ private var currentForBox : String = current
+
+ private val currentParsedType : Option[ValueType] = {
+ try {
+ val types = CompositeExpression.parseTypes(current)
+ if(types.length == 1) {
+ currentForBox = ""
+ Some(types.head)
+ } else {
+ currentForBox = valueTypeVectorToString(types)
+ None
+ }
+ }
+ catch {
+ case GenericParseException(msg) => {
+ currentForBox = ""
+ None
+ }
+ }
+ }
+
+ val dataTypeOptions: Seq[String] = valueTypesAsHumanReadable.values.toSet.toSeq.sorted :+ "composite"
+ val AcceptButton = new Button("Accept")
+ val CancelButton = new Button("Cancel")
+ val DataComboBox = new ComboBox(dataTypeOptions)
+ DataComboBox.selection.item = {
+ if (currentParsedType.nonEmpty) {
+ current
+ } else {
+ "composite"
+ }
+ }
+ val CustomText = new TextField()
+ CustomText.text = currentForBox
+ val ShapeEditorPanel: BoxPanel = new BoxPanel(Orientation.Vertical) {
+
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), new Label("Data Type:"), Swing.HStrut(5), DataComboBox, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), new Label("Composite Type:"), Swing.HStrut(5), CustomText, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), AcceptButton, Swing.HStrut(5), CancelButton, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ }
+
+ enableCustom()
+ defaultButton = Some(CancelButton)
+
+
+ def enableCustom(): Unit = {
+ DataComboBox.selection.item match {
+ case "composite" => CustomText.enabled = true
+ case _ => CustomText.enabled = false
+ }
+ }
+
+ contents = ShapeEditorPanel
+
+ listenTo(AcceptButton, CancelButton, DataComboBox.selection)
+
+ var wasAccepted : Boolean = false
+
+ reactions += {
+ case ButtonClicked(AcceptButton) =>
+ wasAccepted = true
+ close()
+ case ButtonClicked(CancelButton) =>
+ wasAccepted = false
+ close()
+ case SelectionChanged(DataComboBox) =>
+ enableCustom()
+ }
+ }
+
+ // example: ColourPickingDialog("X", "foreground colour", Color(10,10,200))
+ class ColourPickingDialog(title: String, current: Color) extends Dialog {
+ modal = true
+
+ val currentColourName = colourToString(current)
+
+ val colourOptions: Seq[String] = approvedColours.keys.toSet.toSeq.sorted :+ "custom"
+ val AcceptButton = new Button("Accept")
+ val CancelButton = new Button("Cancel")
+ val ColourComboBox = new ComboBox(colourOptions)
+ ColourComboBox.selection.item = currentColourName
+ val CustomText = new TextField()
+ displayColour(current)
+ val ShapeEditorPanel: BoxPanel = new BoxPanel(Orientation.Vertical) {
+
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), new Label("Colour:"), Swing.HStrut(5), ColourComboBox, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), new Label("Custom data:"), Swing.HStrut(5), CustomText, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), AcceptButton, Swing.HStrut(5), CancelButton, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ }
+
+ enableCustom()
+ defaultButton = Some(CancelButton)
+
+ def displayColour(colour: Color): Unit = {
+ CustomText.text = f"#${0xFFFFFF & colour.hashCode()}%06X"
+ }
+
+ def enableCustom(): Unit = {
+ ColourComboBox.selection.item match {
+ case "custom" => CustomText.enabled = true
+ case _ => CustomText.enabled = false
+ }
+ }
+
+ contents = ShapeEditorPanel
+
+ listenTo(AcceptButton, CancelButton, ColourComboBox.selection)
+
+ reactions += {
+ case ButtonClicked(AcceptButton) =>
+ close()
+ case ButtonClicked(CancelButton) =>
+ ColourComboBox.item = currentColourName
+ close()
+ case SelectionChanged(ColourComboBox) =>
+ ColourComboBox.selection.item match {
+ case "custom" =>
+ case s => displayColour(approvedColours(s))
+ }
+ enableCustom()
+ }
+ }
+
+ class NodeShapeDialog(name: String, current: VertexShape, currentCustom: String) extends Dialog {
+ title = s"Choose the shape for node $name"
+ modal = true
+
+ val shapeOptions: Seq[String] = VertexShape.values.
+ map(vt => vt.toString).filterNot(s => s == "custom").toSeq
+
+
+ val AcceptButton = new Button("Accept")
+ val CancelButton = new Button("Cancel")
+ val ShapeComboBox = new ComboBox(shapeOptions)
+ ShapeComboBox.selection.item = current.toString
+ val CustomText = new TextField(currentCustom)
+ enableCustom()
+ defaultButton = Some(CancelButton)
+ val ShapeEditorPanel = new BoxPanel(Orientation.Vertical) {
+
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), new Label("Vertex shape:"), Swing.HStrut(5), ShapeComboBox, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), new Label("Custom data:"), Swing.HStrut(5), CustomText, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ contents += new BoxPanel(Orientation.Horizontal) {
+ contents += (Swing.HStrut(10), AcceptButton, Swing.HStrut(5), CancelButton, Swing.HStrut(10))
+ }
+ contents += Swing.VStrut(10)
+ }
+
+ def enableCustom(): Unit = {
+ ShapeComboBox.selection.item match {
+ case "custom" => CustomText.enabled = true
+ case _ => CustomText.enabled = false
+ }
+ }
+
+ contents = ShapeEditorPanel
+
+ listenTo(AcceptButton, CancelButton, ShapeComboBox.selection)
+
+ reactions += {
+ case ButtonClicked(AcceptButton) =>
+ close()
+ case ButtonClicked(CancelButton) =>
+ ShapeComboBox.item = current.toString
+ close()
+ case SelectionChanged(ShapeComboBox) =>
+ enableCustom()
+ }
+ }
+
+ class NodeEditor(nodeName: String, desc: VertexDesc) extends BoxPanel(Orientation.Vertical) {
+ // Note that GridPanel expands to fill parent
+ maximumSize = maxGridSize
+ contents += new GridPanel(5, 2) {
+
+ contents += new Label("Name")
+ contents += new Label(nodeName) // Currently doesn't support renaming nodes (and shouldn't?)
+
+ contents += new Label("Shape")
+ val EditShapeButton: Button = new Button(desc.style.shape.toString)
+ contents += EditShapeButton
+
+ contents += new Label("Colour")
+ val EditColourButton: Button = new Button(colourToString(desc.style.fillColor))
+ contents += EditColourButton
+
+ contents += new Label("Values")
+ val EditValueTypeButton: Button = new Button(valueTypeVectorToString(desc.value.typ))
+ contents += EditValueTypeButton
+
+ contents += new Label("Border width")
+ val EditBorderWidthButton: Button = new Button(desc.style.strokeWidth.toString)
+ contents += EditBorderWidthButton
+
+
+ //contents += new Label("Label placement")
+ //val EditPlacementButton: Button = new Button(desc.style.labelPosition.toString)
+ //contents += EditPlacementButton
+ //contents += new Label("Example")
+ //contents += new Label("Example here")
+ listenTo(EditShapeButton, EditColourButton, EditValueTypeButton, EditBorderWidthButton)
+ reactions += {
+ case ButtonClicked(EditShapeButton) =>
+ chooseNodeShape(nodeName, (desc.style.shape, desc.style.customShape))
+ case ButtonClicked(EditColourButton) =>
+ chooseNodeColour(nodeName, desc.style.fillColor)
+ case ButtonClicked(EditValueTypeButton) =>
+ chooseNodeDataType(nodeName, TheoryEditPanel.valueTypeVectorToString(desc.value.typ))
+ case ButtonClicked(EditBorderWidthButton) =>
+ chooseNodeBorderWidth(nodeName, desc.style.strokeWidth)
+ }
+ }
+ }
+val maxGridSize = new Dimension(scaleInt(200), scaleInt(100))
+
+ class EdgeEditor(edgeName: String, desc: EdgeDesc) extends BoxPanel(Orientation.Vertical) {
+ // Note that GridPanel expands to fill parent
+ maximumSize = maxGridSize
+ contents += new GridPanel(3, 2) {
+ contents += new Label("Name")
+ contents += new Label(edgeName) // Currently doesn't support renaming nodes (and shouldn't?)
+ contents += new Label("Width")
+ val EditShapeButton: Button = new Button(desc.style.strokeWidth.toString)
+ contents += EditShapeButton
+ contents += new Label("Colour")
+ val EditColourButton: Button = new Button(colourToString(desc.style.strokeColor))
+ contents += EditColourButton
+ //contents += new Label("Label placement")
+ //val EditPlacementButton: Button = new Button(desc.style.labelPosition.toString)
+ //contents += EditPlacementButton
+ //contents += new Label("Example")
+ //contents += new Label("Example here")
+ listenTo(EditShapeButton, EditColourButton)
+ reactions += {
+ case ButtonClicked(EditShapeButton) =>
+ chooseEdgeWidth(edgeName, desc.style.strokeWidth)
+ case ButtonClicked(EditColourButton) =>
+ chooseEdgeColour(edgeName, desc.style.strokeColor)
+ //case ButtonClicked(EditPlacementButton) =>
+ // chooseLabelPlacement(name, desc.style.labelPosition)
+ }
+ }
+ }
+
+ createPageComponents()
+
+
+ add(TopScrollablePane, BorderPanel.Position.Center)
+
+ add(Toolbar, BorderPanel.Position.North)
+
+ listenTo(document)
+
+ reactions += {
+ case TheoryChanged() =>
+ alert("Theory changed", Elevation.DEBUG)
+ createPageComponents()
+
+ }
+}
+
+object TheoryEditPanel {
+
+
+ implicit private def valueTypeVectorToString(vs: Vector[ValueType]) : String =
+ vs.map(valueTypesAsHumanReadable).mkString("",", ","")
+
+ def colourMute(c: Color): Color = {
+ def m(i: Int): Int = math.floor(i * 1).toInt
+
+ new Color(m(c.getRed), m(c.getGreen), m(c.getBlue))
+ }
+
+ val valueTypesAsHumanReadable: Map[ValueType, String] = ValueType.values.toList map (v =>
+ v -> (v match {
+ case ValueType.AngleExpr => "angle"
+ case ValueType.Boolean => "boolean"
+ case ValueType.Integer => "integer"
+ case ValueType.Rational => "rational"
+ case ValueType.String => "string"
+ case ValueType.Long => "long"
+ case ValueType.Enum => "string"
+ case ValueType.Empty => "empty"
+ })) toMap
+
+ val approvedColours: Map[String, Color] = Map(
+ "red" -> colourMute(new Color(255, 0, 0)),
+ "green" -> colourMute(new Color(0, 255, 0)),
+ "blue" -> colourMute(new Color(0, 0, 255)),
+ "white" -> colourMute(new Color(255, 255, 255)),
+ "black" -> colourMute(new Color(0, 0, 0)),
+ "yellow" -> colourMute(new Color(255, 255, 0)),
+ "magenta" -> colourMute(new Color(255, 0, 255)),
+ "cyan" -> colourMute(new Color(0, 255, 255))
+ )
+
+}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/gui/Transformer.scala b/scala/src/main/scala/quanto/gui/Transformer.scala
index ff75181c..dbed175c 100644
--- a/scala/src/main/scala/quanto/gui/Transformer.scala
+++ b/scala/src/main/scala/quanto/gui/Transformer.scala
@@ -3,11 +3,12 @@ package quanto.gui
class Transformer(
var scale: Double = 50.0f,
- var origin: (Double, Double) = (250.0f, 250.0f)
+ var screenDrawOrigin: (Double, Double) = (250f, 250f)
)
{
- def toScreen(pt: (Double,Double)) = (pt._1 * scale + origin._1, origin._2 - pt._2 * scale)
- def fromScreen(pt: (Double,Double)) = ((pt._1 - origin._1) / scale, ((-1 * pt._2) + origin._2) / scale)
+ // screenDrawOrigin is where _on the screen_ _in pixels_ we want the origin to be drawn
+ def toScreen(pt: (Double,Double)) = (pt._1 * scale + screenDrawOrigin._1, screenDrawOrigin._2 - pt._2 * scale)
+ def fromScreen(pt: (Double,Double)) = ((pt._1 - screenDrawOrigin._1) / scale, ((-1 * pt._2) + screenDrawOrigin._2) / scale)
def scaleToScreen(x: Double) = x * scale
def scaleFromScreen(x: Double) = x / scale
diff --git a/scala/src/main/scala/quanto/gui/graphview/BBoxDisplayData.scala b/scala/src/main/scala/quanto/gui/graphview/BBoxDisplayData.scala
index 34ec21b4..f984093b 100644
--- a/scala/src/main/scala/quanto/gui/graphview/BBoxDisplayData.scala
+++ b/scala/src/main/scala/quanto/gui/graphview/BBoxDisplayData.scala
@@ -29,17 +29,17 @@ trait BBoxDisplayData { self: VertexDisplayData =>
protected def computeBBoxDisplay() {
bboxDisplay.clear()
- var offset = Math.max(boundsForVertexSet(graph.verts).getMaxX, trans.origin._1)
+ var offset = Math.max(boundsForVertexSet(graph.verts).getMaxX, trans.screenDrawOrigin._1)
// used to compute relative padding sizes
- val em = trans.scaleToScreen(0.1)
+ val em = trans.scaleToScreen(0.25)
graph.bboxesChildrenFirst.foreach { bbox =>
val vset = graph.contents(bbox)
val rect = if (vset.isEmpty) {
- offset += 8*em
- new Rectangle2D.Double(offset, trans.origin._2 - 2*em, 4*em, 4*em)
+ offset += 4.0*em
+ new Rectangle2D.Double(offset, trans.screenDrawOrigin._2 - 2*em, 4*em, 4*em)
} else {
/* bounds determined by vertices of bbox */
@@ -53,18 +53,18 @@ trait BBoxDisplayData { self: VertexDisplayData =>
val rect = bbd.rect
if(bounds.contains(rect)) {
- val ulx = min(rect.getMinX - 5.0*em, bounds.getMinX)
- val uly = min(rect.getMinY - 5.0*em, bounds.getMinY)
- val lrx = max(rect.getMaxX + 5.0*em, bounds.getMaxX)
- val lry = max(rect.getMaxY + 5.0*em, bounds.getMaxY)
+ val ulx = min(rect.getMinX - em, bounds.getMinX)
+ val uly = min(rect.getMinY - em, bounds.getMinY)
+ val lrx = max(rect.getMaxX + em, bounds.getMaxX)
+ val lry = max(rect.getMaxY + em, bounds.getMaxY)
bounds = new Rectangle2D.Double(ulx, uly, lrx - ulx, lry - uly)
}
else if(rect.contains(bounds)) {
- val ulx = min(rect.getMinX, bounds.getMinX - 5.0*em)
- val uly = min(rect.getMinY, bounds.getMinY - 5.0*em)
- val lrx = max(rect.getMaxX, bounds.getMaxX + 5.0*em)
- val lry = max(rect.getMaxY, bounds.getMaxY + 5.0*em)
+ val ulx = min(rect.getMinX, bounds.getMinX - em)
+ val uly = min(rect.getMinY, bounds.getMinY - em)
+ val lrx = max(rect.getMaxX, bounds.getMaxX + em)
+ val lry = max(rect.getMaxY, bounds.getMaxY + em)
val new_bounds = new Rectangle2D.Double(ulx, uly, lrx - ulx, lry - uly)
bboxDisplay += bb -> BBDisplay(new_bounds)
diff --git a/scala/src/main/scala/quanto/gui/graphview/EdgeDisplayData.scala b/scala/src/main/scala/quanto/gui/graphview/EdgeDisplayData.scala
index 8f80c395..9e6f0414 100644
--- a/scala/src/main/scala/quanto/gui/graphview/EdgeDisplayData.scala
+++ b/scala/src/main/scala/quanto/gui/graphview/EdgeDisplayData.scala
@@ -3,9 +3,18 @@ package quanto.gui.graphview
import quanto.data._
import quanto.util.RichCubicCurve._
import java.awt.geom._
+
+import quanto.data.Theory.EdgeDesc
+import quanto.util.UserOptions
+
import math._
+import scala.swing.Color
-case class EDisplay(path: Path2D.Double, lines: List[Line2D.Double], label: Option[LabelDisplayData]) {
+case class EDisplay(path: Path2D.Double,
+ width: Int,
+ color: Color,
+ lines: List[Line2D.Double],
+ label: Option[LabelDisplayData]) {
def pointHit(pt: Point2D) = {
lines exists { l =>
//println("line starts " + l.getP1 + " ends " + l.getP2 + ", distance to " + pt + " is " + l.ptSegDistSq(pt))
@@ -23,6 +32,7 @@ trait EdgeDisplayData { self: GraphView with VertexDisplayData =>
for ((v1,sd) <- graph.vdata; (v2,td) <- graph.vdata if v1 <= v2) {
val edges = graph.source.codf(v1) intersect graph.target.codf(v2)
val rEdges = if (v1 == v2) Set[EName]() else graph.target.codf(v1) intersect graph.source.codf(v2)
+ // Count total edges v1 -> v2, and reverse edges (v2 -> v1)
val numEdges = edges.size + rEdges.size
if (numEdges != 0) {
@@ -115,7 +125,7 @@ trait EdgeDisplayData { self: GraphView with VertexDisplayData =>
}
val labelDisplay = edgeData.typeInfo.value.typ match {
- case Theory.ValueType.String =>
+ case Vector(Theory.ValueType.String) =>
val fm = peer.getGraphics.getFontMetrics(GraphView.EdgeLabelFont)
val text = edgeData.label
if (text == "") None
@@ -126,9 +136,10 @@ trait EdgeDisplayData { self: GraphView with VertexDisplayData =>
edgeData.typeInfo.style.labelBackgroundColor))
case _ => None
}
+ val width = UserOptions.scaleInt(edgeData.typeInfo.style.strokeWidth)
+ val color = edgeData.typeInfo.style.strokeColor
-
- edgeDisplay(e) = EDisplay(p, lines, labelDisplay)
+ edgeDisplay(e) = EDisplay(p, width, color, lines, labelDisplay)
i += 1
}
diff --git a/scala/src/main/scala/quanto/gui/graphview/GraphView.scala b/scala/src/main/scala/quanto/gui/graphview/GraphView.scala
index 8777b276..f4751913 100644
--- a/scala/src/main/scala/quanto/gui/graphview/GraphView.scala
+++ b/scala/src/main/scala/quanto/gui/graphview/GraphView.scala
@@ -15,7 +15,8 @@ import java.awt.font.TextLayout
import java.io.File
import quanto.util.FileHelper.printToFile
-import quanto.util.UserOptions
+import quanto.util.UserAlerts.alert
+import quanto.util.{UserAlerts, UserOptions}
import quanto.util.json.JsonString
@@ -60,6 +61,7 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
// gets called when the component is first painted
lazy val init = {
+ focusOnGraph()
resizeViewToFit()
}
@@ -92,11 +94,39 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
private var _zoom = 1.0
def zoom = _zoom
def zoom_=(d: Double) {
+ // log where the middle of the screen currently is
_zoom = d
trans.scale = _zoom * 50.0
- trans.origin = (_zoom * 250.0, _zoom * 250.0)
- invalidateGraph(clearSelection = false)
resizeViewToFit()
+ invalidateGraph(false)
+ repaint()
+ }
+
+private def viewportOffset (): (Double, Double) = {
+ val viewRect = peer.getVisibleRect
+ (viewRect.width / 2, viewRect.height / 2)
+}
+
+ def graphFocus : (Double, Double) = {
+ val viewShift = viewportOffset()
+ trans.fromScreen(viewShift)
+ }
+
+ private def renderOriginAt(x : Double, y: Double): Unit ={
+ trans.screenDrawOrigin = (x, y)
+ }
+
+ def graphFocus_=(pointOnGraph : (Double, Double)) : Unit = {
+
+ val screenFocus = (trans scaleToScreen pointOnGraph._1, trans scaleToScreen pointOnGraph._2)
+
+ val viewShift = viewportOffset()
+
+ renderOriginAt(0-screenFocus._1 + viewShift._1, 0-screenFocus._2 + viewShift._2)
+
+ resizeViewToFit()
+ invalidateGraph(false)
+ revalidate()
repaint()
}
@@ -142,6 +172,8 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
}
def resizeViewToFit() {
+
+ val bufferFocus = graphFocus
// top left and bottom right of bounds, in screen coordinates
val graphTopLeft = graph.vdata.foldLeft(0.0,0.0) { (c,v) =>
(min(c._1, v._2.coord._1), max(c._2, v._2.coord._2))
@@ -152,14 +184,14 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
} match {case (x,y) => trans toScreen (x + 1.0, y - 1.0)}
// default bounds, based on the current position of the origin and the size of the visible region
- val vRect = peer.getVisibleRect
- val defaultTopLeft = (trans.origin._1 - (vRect.getWidth/2.0), trans.origin._2 - (vRect.getHeight/2.0))
- val defaultBottomRight = (trans.origin._1 + (vRect.getWidth/2.0), trans.origin._2 + (vRect.getHeight/2.0))
+ val offset = viewportOffset()
+ val screenTopLeft = (trans.screenDrawOrigin._1 - offset._1, trans.screenDrawOrigin._2 - offset._2)
+ val screenBottomRight = (trans.screenDrawOrigin._1 + offset._1, trans.screenDrawOrigin._2 + offset._2)
- val topLeft = (min(graphTopLeft._1, defaultTopLeft._1),
- min(graphTopLeft._2, defaultTopLeft._2))
- val bottomRight = (max(graphBottomRight._1, defaultBottomRight._1),
- max(graphBottomRight._2, defaultBottomRight._2))
+ val topLeft = (min(graphTopLeft._1, screenTopLeft._1),
+ min(graphTopLeft._2, screenTopLeft._2))
+ val bottomRight = (max(graphBottomRight._1, screenBottomRight._1),
+ max(graphBottomRight._2, screenBottomRight._2))
val (w,h) = (bottomRight._1 - topLeft._1,
bottomRight._2 - topLeft._2)
@@ -168,10 +200,12 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
(preferredSize.width != w) || (preferredSize.height != h)
if (changed) {
- trans.origin = (trans.origin._1 - topLeft._1, trans.origin._2 - topLeft._2)
+ trans.screenDrawOrigin = (trans.screenDrawOrigin._1 - topLeft._1, trans.screenDrawOrigin._2 - topLeft._2)
+ //graphFocus = bufferFocus
preferredSize = new Dimension(w.toInt, h.toInt)
invalidateGraph(clearSelection = false)
revalidate()
+ repaint()
}
}
@@ -190,6 +224,48 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
super.repaint()
}
+ def focusOnGraph(): Unit = {
+ // top left and bottom right of bounds, in screen coordinates
+ val graphTopLeft = graph.vdata.foldLeft(0.0, 0.0) { (c, v) =>
+ (min(c._1, v._2.coord._1), max(c._2, v._2.coord._2))
+ }
+
+
+ val graphBottomRight = graph.vdata.foldLeft((0.0, 0.0)) { (c, v) =>
+ (max(c._1, v._2.coord._1), min(c._2, v._2.coord._2))
+ }
+
+
+ val defaultTopLeft = (-5,5)
+ val defaultBottomRight = (5,-5)
+
+ val topLeft = (min(graphTopLeft._1, defaultTopLeft._1),
+ max(graphTopLeft._2, defaultTopLeft._2))
+ val bottomRight = (max(graphBottomRight._1, defaultBottomRight._1),
+ min(graphBottomRight._2, defaultBottomRight._2))
+
+ val (w, h) = (bottomRight._1 - topLeft._1,
+ topLeft._2 - bottomRight._2)
+
+ val graphCentre = ((topLeft._1 + bottomRight._1) / 2, (topLeft._2 + bottomRight._2) / 2)
+
+ // default bounds, based on the current position of the origin and the size of the visible region
+ val vRect = peer.getVisibleRect
+
+ val widthOnScreen = trans scaleToScreen w
+ val heightOnScreen = trans scaleToScreen h
+
+ val widthChange = widthOnScreen / vRect.width
+ val heightChange = heightOnScreen / vRect.height
+ val zoomChangeNeeded = max(widthChange, heightChange)
+
+ zoom = 0.9*zoom / zoomChangeNeeded
+
+ //trans.scale = trans.scale * zoomChangeNeeded
+
+ graphFocus = graphCentre
+ }
+
override def paintComponent(g: Graphics2D) {
super.paintComponent(g)
g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON)
@@ -265,16 +341,10 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
for ((e, ed) <- edgeDisplay) {
if (selectedEdges contains e) {
g.setColor(Color.BLUE)
- g.setStroke(new BasicStroke(2))
+ g.setStroke(new BasicStroke(ed.width))
} else {
- if (graph.edata(e).isDirected) {
- g.setColor(Color.GRAY)
- g.setStroke(new BasicStroke(1))
- } else {
- g.setColor(Color.BLACK)
- g.setStroke(new BasicStroke(2))
- }
-
+ g.setColor(ed.color)
+ g.setStroke(new BasicStroke(ed.width))
}
g.draw(ed.path)
@@ -294,7 +364,7 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
g.setStroke(new BasicStroke(1))
var a = g.getColor
- for ((v, VDisplay(shape,color,label)) <- vertexDisplay) {
+ for ((v, VDisplay(shape, strokeWidth, color,label)) <- vertexDisplay) {
/* draw red line if vertex coordinates are within !-box rectangle
* but the vertex is not a member of the !-box and also write
@@ -343,10 +413,10 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
if (selectedVerts contains v) {
g.setColor(Color.BLUE)
- g.setStroke(new BasicStroke(2))
+ g.setStroke(new BasicStroke(strokeWidth + 1))
} else {
g.setColor(Color.BLACK)
- g.setStroke(new BasicStroke(1))
+ g.setStroke(new BasicStroke(strokeWidth))
}
g.draw(shape)
@@ -501,9 +571,12 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
}
}
+ val names = scala.collection.mutable.Map[String,Int]()
+
/* Output view to a tikzit-readable file */
printToFile(f, append)(p => {
- p.println("\\begin{tikzpicture}[baseline={([yshift=-.5ex]current bounding box.center)}]")
+ //p.println("\\begin{tikzpicture}[baseline={([yshift=-.5ex]current bounding box.center)}]")
+ p.println("\\begin{tikzpicture}[quanto]")
p.println("\t\\begin{pgfonlayer}{nodelayer}")
/* fill in all vertices */
@@ -513,7 +586,9 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
case _ : WireV => "wire"
}
- val number = vn.toString
+ names += vn.toString -> names.size
+
+ val number = names.size - 1
val disp_rec = vertexDisplay(vn).shape.getBounds
val trans_coord = trans.fromScreen(disp_rec.getCenterX, disp_rec.getCenterY)
min_max(trans_coord)
@@ -530,25 +605,29 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
/* fill in corners of !-boxes */
for ((bbn,bbd) <- bboxDisplay) {
- val number_ul = bbn.toString + "ul"
+ names += (bbn.toString + "ul") -> names.size
+ val number_ul = names.size - 1
val trans_coord_ul = trans.fromScreen(bbd.rect.getMinX, bbd.rect.getMinY)
min_max(trans_coord_ul)
val coord_ul = coordToString(trans_coord_ul)
p.println("\t\t\\node [style=bbox] (" + number_ul + ") at " + coord_ul + " {};")
- val number_ur = bbn.toString + "ur"
+ names += (bbn.toString + "ur") -> names.size
+ val number_ur = names.size - 1
val trans_coord_ur = trans.fromScreen(bbd.rect.getMaxX, bbd.rect.getMinY)
min_max(trans_coord_ur)
val coord_ur = coordToString(trans_coord_ur)
p.println("\t\t\\node [style=none] (" + number_ur + ") at " + coord_ur + " {};")
- val number_ll = bbn.toString + "ll"
+ names += (bbn.toString + "ll") -> names.size
+ val number_ll = names.size - 1
val trans_coord_ll = trans.fromScreen(bbd.rect.getMinX, bbd.rect.getMaxY)
min_max(trans_coord_ll)
val coord_ll = coordToString(trans_coord_ll)
p.println("\t\t\\node [style=none] (" + number_ll + ") at " + coord_ll + " {};")
- val number_lr = bbn.toString + "lr"
+ names += (bbn.toString + "lr") -> names.size
+ val number_lr = names.size - 1
val trans_coord_lr = trans.fromScreen(bbd.rect.getMaxX, bbd.rect.getMaxY)
min_max(trans_coord_lr)
val coord_lr = coordToString(trans_coord_lr)
@@ -556,16 +635,16 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
}
/* output 4 nodes used for padding*/
- val pad_size = 1.0
- minX -= pad_size
- maxX += pad_size
- minY -= pad_size
- maxY += pad_size
-
- p.println("\t\t\\node [style=none] (padl) at " + coordToString(minX, minY) + " {};")
- p.println("\t\t\\node [style=none] (padr) at " + coordToString(maxX, maxY) + " {};")
- p.println("\t\t\\node [style=none] (padu) at " + coordToString(minX, maxY) + " {};")
- p.println("\t\t\\node [style=none] (padd) at " + coordToString(maxX, minY) + " {};")
+// val pad_size = 1.0
+// minX -= pad_size
+// maxX += pad_size
+// minY -= pad_size
+// maxY += pad_size
+//
+// p.println("\t\t\\node [style=none] (padl) at " + coordToString(minX, minY) + " {};")
+// p.println("\t\t\\node [style=none] (padr) at " + coordToString(maxX, maxY) + " {};")
+// p.println("\t\t\\node [style=none] (padu) at " + coordToString(minX, maxY) + " {};")
+// p.println("\t\t\\node [style=none] (padd) at " + coordToString(maxX, minY) + " {};")
p.println("\t\\end{pgfonlayer}")
@@ -574,9 +653,9 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
/* fill in all graph edges */
for (edge_set <- graph.edgePartition) {
val edge_arr : Array[EName] = edge_set.toArray
- val size = edge_arr.size
- val canonical_source = graph.source(edge_arr(0)).toString
- val canonical_target = graph.target(edge_arr(0)).toString
+ val size = edge_arr.length
+ val canonical_source = names(graph.source(edge_arr(0)).toString)
+ val canonical_target = names(graph.target(edge_arr(0)).toString)
if (canonical_source != canonical_target) {
/* edges are between different nodes */
@@ -588,7 +667,9 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
val en = edge_arr(0)
val ed = graph.edata(en)
val style = if (ed.isDirected) "directed" else "simple"
- p.println("\t\t\\draw [style=" + style + "] (" + graph.source(en).toString + ") to (" + graph.target(en).toString + ");" )
+ p.println("\t\t\\draw [style=" + style + "] (" +
+ names(graph.source(en).toString) + ") to (" +
+ names(graph.target(en).toString) + ");" )
start = 1
}
@@ -597,13 +678,13 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
var right_left = "left="
/* draw the rest of the edges as arcs by setting the bend angle */
- for (i <- start to size-1) {
+ for (i <- start until size) {
val en = edge_arr(i)
val ed = graph.edata(en)
val style = if (ed.isDirected) "directed" else "simple"
val angle = angle_it.toString
- val source = graph.source(en).toString
- val target = graph.target(en).toString
+ val source = names(graph.source(en).toString)
+ val target = names(graph.target(en).toString)
/* alternate bending left or right */
if ((i-start) % 2 == 0) {
@@ -625,7 +706,7 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
/* edges have same source and target, so we need to output loops */
var looseness = 4.5
- for (i <- 0 to size-1) {
+ for (i <- 0 until size) {
val en = edge_arr(i)
val ed = graph.edata(en)
val style = if (ed.isDirected) "directed" else "simple"
@@ -643,10 +724,10 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
for ((bbn, _) <- bboxDisplay) {
/* fill in edges connecting !-box corners */
- val number_ul = bbn.toString + "ul.center"
- val number_ur = bbn.toString + "ur.center"
- val number_ll = bbn.toString + "ll.center"
- val number_lr = bbn.toString + "lr.center"
+ val number_ul = names(bbn.toString + "ul") + ".center"
+ val number_ur = names(bbn.toString + "ur") + ".center"
+ val number_ll = names(bbn.toString + "ll") + ".center"
+ val number_lr = names(bbn.toString + "lr") + ".center"
p.println("\t\t\\draw [style=blue] (" + number_ul + ") to (" + number_ur + ");" )
p.println("\t\t\\draw [style=blue] (" + number_ul + ") to (" + number_ll + ");" )
p.println("\t\t\\draw [style=blue] (" + number_ll + ") to (" + number_lr + ");" )
@@ -654,10 +735,9 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel
/* draw edges indicating nested !-boxes */
graph.bboxParent.get(bbn) match {
- case Some(bb_parent) => {
- val parent_number_ul = bb_parent.toString + "ul.center"
+ case Some(bb_parent) =>
+ val parent_number_ul = names(bb_parent.toString + "ul") + ".center"
p.println("\t\t\\draw [style=blue] (" + number_ul + ") to (" + parent_number_ul + ");" )
- }
case None =>
}
}
diff --git a/scala/src/main/scala/quanto/gui/graphview/VertexDisplayData.scala b/scala/src/main/scala/quanto/gui/graphview/VertexDisplayData.scala
index 8afa91a6..e310794b 100644
--- a/scala/src/main/scala/quanto/gui/graphview/VertexDisplayData.scala
+++ b/scala/src/main/scala/quanto/gui/graphview/VertexDisplayData.scala
@@ -1,13 +1,15 @@
package quanto.gui.graphview
-import java.awt.geom.{Rectangle2D, Ellipse2D, Point2D}
-import java.awt.{FontMetrics, Color, Shape}
+import java.awt.geom.{Ellipse2D, Point2D, Rectangle2D}
+import java.awt.{Color, FontMetrics, Shape}
+
import math._
import quanto.data._
import quanto.gui._
import quanto.core.data.TexConstants
+import quanto.util.UserOptions
-case class VDisplay(shape: Shape, color: Color, label: Option[LabelDisplayData]) {
+case class VDisplay(shape: Shape, borderWidth: Int, color: Color, label: Option[LabelDisplayData]) {
def pointHit(pt: Point2D) = {
val bnd = shape.getBounds2D
pt.getX >= bnd.getMinX - GraphView.VertexSelectionTolerence &&
@@ -52,13 +54,14 @@ trait VertexDisplayData { self: GraphView =>
protected def computeVertexDisplay() {
val trWireWidth = 0.707 * (trans scaleToScreen GraphView.WireRadius)
+ // Go through every vertex in the graph
for ((v,data) <- graph.vdata if !vertexDisplay.contains(v)) {
val (x,y) = trans toScreen data.coord
vertexDisplay(v) = data match {
case vertexData : NodeV =>
val style = vertexData.typeInfo.style
- val label = if (vertexData.hasAngle) vertexData.angle.toString else vertexData.value
+ val label = if (vertexData.hasValue) vertexData.phaseData.toString else vertexData.value
val text = if(zoom < GraphView.zoomCutOut && label != "") "~"
else TexConstants.translate(label)
/*vertexData.typeInfo.value.typ match {
@@ -66,19 +69,26 @@ trait VertexDisplayData { self: GraphView =>
case _ => ""
}*/
+ // Build the label
val fm = peer.getGraphics.getFontMetrics(GraphView.VertexLabelFont)
val labelDisplay = LabelDisplayData(
- text, (x,y), fm,
+ text, (x, y), fm,
vertexData.typeInfo.style.labelForegroundColor,
vertexData.typeInfo.style.labelBackgroundColor)
-
+ // Build the shape around the label
val shape = style.shape match {
case Theory.VertexShape.Rectangle =>
+ val buffer = trans.scaleToScreen(GraphView.NodeTextPadding)
+ val height = labelDisplay.bounds.getHeight + buffer
+ val widthFromLabel = labelDisplay.bounds.getWidth + buffer
+ // Default to square if no data, and stretch horizontally if needed
+ val width = max(widthFromLabel, height)
+
+ val x = labelDisplay.bounds.getMinX - (width - labelDisplay.bounds.getWidth) / 2.0
+ val y = labelDisplay.bounds.getMinY - (height - labelDisplay.bounds.getHeight) / 2.0
- new Rectangle2D.Double(
- labelDisplay.bounds.getMinX - 5.0, labelDisplay.bounds.getMinY - 3.0,
- labelDisplay.bounds.getWidth + 10.0, labelDisplay.bounds.getHeight + 6.0)
+ new Rectangle2D.Double(x, y, width, height)
case Theory.VertexShape.Circle =>
// radius should fit to label if required
val r = max(
@@ -86,30 +96,38 @@ trait VertexDisplayData { self: GraphView =>
trans.scaleToScreen(GraphView.NodeRadius)
)
- new Ellipse2D.Double(
- labelDisplay.bounds.getCenterX - r,
- labelDisplay.bounds.getCenterY -r,
- 2.0 * r, 2.0 * r)
+ val midX = labelDisplay.bounds.getCenterX - r
+ val midY = labelDisplay.bounds.getCenterY - r
+ new Ellipse2D.Double(midX, midY, 2.0 * r, 2.0 * r)
case _ => throw new Exception("Shape not supported yet")
}
- VDisplay(shape, style.fillColor, Some(labelDisplay))
+ VDisplay(shape, style.strokeWidth, style.fillColor, Some(labelDisplay))
case _: WireV =>
VDisplay(
new Rectangle2D.Double(
x - trWireWidth, y - trWireWidth,
2.0 * trWireWidth, 2.0 * trWireWidth),
- Color.GRAY,None)
+ 1,
+ Color.GRAY, None)
}
}
}
- protected def boundsForVertexSet(vset: Set[VName]) = {
+ protected def boundsForVertexSet(vset: Set[VName]): Rectangle2D.Double = {
var init = false
var ulx,uly,lrx,lry = 0.0
+ val em = trans.scaleToScreen(0.25)
+
vset.foreach { v =>
- val rect = vertexDisplay(v).shape.getBounds
+ // grow the bounding box until it snaps to the grid
+ val bds = vertexDisplay(v).shape.getBounds
+ val rx = Math.ceil(bds.width / (2.0 * em)) * em
+ val ry = Math.ceil(bds.height / (2.0 * em)) * em
+ val p = trans toScreen graph.vdata(v).coord
+ val rect = new Rectangle2D.Double(p._1 - rx, p._2 - ry, 2.0 * rx, 2.0 * ry)
+
if (init) {
ulx = min(ulx, rect.getX)
uly = min(uly, rect.getY)
@@ -125,9 +143,8 @@ trait VertexDisplayData { self: GraphView =>
}
val bounds = new Rectangle2D.Double(ulx, uly, lrx - ulx, lry - uly)
- val em = trans.scaleToScreen(0.1)
- val p = (bounds.getX - 3*em, bounds.getY - 3*em)
- val q = (bounds.getWidth + 6*em, bounds.getHeight + 6*em)
+ val p = (bounds.getX - em, bounds.getY - em)
+ val q = (bounds.getWidth + 2*em, bounds.getHeight + 2*em)
new Rectangle2D.Double(p._1, p._2, q._1, q._2)
}
diff --git a/scala/src/main/scala/quanto/gui/histview/HistView.scala b/scala/src/main/scala/quanto/gui/histview/HistView.scala
index d92b74db..aa7686c7 100644
--- a/scala/src/main/scala/quanto/gui/histview/HistView.scala
+++ b/scala/src/main/scala/quanto/gui/histview/HistView.scala
@@ -31,11 +31,10 @@ class HistView[A <: HistNode](data: TreeSeq[A]) extends ListView[(Seq[TreeSeq.De
}
renderer = new ListView.Renderer[(Seq[TreeSeq.Decoration[A]], A)] {
- def componentFor(list: ListView[_], isSelected: Boolean,
- focused: Boolean, a: (Seq[TreeSeq.Decoration[A]], A), index: Int): Component =
+ override def componentFor(list: ListView[_ <: (Seq[TreeSeq.Decoration[A]], A)], isSelected: Boolean, focused: Boolean, a: (Seq[TreeSeq.Decoration[A]], A), index: Int): Component =
{
- if (itemWidth == -1) computeItemWidth()
- new HistViewItem[A](a._1, a._2, isSelected, new Dimension(itemWidth,scaleInt(30)))
+ if (itemWidth == -1) computeItemWidth()
+ new HistViewItem[A](a._1, a._2, isSelected, new Dimension(itemWidth,scaleInt(30)))
}
}
@@ -48,6 +47,10 @@ class HistView[A <: HistNode](data: TreeSeq[A]) extends ListView[(Seq[TreeSeq.De
selectIndices()
}
+ def selectedIndex() : Int = {
+ peer.getSelectedIndex
+ }
+
def selectedNode: Option[A] =
if (selection.indices.isEmpty) None
else Some(treeData.toSeq(selection.indices.head))
diff --git a/scala/src/main/scala/quanto/layout/Constraints.scala b/scala/src/main/scala/quanto/layout/Constraints.scala
index b7a1ad37..98a71552 100644
--- a/scala/src/main/scala/quanto/layout/Constraints.scala
+++ b/scala/src/main/scala/quanto/layout/Constraints.scala
@@ -1,58 +1,25 @@
package quanto.layout
-import quanto.data._
-import quanto.data.Graph
-import quanto.data.VName
+import quanto.data.{Graph, VName}
class ConstraintException(msg: String) extends Exception(msg)
/**
- * A mixin for GraphLayouts which provides distance-based constraint functionality as in
- * [1] "Scalable, Versatile and Simple Constrained Graph Layout", Dwyer 2009
- */
+ * A mixin for GraphLayouts which provides distance-based constraint functionality as in
+ * [1] "Scalable, Versatile and Simple Constrained Graph Layout", Dwyer 2009
+ */
trait Constraints extends GraphLayout {
- def alpha: Double // cooling factor for soft constraints
+ val constraints = new ConstraintSeq
var constraintIterations = 10
var bug = false
- val constraints = new ConstraintSeq
+
+ def alpha: Double // cooling factor for soft constraints
override def initialize(g: Graph, randomCoords: Boolean = true) {
super.initialize(g, randomCoords)
constraints.clear()
}
- def isConstraintSatisfied (c: Constraint): Boolean = {
- val (v1x,v1y) = coord(c.v1)
- val (v2x,v2y) = coord(c.v2)
- val (v1v2x,v1v2y) = (v2x-v1x, v2y-v1y)
-
- val dir = c.direction match {
- case Some(d) => d
- case None => (0.0,0.0)
- }
-
- val length = projectLength((v1v2x,v1v2y),dir)
- if (c.order == 0) projectLength((v1v2x,v1v2y),dir) == c.length
- else if(c.order > 0) projectLength((v1v2x,v1v2y),dir) >= c.length
- else projectLength((v1v2x,v1v2y),dir) <= c.length
-
- }
- // project v1 on v2 // direction is unit vector
- def projectVector (v1 :(Double,Double), v2 : (Double,Double)) : (Double,Double)= {
- val length = vectorProduct(v1,v2)
- (length * v2._1, length * v2._2)
- }
-
- def projectLength (v1 :(Double,Double), v2 : (Double,Double)) : Double = {
- val length = vectorProduct(v1,v2)
- length
- }
-
- def vectorProduct (v1 : (Double,Double), v2 : (Double, Double)) : Double = {
- v1._1*v2._1+v1._2*v2._2
- }
-
-
def projectConstraints() {
var feasible = false // flag for all constraints satisfied
var maxLayer = constraints.currentLayer
@@ -65,78 +32,80 @@ trait Constraints extends GraphLayout {
iteration = 0
}
- for ((constraint,layer) <- constraints; if layer <= maxLayer) {
-
- val (p1,p2) = (coord(constraint.v1), coord(constraint.v2))
-
+ for ((constraint, layer) <- constraints; if layer <= maxLayer) {
+
+ val (p1, p2) = (coord(constraint.v1), coord(constraint.v2))
+
val shift = constraint.direction match {
case Some(dir) =>
- // val (dx,dy) = ((p2._1 - p1._1) * dir._1, (p2._2 - p1._2) * dir._2)
+ // val (dx,dy) = ((p2._1 - p1._1) * dir._1, (p2._2 - p1._2) * dir._2)
// the coordinates of vector projected on direction
- val p1p2 = (p2._1 - p1._1,p2._2 - p1._2)
- val (dx,dy) = projectVector((p1p2._1, p1p2._2), dir)
-
+ val p1p2 = (p2._1 - p1._1, p2._2 - p1._2)
+ val (dx, dy) = projectVector((p1p2._1, p1p2._2), dir)
+
// Add direction to the distance
// if the angle is acute angle then we need to do nothing
// if the angle is obtuse angle then we need to move on of the vertex.
- if (vectorProduct(p1p2,dir)<0) {
+ if (vectorProduct(p1p2, dir) < 0) {
// swap the two nodes
-// val temp = coord(constraint.v1);
-// setCoord(constraint.v1, coord(constraint.v2))
-// setCoord(constraint.v2,temp)
- // or we can just move p2 to another side
- val (nx,ny) = (2*p1._1 - p2._1, 2*p1._2 - p2._2);
- setCoord(constraint.v2, (nx,ny))
-//
+ // val temp = coord(constraint.v1);
+ // setCoord(constraint.v1, coord(constraint.v2))
+ // setCoord(constraint.v2,temp)
+ // or we can just move p2 to another side
+ val (nx, ny) = (2 * p1._1 - p2._1, 2 * p1._2 - p2._2)
+ setCoord(constraint.v2, (nx, ny))
+ //
}
-// // if it is 0 then they are perpendicular
+ // // if it is 0 then they are perpendicular
// this constraint cannot be projected.
- if ((vectorProduct(p1p2,dir)==0.0)&& !isConstraintSatisfied(constraint)){
- if (bug) println("constraint " + constraint)
- if (bug) println("vector p1 p2 is " + p1p2)
- if (bug) println("vector direction " + dir)
-
- val (v2x,v2y) = coord(constraint.v2)
- setCoord(constraint.v2, (v2x+dir._1/50,v2y+dir._2/50))
+ if ((vectorProduct(p1p2, dir) == 0.0) && !isConstraintSatisfied(constraint)) {
+ if (bug) println("constraint " + constraint)
+ if (bug) println("vector p1 p2 is " + p1p2)
+ if (bug) println("vector direction " + dir)
+
+ val (v2x, v2y) = coord(constraint.v2)
+ setCoord(constraint.v2, (v2x + dir._1 / 50, v2y + dir._2 / 50))
}
-
+
val ideal = (dir._1 * constraint.length, dir._2 * constraint.length)
(
- if ((constraint.order == 0 && dx != ideal._1) ||
- (constraint.order == -1 && dx > ideal._1) ||
- (constraint.order == 1 && dx < ideal._1))
+ if ((constraint.order == 0 && dx != ideal._1) ||
+ (constraint.order == -1 && dx > ideal._1) ||
+ (constraint.order == 1 && dx < ideal._1))
//if (!isConstraintSatisfied(constraint))
-
+
dx - ideal._1
- //ideal._1 - dx
+ //ideal._1 - dx
else 0,
- if ((constraint.order == 0 && dy != ideal._2) ||
- (constraint.order == -1 && dy > ideal._2) ||
- (constraint.order == 1 && dy < ideal._2))
+ if ((constraint.order == 0 && dy != ideal._2) ||
+ (constraint.order == -1 && dy > ideal._2) ||
+ (constraint.order == 1 && dy < ideal._2))
dy - ideal._2
- //ideal._2 - dy
+ //ideal._2 - dy
else 0
)
-
+
case None =>
- val (dx,dy) = (p2._1 - p1._1, p2._2 - p1._2)
- val length = math.sqrt(dx*dx + dy*dy)
- val dir = if (length != 0) (dx/length, dy/length) else (1.0,0.0)
+ val (dx, dy) = (p2._1 - p1._1, p2._2 - p1._2)
+ val length = math.sqrt(dx * dx + dy * dy)
+ val dir = if (length != 0) (dx / length, dy / length) else (1.0, 0.0)
- if ((constraint.order == 0 && length != constraint.length) ||
- (constraint.order == -1 && length > constraint.length) ||
- (constraint.order == 1 && length < constraint.length))
+ if ((constraint.order == 0 && length != constraint.length) ||
+ (constraint.order == -1 && length > constraint.length) ||
+ (constraint.order == 1 && length < constraint.length))
(dir._1 * (length - constraint.length), dir._2 * (length - constraint.length))
- else (0.0,0.0)
+ else (0.0, 0.0)
}
- val (shiftX, shiftY) = shift match { case (x,y) => if (constraint.soft) (x * alpha, y * alpha) else (x,y) }
+ val (shiftX, shiftY) = shift match {
+ case (x, y) => if (constraint.soft) (x * alpha, y * alpha) else (x, y)
+ }
if (shiftX != 0.0 || shiftY != 0.0) {
feasible = false
-
+
setCoord(constraint.v1, (p1._1 + (constraint.mv1 * shiftX), p1._2 + (constraint.mv1 * shiftY)))
setCoord(constraint.v2, (p2._1 - (constraint.mv2 * shiftX), p2._2 - (constraint.mv2 * shiftY)))
}
@@ -145,25 +114,62 @@ trait Constraints extends GraphLayout {
iteration += 1
}
-// if (feasible) {
-// println("feasible solution found after " + iteration + " iterations")
-// } else {
-// println("no feasible solution")
-// }
+ // if (feasible) {
+ // println("feasible solution found after " + iteration + " iterations")
+ // } else {
+ // println("no feasible solution")
+ // }
+ }
+
+ def isConstraintSatisfied(c: Constraint): Boolean = {
+ val (v1x, v1y) = coord(c.v1)
+ val (v2x, v2y) = coord(c.v2)
+ val (v1v2x, v1v2y) = (v2x - v1x, v2y - v1y)
+
+ val dir = c.direction match {
+ case Some(d) => d
+ case None => (0.0, 0.0)
+ }
+
+ if (c.order == 0) projectLength((v1v2x, v1v2y), dir) == c.length
+ else if (c.order > 0) projectLength((v1v2x, v1v2y), dir) >= c.length
+ else projectLength((v1v2x, v1v2y), dir) <= c.length
+
+ }
+
+ def projectLength(v1: (Double, Double), v2: (Double, Double)): Double = {
+ val length = vectorProduct(v1, v2)
+ length
+ }
+
+ // project v1 on v2 // direction is unit vector
+ def projectVector(v1: (Double, Double), v2: (Double, Double)): (Double, Double) = {
+ val length = vectorProduct(v1, v2)
+ (length * v2._1, length * v2._2)
+ }
+
+ def vectorProduct(v1: (Double, Double), v2: (Double, Double)): Double = {
+ v1._1 * v2._1 + v1._2 * v2._2
}
}
-class ConstraintSeq extends Iterable[(Constraint,Int)] {
+class ConstraintSeq extends Iterable[(Constraint, Int)] {
+ private val cs = collection.mutable.ListBuffer[() => Iterator[(Constraint, Int)]]()
private var _currentLayer = 0
- def currentLayer = _currentLayer
- private val cs = collection.mutable.ListBuffer[() => Iterator[(Constraint,Int)]]()
- def nextLayer() { _currentLayer += 1 }
- def clear() { cs.clear(); _currentLayer = 0 }
+ def currentLayer: Int = _currentLayer
+
+ def nextLayer() {
+ _currentLayer += 1
+ }
+
+ def clear() {
+ cs.clear(); _currentLayer = 0
+ }
def +=(c: Constraint) {
val layer = _currentLayer
- cs += (() => Iterator((c,layer)))
+ cs += (() => Iterator((c, layer)))
}
def ++=(cf: => Iterable[Constraint]) {
@@ -171,45 +177,51 @@ class ConstraintSeq extends Iterable[(Constraint,Int)] {
cs += (() => cf.iterator.zip(Iterator.continually(layer)))
}
- def iterator = cs.iterator.map(x => x()).foldLeft(Iterator[(Constraint,Int)]())(_ ++ _)
+ def iterator: Iterator[(Constraint, Int)] = cs.iterator.map(x => x()).foldLeft(Iterator[(Constraint, Int)]())(_ ++ _)
}
- // order: which relation with d > < or =
-case class Constraint(v1: VName, v2: VName, direction: Option[(Double,Double)], length: Double, w1: Double, w2: Double, order: Int, soft: Boolean) {
- lazy val mv1 = if (w1 + w2 != 0.0) w2 / (w1 + w2) else 0.5
- lazy val mv2 = 1.0 - mv1
-
+
+// order: which relation with d > < or =
+case class Constraint(v1: VName, v2: VName, direction: Option[(Double, Double)], length: Double, w1: Double, w2: Double, order: Int, soft: Boolean) {
+ lazy val mv1: Double = if (w1 + w2 != 0.0) w2 / (w1 + w2) else 0.5
+ lazy val mv2: Double = 1.0 - mv1
+
}
object Constraint {
- object distance {
- def from(v1: VName) = new DistanceFromExpr(v1)
- }
-
+
class DistanceFromExpr(v1: VName) {
- def to(v2: VName) = DistanceExpr(v1,v2)
+ def to(v2: VName) = DistanceExpr(v1, v2)
}
-
- case class DistanceExpr(v1: VName, v2: VName, direction: Option[(Double,Double)] = None, w1: Double = 1.0, w2: Double = 1.0) {
- def along(dir: (Double,Double)) = copy(direction = Some(dir))
-
- def weighted(w: (Double,Double)) = copy(w1 = w._1, w2 = w._2)
+
+ case class DistanceExpr(v1: VName, v2: VName, direction: Option[(Double, Double)] = None, w1: Double = 1.0, w2: Double = 1.0) {
+ def along(dir: (Double, Double)): DistanceExpr = copy(direction = Some(dir))
+
+ def weighted(w: (Double, Double)): DistanceExpr = copy(w1 = w._1, w2 = w._2)
+
+ def <=(len: Double) = Constraint(v1, v2, normalizedDir, len, w1, w2, -1, soft = false)
private def normalizedDir = direction.map {
- case (0.0,0.0) => throw new ConstraintException("'along' direction must be a non-zero vector")
- case (x,y) =>
- val length = math.sqrt(x*x + y*y)
+ case (0.0, 0.0) => throw new ConstraintException("'along' direction must be a non-zero vector")
+ case (x, y) =>
+ val length = math.sqrt(x * x + y * y)
// return the normal
(x / length, y / length)
}
-
-
- def <= (len: Double) = Constraint(v1,v2,normalizedDir,len,w1,w2,-1,soft=false)
- def ===(len: Double) = Constraint(v1,v2,normalizedDir,len,w1,w2,0,soft=false)
- def >= (len: Double) = Constraint(v1,v2,normalizedDir,len,w1,w2,1,soft=false)
-
- def ~<= (len: Double) = Constraint(v1,v2,normalizedDir,len,w1,w2,-1,soft=true)
- def ~== (len: Double) = Constraint(v1,v2,normalizedDir,len,w1,w2,0,soft=true)
- def ~>= (len: Double) = Constraint(v1,v2,normalizedDir,len,w1,w2,1,soft=true)
+
+ def ===(len: Double) = Constraint(v1, v2, normalizedDir, len, w1, w2, 0, soft = false)
+
+ def >=(len: Double) = Constraint(v1, v2, normalizedDir, len, w1, w2, 1, soft = false)
+
+ def ~<=(len: Double) = Constraint(v1, v2, normalizedDir, len, w1, w2, -1, soft = true)
+
+ def ~==(len: Double) = Constraint(v1, v2, normalizedDir, len, w1, w2, 0, soft = true)
+
+ def ~>=(len: Double) = Constraint(v1, v2, normalizedDir, len, w1, w2, 1, soft = true)
}
+
+ object distance {
+ def from(v1: VName) = new DistanceFromExpr(v1)
+ }
+
}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/layout/DeriveLayout.scala b/scala/src/main/scala/quanto/layout/DeriveLayout.scala
index cfb393cc..863f1b52 100644
--- a/scala/src/main/scala/quanto/layout/DeriveLayout.scala
+++ b/scala/src/main/scala/quanto/layout/DeriveLayout.scala
@@ -5,21 +5,21 @@ import quanto.layout.constraint._
class DeriveLayout {
- def layout(derivation: Derivation) = {
+ def layout(derivation: Derivation): Derivation = {
var steps = Map[DSName, DStep]()
val layoutProc = new ForceLayout with Clusters
layoutProc.alpha0 = 0.05
layoutProc.keepCentered = false
- while (steps.size < derivation.steps.size) derivation.steps.foreach { case (sname, step) =>
+ while (steps.size < derivation.steps.size) derivation.steps.foreach { case (stepName, step) =>
// try to pull a parent that has already been processed, or root if step has no parent
- val parentOpt = derivation.parentMap.get(sname) match {
+ val parentOpt = derivation.parentMap.get(stepName) match {
case Some(p) => steps.get(p).map(_.graph)
case None => Some(derivation.root)
}
- parentOpt.map { parent: Graph =>
+ parentOpt.foreach { parent: Graph =>
var g = step.graph
layoutProc.clearLockedVertices()
@@ -27,7 +27,9 @@ class DeriveLayout {
// are fresh for parent)
for (v <- parent.verts) {
if (g.verts.contains(v)) {
- g = g.updateVData(v) { _.withCoord(parent.vdata(v).coord) }
+ g = g.updateVData(v) {
+ _.withCoord(parent.vdata(v).coord)
+ }
layoutProc.lockVertex(v)
}
}
@@ -35,11 +37,11 @@ class DeriveLayout {
layoutProc.initialize(g, randomCoords = false)
// relax a bit to layout new coords
- for (i <- 1 to 10) layoutProc.step()
+ for (_ <- 1 to 10) layoutProc.step()
layoutProc.updateGraph()
// layout the graph, and add the step
- steps += sname -> step.copy(graph = layoutProc.graph)
+ steps += stepName -> step.copy(graph = layoutProc.graph)
}
}
diff --git a/scala/src/main/scala/quanto/layout/DotLayout.scala b/scala/src/main/scala/quanto/layout/DotLayout.scala
index bf55319c..7058b96e 100644
--- a/scala/src/main/scala/quanto/layout/DotLayout.scala
+++ b/scala/src/main/scala/quanto/layout/DotLayout.scala
@@ -1,31 +1,55 @@
package quanto.layout
-import sys.process._
-import quanto.data._
import java.io._
-import quanto.data.VName
+
+import quanto.data.{VName, _}
+
+import scala.sys.process._
class DotLayout extends GraphLayout {
var dotString = ""
- var dotProcess: Process = null
+ var dotProcess: Process = _
var dotIn: BufferedReader = _
var dotOut: BufferedWriter = _
+ protected def compute() {
+ val (vid, dotStr) = generateDot(graph)
+ dotString = dotStr
+ var xMax = 0.0
+ var yMax = 0.0
+
+ val Graph = """graph \d+ (\d+(\.\d+)?) (\d+(\.\d+)?)""".r
+ val Node = """node (\d+) (\d+(\.\d+)?) (\d+(\.\d+)?) .*""".r
+
+ var coordMap = Map[Int, (Double, Double)]()
+
+ ("dot -Tplain" #< new ByteArrayInputStream(dotString.getBytes("UTF-8"))).lines_!.foreach {
+ case Graph(xbd, _, ybd, _) =>
+ xMax = xbd.toDouble
+ yMax = ybd.toDouble
+ case Node(id, x, _, y, _) =>
+ coordMap += id.toInt -> ((x.toDouble - 0.5 * xMax) / 10, (0.5 * yMax - y.toDouble) / 10)
+ case _ => ()
+ }
+
+ vid.foreach { case (v, id) => setCoord(v, coordMap(id)) }
+ }
+
private def generateDot(graph: Graph) = {
val sb = new StringBuilder
- var vid = Map[VName,Int]()
+ var vid = Map[VName, Int]()
var i = 0
sb ++= "digraph {\n"
- graph.vdata.foreach { case (v,d) =>
+ graph.vdata.foreach { case (v, _) =>
sb ++= " " + i + " [width=14,height=14]\n"
vid += v -> i
i += 1
}
- graph.edata.foreach { case (e,d) =>
+ graph.edata.foreach { case (e, d) =>
sb ++= " %d %s %d\n".format(
vid(graph.source(e)),
if (d.isDirected) "->" else "--",
@@ -34,7 +58,7 @@ class DotLayout extends GraphLayout {
i = 0
- graph.bbdata.foreach { case(bb,d) =>
+ graph.bbdata.foreach { case (bb, _) =>
sb ++= " subgraph \"cluster_" + i + "\" { \n"
graph.contents(bb).foreach { v => sb ++= " " + vid(v) + "\n" }
sb ++= " }\n"
@@ -43,31 +67,6 @@ class DotLayout extends GraphLayout {
sb ++= "}\n"
- (vid,sb.toString)
- }
-
- protected def compute() {
- val (vid,dotStr) = generateDot(graph)
- dotString = dotStr
- var xMax = 0.0
- var yMax = 0.0
-
- val Graph = """graph \d+ (\d+(\.\d+)?) (\d+(\.\d+)?)""".r
- val Node = """node (\d+) (\d+(\.\d+)?) (\d+(\.\d+)?) .*""".r
-
- var coordMap = Map[Int,(Double,Double)]()
-
- ("dot -Tplain" #< new ByteArrayInputStream(dotString.getBytes("UTF-8"))).lines_!.foreach { line =>
- line match {
- case Graph(xbd,_, ybd,_) =>
- xMax = xbd.toDouble
- yMax = ybd.toDouble
- case Node(id, x,_, y,_) =>
- coordMap += id.toInt -> ((x.toDouble - 0.5 * xMax) / 10, (0.5 * yMax - y.toDouble) / 10)
- case _ => ()
- }
- }
-
- vid.foreach { case (v,id) => setCoord(v, coordMap(id)) }
+ (vid, sb.toString)
}
}
diff --git a/scala/src/main/scala/quanto/layout/ForceLayout.scala b/scala/src/main/scala/quanto/layout/ForceLayout.scala
index 113b250b..ea779474 100644
--- a/scala/src/main/scala/quanto/layout/ForceLayout.scala
+++ b/scala/src/main/scala/quanto/layout/ForceLayout.scala
@@ -1,50 +1,39 @@
package quanto.layout
-import quanto.util._
import quanto.data._
-import math.{min,max,abs}
import quanto.layout.constraint._
+import quanto.util._
+
+import scala.math.abs
/**
- * Force-directed layout algorithm. Parts are based on:
- * [1] force.js from the D3 javascript library (see d3js.org)
- * [2] "Scalable, Versatile and Simple Constrained Graph Layout", Dwyer 2009
- * [3] "Efficient and High Quality Force-Directed Graph Drawing", Hu 2006
- */
+ * Force-directed layout algorithm. Parts are based on:
+ * [1] force.js from the D3 javascript library (see d3js.org)
+ * [2] "Scalable, Versatile and Simple Constrained Graph Layout", Dwyer 2009
+ * [3] "Efficient and High Quality Force-Directed Graph Drawing", Hu 2006
+ */
class ForceLayout extends GraphLayout with Constraints {
// repulsive force between vertices
//var charge: VName => Double = (v => if (graph.vdata(v).isWireVertex) 3.0 else 5.0)
var nodeCharge = 5.0
-
- def charge(v:VName) = if (graph.vdata(v).isWireVertex && nodeCharge != 0.0) 1.0 else nodeCharge
-
// spring strength on edges
var strength = 2.5
-
// preferred length of edge
var edgeLength = 0.5
-
// (small) attractive force toward center of bounds
var gravity = 1.0
-
// Barnes-Hut approximation constant. Higher = coarser
var theta = 0.8
-
// used in Verlet integration
var friction = 0.9
-
// initial step size
var alpha0: Double = 1.0
-
// increase or decrease step size by this amount
var alphaAdjust = 0.7
-
// maximum iterations
var maxIterations = 3000
-
// re-center graph after each iteration
var keepCentered = true
-
// step size alpha is re-computed on the fly using trust region heuristic
var alpha: Double = _
var prevEnergy: Double = _
@@ -52,6 +41,8 @@ class ForceLayout extends GraphLayout with Constraints {
var progress: Int = _
var iteration = 0
+ def charge(v: VName): Double = if (graph.vdata(v).isWireVertex && nodeCharge != 0.0) 1.0 else nodeCharge
+
override def initialize(g: Graph, randomCoords: Boolean = true) {
super.initialize(g, randomCoords)
alpha = alpha0
@@ -60,24 +51,24 @@ class ForceLayout extends GraphLayout with Constraints {
}
// compute the equivalent point charge for every region of space in the quad tree
- def computeCharges(tr: QuadTree[(Option[VName],Double)]): QuadTree[(Option[VName],Double)] = tr match {
+ def computeCharges(tr: QuadTree[(Option[VName], Double)]): QuadTree[(Option[VName], Double)] = tr match {
case leaf: QuadLeaf[_] => leaf
case _: QuadNode[_] =>
- val node = tr.asInstanceOf[QuadNode[(Option[VName],Double)]]
+ val node = tr.asInstanceOf[QuadNode[(Option[VName], Double)]]
- val (v,nCharge) = node.value.getOrElse((None,0.0))
+ val (v, nCharge) = node.value.getOrElse((None, 0.0))
val nw = computeCharges(node.nw)
val ne = computeCharges(node.ne)
val sw = computeCharges(node.sw)
val se = computeCharges(node.se)
- val (p,totalCharge) = Iterator(nw,ne,sw,se).foldLeft((node.p._1 * nCharge, node.p._2 * nCharge), nCharge) {
- case ((pSum,cSum), child) =>
- val (_,c) = child.value.getOrElse((None,0.0))
+ val (p, totalCharge) = Iterator(nw, ne, sw, se).foldLeft((node.p._1 * nCharge, node.p._2 * nCharge), nCharge) {
+ case ((pSum, cSum), child) =>
+ val (_, c) = child.value.getOrElse((None, 0.0))
((child.p._1 * c + pSum._1, child.p._2 * c + pSum._2), c + cSum)
}
- val center = if (totalCharge != 0.0) (p._1 / totalCharge, p._2 / totalCharge) else (0.0,0.0)
- QuadNode(node.x1,node.y1,node.x2,node.y2,Some((v,totalCharge)),center,nw,ne,sw,se)
+ val center = if (totalCharge != 0.0) (p._1 / totalCharge, p._2 / totalCharge) else (0.0, 0.0)
+ QuadNode(node.x1, node.y1, node.x2, node.y2, Some((v, totalCharge)), center, nw, ne, sw, se)
}
// take an unconstrained step in the direction of steepest descent in energy
@@ -98,13 +89,28 @@ class ForceLayout extends GraphLayout with Constraints {
val oldCoords = coords
+ // shake overlapping elements slightly
+ val vertexList: List[VName] = graph.verts.toList
+ for (i <- vertexList; j <- vertexList if i != j) {
+ val p1 = coord(i)
+ val p2 = coord(j)
+ val shake = 0.2
+ if (p1 == p2) {
+ setCoord(i, (
+ p1._1 + shake * Math.random(),
+ p1._2 + shake * Math.random()
+ ))
+ }
+ }
+
+
// apply spring forces
for (e <- graph.edges) {
val sp = coord(graph.source(e))
val tp = coord(graph.target(e))
- val (dx,dy) = if (this.isInstanceOf[Ranking] || this.isInstanceOf[IRanking]) (2.0*(tp._1 - sp._1), tp._2 - sp._2)
- else (tp._1 - sp._1, tp._2 - sp._2)
- val d = math.sqrt(dx*dx + dy*dy)
+ val (dx, dy) = if (this.isInstanceOf[Ranking] || this.isInstanceOf[IRanking]) (2.0 * (tp._1 - sp._1), tp._2 - sp._2)
+ else (tp._1 - sp._1, tp._2 - sp._2)
+ val d = math.sqrt(dx * dx + dy * dy)
if (d != 0.0) {
val displacement = d - edgeLength
val k = (alpha * strength * displacement) / d
@@ -127,18 +133,18 @@ class ForceLayout extends GraphLayout with Constraints {
}
// compute charges
- val quad = computeCharges(QuadTree(graph.verts.toSeq.map { v => (coord(v), (Some(v),charge(v))) }))
+ val quad = computeCharges(QuadTree(graph.verts.toSeq.map { v => (coord(v), (Some(v), charge(v))) }))
// apply charge forces
for (v <- graph.verts if !lockedVertices.contains(v)) {
var p = coord(v)
quad.visit { nd =>
nd.value match {
- case Some((optV,nodeCharge1)) =>
- val (dx1,dy1) = (nd.p._1 - p._1, nd.p._2 - p._2)
+ case Some((optV, nodeCharge1)) =>
+ val (dx1, dy1) = (nd.p._1 - p._1, nd.p._2 - p._2)
val dx = if (abs(dx1) < 0.01) 0.01 else dx1
val dy = if (abs(dy1) < 0.01) 0.01 else dy1
- val d2 = dx*dx + dy*dy
+ val d2 = dx * dx + dy * dy
if (d2 == 0.0) false
else {
@@ -147,7 +153,7 @@ class ForceLayout extends GraphLayout with Constraints {
energy += (charge(v) + nodeCharge1) / d2
val kx = alpha * nodeCharge1 / d2
val ky = if (this.isInstanceOf[Ranking] || this.isInstanceOf[IRanking]) kx * 1.5 else kx
- p = (p._1 - dx*kx, p._2 - dy*ky)
+ p = (p._1 - dx * kx, p._2 - dy * ky)
true
} else {
// if !B-H, but there is a (different) vertex here, act with the point charge
@@ -155,7 +161,7 @@ class ForceLayout extends GraphLayout with Constraints {
case Some(v1) if v1 != v =>
energy += (charge(v) + charge(v1)) / d2
val k = alpha * charge(v1) / d2
- p = (p._1 - dx*k, p._2 - dy*k)
+ p = (p._1 - dx * k, p._2 - dy * k)
case _ =>
}
@@ -171,24 +177,24 @@ class ForceLayout extends GraphLayout with Constraints {
// position verlet integration
for (v <- graph.verts) {
- val (px,py) = oldCoords(v)
- val (x,y) = coord(v)
- setCoord(v, (x - ((px-x) * friction), y - ((py-y)*friction)))
+ val (px, py) = oldCoords(v)
+ val (x, y) = coord(v)
+ setCoord(v, (x - ((px - x) * friction), y - ((py - y) * friction)))
}
}
def recenter() {
- val (sumCoordx,sumCoordy) = graph.verts.foldLeft(0.0,0.0)((pos,name)
- => (pos._1+coord(name)._1,pos._2+coord(name)._2))
- val (centerX,centerY) = (sumCoordx/graph.verts.size,sumCoordy/graph.verts.size)
-
-// if(abs(centerX)> 5|| abs(centerY)> 5){
- graph.verts.foreach(name=>{
- val (px,py) = coord(name)
- setCoord(name, (px-centerX, py-centerY ))
- })
- // }
+ val (sumCoordX, sumCoordY) = graph.verts.foldLeft(0.0, 0.0)((pos, name)
+ => (pos._1 + coord(name)._1, pos._2 + coord(name)._2))
+ val (centerX, centerY) = (sumCoordX / graph.verts.size, sumCoordY / graph.verts.size)
+
+ // if(abs(centerX)> 5|| abs(centerY)> 5){
+ graph.verts.foreach(name => {
+ val (px, py) = coord(name)
+ setCoord(name, (px - centerX, py - centerY))
+ })
+ // }
}
def step() {
diff --git a/scala/src/main/scala/quanto/layout/GraphLayout.scala b/scala/src/main/scala/quanto/layout/GraphLayout.scala
index 887ba130..a8f0a1b6 100644
--- a/scala/src/main/scala/quanto/layout/GraphLayout.scala
+++ b/scala/src/main/scala/quanto/layout/GraphLayout.scala
@@ -1,46 +1,59 @@
package quanto.layout
-import quanto.data.{VName, Graph}
+import quanto.data.{Graph, VName}
+
+import scala.collection.mutable
import scala.util.Random
class LayoutUninitializedException extends Exception("Layout data read before layout() called")
abstract class GraphLayout {
- private var _graph: Graph = null
- def graph = _graph
+ protected val lockedVertices: mutable.Set[VName] = collection.mutable.Set()
+ private val _coords = collection.mutable.Map[VName, (Double, Double)]()
+ private var _graph: Graph = _
- protected val lockedVertices = collection.mutable.Set[VName]()
+ def lockVertex(v: VName) {
+ lockedVertices += v
+ }
- def lockVertex(v: VName) { lockedVertices += v }
- def clearLockedVertices() { lockedVertices.clear() }
+ def clearLockedVertices() {
+ lockedVertices.clear()
+ }
- private val _coords = collection.mutable.Map[VName,(Double,Double)]()
- def setCoord(v: VName, p:(Double,Double)) {
+ def setCoord(v: VName, p: (Double, Double)) {
if (!lockedVertices.contains(v)) _coords(v) = p
}
+
def coord(v: VName) = _coords(v)
- def coords = _coords.clone()
- // override to compute layout data
- protected def compute()
+ def coords: mutable.Map[VName, (Double, Double)] = _coords.clone()
+
+ def layout(g: Graph, randomCoords: Boolean = true): Graph = {
+ initialize(g, randomCoords)
+ compute()
+ updateGraph()
+
+ graph
+ }
def initialize(g: Graph, randomCoords: Boolean = true) {
_graph = g
_coords.clear()
val r = new Random(0xdeadbeef)
- graph.vdata.foreach { case (v,d) =>
- _coords(v) = if (randomCoords) (0.5 - r.nextDouble(), 0.5 - r.nextDouble()) else d.coord }
+ graph.vdata.foreach { case (v, d) =>
+ _coords(v) = if (randomCoords) (0.5 - r.nextDouble(), 0.5 - r.nextDouble()) else d.coord
+ }
}
def updateGraph() {
- _graph = _coords.foldLeft(graph) { case(g,(v,c)) => g.updateVData(v) { _.withCoord(c) } }
+ _graph = _coords.foldLeft(graph) { case (g, (v, c)) => g.updateVData(v) {
+ _.withCoord(c)
+ }
+ }
}
- def layout(g: Graph, randomCoords: Boolean = true): Graph = {
- initialize(g, randomCoords)
- compute()
- updateGraph()
+ def graph: Graph = _graph
- graph
- }
+ // override to compute layout data
+ protected def compute()
}
diff --git a/scala/src/main/scala/quanto/layout/GraphTransform.scala b/scala/src/main/scala/quanto/layout/GraphTransform.scala
index c2a7ee0f..4f5d9b06 100644
--- a/scala/src/main/scala/quanto/layout/GraphTransform.scala
+++ b/scala/src/main/scala/quanto/layout/GraphTransform.scala
@@ -1,28 +1,27 @@
package quanto.layout
-import quanto.data.Graph
-import quanto.data.VData
+import quanto.data.{Graph, VData}
+
+import scala.collection.mutable.ListBuffer
// transfer each bbox to a node
-class GraphTransform (g:Graph){
-
- //val numOfBBox = g.bboxes.size
- val bs = collection.mutable.ListBuffer[() => Iterator[VData]]()
-
- def transform {
- g.bboxes.foreach(bb =>
- for(ver <- g.contents(bb)){
- val target = g.succVerts(ver)
-
-
-
- })
-
-
-
- }
+class GraphTransform(g: Graph) {
+
+ //val numOfBBox = g.bboxes.size
+ val bs: ListBuffer[() => Iterator[VData]] = collection.mutable.ListBuffer[() => Iterator[VData]]()
- def restore{
+ def transform() {
+ g.bboxes.foreach(bb =>
+ for (ver <- g.contents(bb)) {
+ val target = g.succVerts(ver)
+
+
+ })
- }
+
+ }
+
+ def restore() {
+
+ }
}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/layout/constraint/Clusters.scala b/scala/src/main/scala/quanto/layout/constraint/Clusters.scala
index fa27b001..f922ddf3 100644
--- a/scala/src/main/scala/quanto/layout/constraint/Clusters.scala
+++ b/scala/src/main/scala/quanto/layout/constraint/Clusters.scala
@@ -1,37 +1,37 @@
package quanto.layout.constraint
-import quanto.data._
-import collection.mutable.ListBuffer
+import quanto.data.{VName, _}
import quanto.layout._
+import quanto.util.Geometry._
+import quanto.util.{QuadTree, _}
+
+import scala.collection.mutable.ListBuffer
+import scala.math.abs
-import math.{min,max,abs}
-import quanto.util.QuadTree
-import quanto.data.VName
-import quanto.util._
-import Geometry._
-import Names._
/**
- * Constraints to force nodes not in the given clusters outside of the bounding box
- */
+ * Constraints to force nodes not in the given clusters outside of the bounding box
+ */
trait Clusters extends Constraints {
+
import Constraint.distance
+
var debug = false
- var clusters = ListBuffer[Set[VName]]()
+ var clusters: ListBuffer[Set[VName]] = ListBuffer()
var clusterPadding = 0.5
var clusterRadiusPerVertex = 2.0
-
+
// right code?
- def inRect(v:VName, r : ((Double,Double),(Double,Double))) : Boolean = {
- val (px,py) = coord(v)
- val ((lx,ly),(ux,uy)) = r
+ def inRect(v: VName, r: ((Double, Double), (Double, Double))): Boolean = {
+ val (px, py) = coord(v)
+ val ((lx, ly), (ux, uy)) = r
(px > lx) && (px < ux) && (py > ly) && (py < uy)
}
-
+
override def initialize(g: Graph, randomCoords: Boolean = true) {
super.initialize(g, randomCoords)
constraints.nextLayer()
//println("Clusters at layer " + constraints.currentLayer)
-
+
// take each bbox as a cluster
clusters.clear
g.bboxes.foreach(bb => clusters += g.contents(bb))
@@ -40,43 +40,41 @@ trait Clusters extends Constraints {
val coordTree = QuadTree(graph.verts.toSeq.map { v => (coord(v), v) })
- clusters.foldLeft(List[Constraint]()) { (constraints,cluster) =>
- bounds(cluster.map(coord(_))) match {
+ clusters.foldLeft(List[Constraint]()) { (constraints, cluster) =>
+ bounds(cluster.map(coord)) match {
case Some(rect) =>
-
- val bbox = new RichRect(rect)
- val (lb,ub) = bbox.pad(clusterPadding)
+
+ val bbox = new RichRect(rect)
+ val (lb, ub) = bbox.pad(clusterPadding)
//if (debug) println("the bound bbox boundary is " +(lb,ub))
-
- // centre of the bbox
- val (cx,cy) = bbox.center
- //if (debug) (println("the center is " + bbox.center))
- val (wth,hgt) = bbox.size
+
val clusterSize = cluster.size
var cons = constraints
for (v1 <- cluster; v2 <- cluster; if v1 != v2) {
- cons ::= { (distance from v1 to v2) <= (clusterRadiusPerVertex * (clusterSize)-1) } // * 0.95
-// cons ::= { (distance from v1 to v2 along (1.0,0.0)) <= (clusterRadiusPerVertex * (clusterSize)) } // * 0.95
-// cons ::= { (distance from v1 to v2 along (0.0,1.0)) <= (clusterRadiusPerVertex * (clusterSize)) } // * 0.95
+ cons ::= {
+ (distance from v1 to v2) <= (clusterRadiusPerVertex * clusterSize - 1)
+ } // * 0.95
+ // cons ::= { (distance from v1 to v2 along (1.0,0.0)) <= (clusterRadiusPerVertex * (clusterSize)) } // * 0.95
+ // cons ::= { (distance from v1 to v2 along (0.0,1.0)) <= (clusterRadiusPerVertex * (clusterSize)) } // * 0.95
}
-
+
//val verts = g.verts.filter(vname => inRect(vname,(lb,ub)))
- val verts = coordTree.query(lb._1, lb._2, ub._1, ub._2)
+ val vertices = coordTree.query(lb._1, lb._2, ub._1, ub._2)
// quadtree is not right
//println("quadTree is " + verts)
-// if (debug) println ("v3 's coordinate is " + coord("v3"))
-// if (debug) println ("b0 's coordinate is " + coord("b0"))
-// if (debug) println ("v1 's coordinate is " + coord("v1"))
- // println("v0 " + "is " + coord("v0"))
- // println("box " + (lb,ub))
+ // if (debug) println ("v3 's coordinate is " + coord("v3"))
+ // if (debug) println ("b0 's coordinate is " + coord("b0"))
+ // if (debug) println ("v1 's coordinate is " + coord("v1"))
+ // println("v0 " + "is " + coord("v0"))
+ // println("box " + (lb,ub))
//if (debug) println ("all vertices in the bbox\n\t" + verts)
- for (v <- verts; if !cluster.contains(v)) {
+ for (v <- vertices; if !cluster.contains(v)) {
val vc = coord(v)
-
- // println(v + " is in the bbox but not belongs it");
+
+ // println(v + " is in the bbox but not belongs it");
// work out the most efficient shift
/////
@@ -86,69 +84,71 @@ trait Clusters extends Constraints {
*/
// need a better way to shift.
-
+
// move to the side that has more force to pull it
- val (shiftDirX, shiftDirY) = (g.succVerts(v)++g.predVerts(v)).
- foldLeft(0.0,0.0)((pos,name) => (pos._1+coord(name)._1-vc._1,pos._2+coord(name)._2-vc._2))
+ //val (shiftDirX, shiftDirY) = (g.succVerts(v) ++ g.predVerts(v)).
+ // foldLeft(0.0, 0.0)((pos, name) => (pos._1 + coord(name)._1 - vc._1, pos._2 + coord(name)._2 - vc._2))
//println("node " + v + " should move direction" + (shiftDirX, shiftDirY))
-// val xShift = if (vc._1 >= lb._1 && vc._1 <= ub._1) {
-// if(shiftDirX > 0) ub._1 - vc._1 else lb._1 - vc._1
-//
-// } else 0
-//
-// val yShift = if (vc._2 >= lb._2 && vc._2 <= ub._2) {
-// if(shiftDirY > 0){
-// ub._2 - vc._2
-// }
-// else{
-// lb._2 - vc._2
-// }
-// } else 0
-//
-// if (xShift != 0 && yShift != 0){
-// for (v1 <- cluster) {
-// val v1c = coord(v1)
-// val (len,dir) = if (abs(shiftDirX) > abs(shiftDirY)) {
-// if (xShift < 0.0) (abs((vc._1 - v1c._1) + (xShift)), (1.0,0.0))
-// else (abs((vc._1 - v1c._1) + xShift), (-1.0,0.0))
-// } else {
-// if (yShift < 0.0) (abs((vc._2 - v1c._2) + (yShift)), (0.0,1.0))
-// else (abs((vc._2 - v1c._2) + (yShift)), (0.0,-1.0))
-// }
-//
-// cons ::= { (distance from v to v1 along dir weighted (1.0, 1.0)) >= len } //cluster.size.toDouble+1
-// }
-// }
-
+ // val xShift = if (vc._1 >= lb._1 && vc._1 <= ub._1) {
+ // if(shiftDirX > 0) ub._1 - vc._1 else lb._1 - vc._1
+ //
+ // } else 0
+ //
+ // val yShift = if (vc._2 >= lb._2 && vc._2 <= ub._2) {
+ // if(shiftDirY > 0){
+ // ub._2 - vc._2
+ // }
+ // else{
+ // lb._2 - vc._2
+ // }
+ // } else 0
+ //
+ // if (xShift != 0 && yShift != 0){
+ // for (v1 <- cluster) {
+ // val v1c = coord(v1)
+ // val (len,dir) = if (abs(shiftDirX) > abs(shiftDirY)) {
+ // if (xShift < 0.0) (abs((vc._1 - v1c._1) + (xShift)), (1.0,0.0))
+ // else (abs((vc._1 - v1c._1) + xShift), (-1.0,0.0))
+ // } else {
+ // if (yShift < 0.0) (abs((vc._2 - v1c._2) + (yShift)), (0.0,1.0))
+ // else (abs((vc._2 - v1c._2) + (yShift)), (0.0,-1.0))
+ // }
+ //
+ // cons ::= { (distance from v to v1 along dir weighted (1.0, 1.0)) >= len } //cluster.size.toDouble+1
+ // }
+ // }
+
// not sure if we ned the if condition
// since quadtree will make sure the condition is true
val xShift = if (vc._1 >= lb._1 && vc._1 <= ub._1) {
- val left = lb._1 - vc._1
+ val left = lb._1 - vc._1
val right = ub._1 - vc._1
if (abs(left) < right) left else right
} else 0
val yShift = if (vc._2 >= lb._2 && vc._2 <= ub._2) {
- val down = lb._2 - vc._2
+ val down = lb._2 - vc._2
val up = ub._2 - vc._2
if (-down < up) down else up
} else 0
-
- //questions about following code
+
+ //questions about following code
if (xShift != 0 && yShift != 0) {
// val soft = g.isBBoxed(v)
for (v1 <- cluster) {
val v1c = coord(v1)
- val (len,dir) = if (abs(xShift) < abs(yShift)) {
- if (xShift < 0.0) (abs((vc._1 - v1c._1) + (xShift)), (1.0,0.0))
- else (abs((vc._1 - v1c._1) + xShift), (-1.0,0.0))
+ val (len, dir) = if (abs(xShift) < abs(yShift)) {
+ if (xShift < 0.0) (abs((vc._1 - v1c._1) + xShift), (1.0, 0.0))
+ else (abs((vc._1 - v1c._1) + xShift), (-1.0, 0.0))
} else {
- if (yShift < 0.0) (abs((vc._2 - v1c._2) + (yShift)), (0.0,1.0))
- else (abs((vc._2 - v1c._2) + (yShift)), (0.0,-1.0))
+ if (yShift < 0.0) (abs((vc._2 - v1c._2) + yShift), (0.0, 1.0))
+ else (abs((vc._2 - v1c._2) + yShift), (0.0, -1.0))
}
-
- cons ::= { (distance from v to v1 along dir weighted (1.0, 1.0)) >= len } //cluster.size.toDouble+1
-// cons ::= { (distance from v to v1 along dir weighted (1.0, cluster.size.toDouble+1)) >= len } //
+
+ cons ::= {
+ (distance from v to v1 along dir weighted(1.0, 1.0)) >= len
+ } //cluster.size.toDouble+1
+ // cons ::= { (distance from v to v1 along dir weighted (1.0, cluster.size.toDouble+1)) >= len } //
}
}
}
@@ -156,7 +156,7 @@ trait Clusters extends Constraints {
cons
case None => constraints
}
-
+
}
}
}
diff --git a/scala/src/main/scala/quanto/layout/constraint/NoTweaks.scala b/scala/src/main/scala/quanto/layout/constraint/NoTweaks.scala
index 49467101..c8cddec5 100644
--- a/scala/src/main/scala/quanto/layout/constraint/NoTweaks.scala
+++ b/scala/src/main/scala/quanto/layout/constraint/NoTweaks.scala
@@ -1,53 +1,49 @@
package quanto.layout.constraint
import quanto.data._
-import collection.mutable.ListBuffer
import quanto.layout._
-import math.{min,max,abs}
-import quanto.util.QuadTree
-import quanto.data.VName
-import quanto.util._
-import Geometry._
-import Names._
//import quanto.layout.distance
trait NoTweaks extends Constraints {
- val radius = 0.2
- override def initialize(g: Graph, randomCoords: Boolean = true) {
- super.initialize(g, randomCoords)
- constraints.nextLayer()
- println("Clusters at layer " + constraints.currentLayer)
-
- // for edges we get rid of the
-
- for(v <- g.verts){
- val edges = g.inEdges(v) ++ g.outEdges(v)
- val compEdges = g.edges -- edges
-
- // for each vertex, the distance to each compEdges
- // is greater than its radius.
-
- for (e <- compEdges ){
- val s = graph.source(e)
- val t = graph.target(e)
- val (sx,sy) = coord(s)
- val (tx,ty) = coord(t)
- val (pqx,pqy) = (sx-tx, sy-ty)
- var par = (0.0, 0.0)
- if (pqy == 0){
- par = (0.0,1.0)
- }
- else if (pqx == 0){
- par = (1.0, 0.0)
- }
- else {
- par = (-pqy/pqx,1.0)
- }
-
- constraints += { (Constraint.distance from v to s along par) >= radius }
- }
-
- }
- }
+ val radius = 0.2
+
+ override def initialize(g: Graph, randomCoords: Boolean = true) {
+ super.initialize(g, randomCoords)
+ constraints.nextLayer()
+ println("Clusters at layer " + constraints.currentLayer)
+
+ // for edges we get rid of the
+
+ for (v <- g.verts) {
+ val edges = g.inEdges(v) ++ g.outEdges(v)
+ val compEdges = g.edges -- edges
+
+ // for each vertex, the distance to each compEdges
+ // is greater than its radius.
+
+ for (e <- compEdges) {
+ val s = graph.source(e)
+ val t = graph.target(e)
+ val (sx, sy) = coord(s)
+ val (tx, ty) = coord(t)
+ val (pqx, pqy) = (sx - tx, sy - ty)
+ var par = (0.0, 0.0)
+ if (pqy == 0) {
+ par = (0.0, 1.0)
+ }
+ else if (pqx == 0) {
+ par = (1.0, 0.0)
+ }
+ else {
+ par = (-pqy / pqx, 1.0)
+ }
+
+ constraints += {
+ (Constraint.distance from v to s along par) >= radius
+ }
+ }
+
+ }
+ }
}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/layout/constraint/Ranking.scala b/scala/src/main/scala/quanto/layout/constraint/Ranking.scala
index aaa669f8..0e7d0b54 100644
--- a/scala/src/main/scala/quanto/layout/constraint/Ranking.scala
+++ b/scala/src/main/scala/quanto/layout/constraint/Ranking.scala
@@ -4,10 +4,12 @@ import quanto.data._
import quanto.layout._
/**
- * Mix in to add ranking constraints to initialization
- */
+ * Mix in to add ranking constraints to initialization
+ */
trait Ranking extends Constraints {
+
import Constraint.distance
+
var rankSep: Double = 1.0
override def initialize(g: Graph, randomCoords: Boolean = true) {
@@ -25,7 +27,9 @@ trait Ranking extends Constraints {
}
trait IRanking extends Constraints {
+
import Constraint.distance
+
var rankSep: Double = 1.0
override def initialize(g: Graph, randomCoords: Boolean = true) {
@@ -34,6 +38,8 @@ trait IRanking extends Constraints {
val dag = g.dagCopy
for (e <- dag.edges)
- constraints += {(distance from dag.source(e) to dag.target(e) along (0,-1)) ~>= rankSep}
+ constraints += {
+ (distance from dag.source(e) to dag.target(e) along(0, -1)) ~>= rankSep
+ }
}
}
diff --git a/scala/src/main/scala/quanto/layout/constraint/VerticalBoundary.scala b/scala/src/main/scala/quanto/layout/constraint/VerticalBoundary.scala
index e0476a69..4f1a29c3 100644
--- a/scala/src/main/scala/quanto/layout/constraint/VerticalBoundary.scala
+++ b/scala/src/main/scala/quanto/layout/constraint/VerticalBoundary.scala
@@ -1,50 +1,55 @@
package quanto.layout.constraint
-import quanto.layout._
import quanto.data._
+import quanto.layout._
trait VerticalBoundary extends Constraints {
+
import Constraint.distance
-
+
override def initialize(g: Graph, randomCoords: Boolean = true) {
super.initialize(g, randomCoords)
constraints.nextLayer()
//println("VerticalBoundary at layer " + constraints.currentLayer)
-
-// for (bnd <- Iterator(g.inputs,g.outputs); if !bnd.isEmpty) {
-// val it = bnd.iterator
-// val v1 = it.next()
-// for (v2 <- it) {
-//// if (g.isBBoxed(v1) && !g.isBBoxed(v2)) {
-////
-//// }
-// coord(v1)
-// if (g.isBBoxed(v1) || g.isBBoxed(v2))
-// constraints += { (distance from v1 to v2 along (0.0,1.0)) ~== 0.0 }
-// else
-// constraints += { (distance from v1 to v2 along (0.0,1.0)) === 0.0 }
-// //constraints += { (distance from v1 to v2 along (0.0,1.0)) ~== 0.0 }
-// // v1 = v2
-// }
-// }
+
+ // for (bnd <- Iterator(g.inputs,g.outputs); if !bnd.isEmpty) {
+ // val it = bnd.iterator
+ // val v1 = it.next()
+ // for (v2 <- it) {
+ //// if (g.isBBoxed(v1) && !g.isBBoxed(v2)) {
+ ////
+ //// }
+ // coord(v1)
+ // if (g.isBBoxed(v1) || g.isBBoxed(v2))
+ // constraints += { (distance from v1 to v2 along (0.0,1.0)) ~== 0.0 }
+ // else
+ // constraints += { (distance from v1 to v2 along (0.0,1.0)) === 0.0 }
+ // //constraints += { (distance from v1 to v2 along (0.0,1.0)) ~== 0.0 }
+ // // v1 = v2
+ // }
+ // }
g.edges.foreach { e =>
- val (s,t) = (g.source(e), g.target(e))
- if (g.isInput(s) || g.isOutput(t)) {
- if (g.isBBoxed(s) || g.isBBoxed(t)){
- constraints += { (distance from s to t along (1.0,0.0)) ~== 0.0 }
+ val (s, t) = (g.source(e), g.target(e))
+ if (g.isInputWire(s) || g.isOutputWire(t)) {
+ if (g.isBBoxed(s) || g.isBBoxed(t)) {
+ constraints += {
+ (distance from s to t along(1.0, 0.0)) ~== 0.0
+ }
}
else
- constraints += { (distance from s to t along (1.0,0.0)) === 0.0 }
+ constraints += {
+ (distance from s to t along(1.0, 0.0)) === 0.0
+ }
}
}
}
-
-// override def projectConstraints(){
-//
-//
-// }
-
+
+ // override def projectConstraints(){
+ //
+ //
+ // }
+
}
diff --git a/scala/src/main/scala/quanto/rewrite/AngleExpressionMatcher.scala b/scala/src/main/scala/quanto/rewrite/AngleExpressionMatcher.scala
deleted file mode 100644
index bad80bb1..00000000
--- a/scala/src/main/scala/quanto/rewrite/AngleExpressionMatcher.scala
+++ /dev/null
@@ -1,42 +0,0 @@
-package quanto.rewrite
-
-import quanto.data._
-import quanto.util._
-
-
-class AngleExpressionMatcher(pVars : Vector[String], tVars : Vector[String], mat : RationalMatrix) {
- val pvSet = pVars.toSet
- val tvSet = tVars.toSet
-
- def addMatch(pExpr : AngleExpression, tExpr : AngleExpression) : Option[AngleExpressionMatcher] = {
- val pVars1 = pVars ++ (pExpr.vars -- pvSet).toVector
- val tVars1 = tVars ++ (tExpr.vars -- tvSet).toVector
-
- val r1 = pVars1.map { v => pExpr.coeffs.getOrElse(v, Rational(0)) }
- val r2 = tVars1.map { v => tExpr.coeffs.getOrElse(v, Rational(0)) }
- val row = r1 ++ r2 :+ (tExpr.const - pExpr.const)
-
- mat.padTo(pVars1.length, tVars1.length).gaussUpdate(row).map { mat1 =>
- new AngleExpressionMatcher(pVars1, tVars1, mat1)
- }
- }
-
-
- def toMap : Map[String, AngleExpression] =
- if (mat.numCols == 0) Map()
- else mat.rows.foldLeft(Map[String,AngleExpression]()) { (mp, row) =>
- val p = RationalMatrix.findPivot(row)
- var coeffs = Map[String,Rational]()
- for (i <- p+1 until mat.line)
- if (row(i) != Rational(0)) coeffs = coeffs + (pVars(i) -> row(i) * -1)
- for (i <- mat.line to row.length - 2)
- if (row(i) != Rational(0)) coeffs = coeffs + (tVars(i-mat.line) -> row(i))
-
- mp + (pVars(p) -> AngleExpression(row.last, coeffs))
- }
-}
-
-object AngleExpressionMatcher {
- def apply(pVars : Vector[String], tVars : Vector[String]) =
- new AngleExpressionMatcher(pVars, tVars, new RationalMatrix(Vector(), pVars.length))
-}
diff --git a/scala/src/main/scala/quanto/rewrite/CompositeExpressionMatcher.scala b/scala/src/main/scala/quanto/rewrite/CompositeExpressionMatcher.scala
new file mode 100644
index 00000000..9a14b80c
--- /dev/null
+++ b/scala/src/main/scala/quanto/rewrite/CompositeExpressionMatcher.scala
@@ -0,0 +1,86 @@
+package quanto.rewrite
+
+
+import quanto.data.Theory.ValueType
+import quanto.data.{CompositeExpression, PhaseExpression}
+import quanto.util.{Rational, RationalMatrix}
+
+class PhaseExpressionMatcher(pVars: Vector[String], tVars: Vector[String], mat: RationalMatrix) {
+ val pvSet: Set[String] = pVars.toSet
+ val tvSet: Set[String] = tVars.toSet
+
+ def addMatch(patternExpression: PhaseExpression, targetExpression: PhaseExpression): Option[PhaseExpressionMatcher] = {
+ val patternVars1 = pVars ++ (patternExpression.vars.toSet -- pvSet).toVector
+ val targetVars1 = tVars ++ (targetExpression.vars.toSet -- tvSet).toVector
+ val zero = Rational(0)
+ val r1 = patternVars1.map { v => patternExpression.coefficients.getOrElse(v, zero) }
+ val r2 = targetVars1.map { v => targetExpression.coefficients.getOrElse(v, zero) }
+ val row = r1 ++ r2 :+ (targetExpression.constant - patternExpression.constant)
+
+ mat.padTo(patternVars1.length, targetVars1.length).gaussUpdate(row).map { mat1 =>
+ new PhaseExpressionMatcher(patternVars1, targetVars1, mat1)
+ }
+ }
+
+ def toMap(valueType: ValueType): Map[String, PhaseExpression] = toMap.mapValues(_.as(valueType))
+
+ def toMap: Map[String, PhaseExpression] =
+ if (mat.numCols == 0) Map()
+ else mat.rows.foldLeft(Map[String, PhaseExpression]()) { (mp, row) =>
+ val p = RationalMatrix.findPivot(row)
+ var coefficients = Map[String, Rational]()
+ for (i <- p + 1 until mat.line)
+ if (row(i) != Rational(0)) coefficients = coefficients + (pVars(i) -> row(i) * -1)
+ for (i <- mat.line to row.length - 2)
+ if (row(i) != Rational(0)) coefficients = coefficients + (tVars(i - mat.line) -> row(i))
+
+ mp + (pVars(p) -> PhaseExpression(row.last, coefficients, ValueType.Rational))
+ }
+}
+
+object PhaseExpressionMatcher {
+ def empty: PhaseExpressionMatcher = PhaseExpressionMatcher(Vector(), Vector(), None)
+
+ def apply(pVars: Vector[String], tVars: Vector[String], modulus: Option[Int]) =
+ new PhaseExpressionMatcher(pVars, tVars, new RationalMatrix(Vector(), pVars.length, modulus))
+}
+
+
+class CompositeExpressionMatcher(matchers: Map[ValueType, Option[PhaseExpressionMatcher]]) {
+
+ // Add matchings by component
+ def addMatch(pExpr: CompositeExpression, tExpr: CompositeExpression): Option[CompositeExpressionMatcher] = {
+ val valuePairs = pExpr.values.zip(tExpr.values)
+ val typeValueTriples = pExpr.valueTypes.zip(valuePairs)
+ // Loop through each valueType and apply relevant matches
+ // If any of the matches fail then the whole thing needs to fail
+ typeValueTriples.foldLeft(Some(this): Option[CompositeExpressionMatcher])((om, t) => {
+ if (om.nonEmpty) {
+ om.get.addPhaseMatch(t._1, t._2._1.asInstanceOf[PhaseExpression], t._2._2.asInstanceOf[PhaseExpression])
+ } else {
+ None
+ }
+ })
+ }
+
+ // Add a single matching to a specific valueType
+ def addPhaseMatch(valueType: ValueType, pExpr: PhaseExpression, tExpr: PhaseExpression): Option[CompositeExpressionMatcher] = {
+ val updatedSingletonMatcher = matchers.
+ getOrElse(valueType, Some(PhaseExpressionMatcher(Vector(), Vector(), pExpr.modulus))).
+ get.addMatch(pExpr, tExpr)
+ if (updatedSingletonMatcher.nonEmpty) {
+ Some(new CompositeExpressionMatcher(matchers + (valueType -> updatedSingletonMatcher)))
+ } else {
+ None
+ }
+ }
+
+ def toMap: Map[ValueType, Map[String, PhaseExpression]] =
+ matchers.keySet.map(
+ valueType => valueType -> matchers(valueType).getOrElse(PhaseExpressionMatcher.empty).toMap.mapValues(_.as(valueType))
+ ).toMap
+}
+
+object CompositeExpressionMatcher {
+ def apply(): CompositeExpressionMatcher = new CompositeExpressionMatcher(Map())
+}
diff --git a/scala/src/main/scala/quanto/rewrite/Match.scala b/scala/src/main/scala/quanto/rewrite/Match.scala
index a660de38..c9fa44e8 100644
--- a/scala/src/main/scala/quanto/rewrite/Match.scala
+++ b/scala/src/main/scala/quanto/rewrite/Match.scala
@@ -1,7 +1,11 @@
package quanto.rewrite
+import quanto.data.Theory.ValueType
import quanto.data._
+import quanto.util.UserAlerts
-class MatchException(msg: String) extends Exception(msg)
+class MatchException(msg: String) extends Exception(msg) {
+ UserAlerts.alert(s"Match Exception: $msg", UserAlerts.Elevation.WARNING)
+}
case class Match(pattern0: Graph, // the pattern without bbox operations
pattern: Graph,
@@ -9,7 +13,8 @@ case class Match(pattern0: Graph, // the pattern without bbox operations
map: GraphMap = GraphMap(),
bareWireMap: Map[VName, Vector[VName]] = Map(),
bbops: List[BBOp] = List(),
- subst: Map[String,AngleExpression] = Map()) {
+ subst: Map[ValueType, Map[String,PhaseExpression]] = Map() // matching individual phases inside composite phases
+ ) {
def addVertex(vPair: (VName, VName)): Match = {
copy(map = map addVertex vPair)
@@ -41,16 +46,16 @@ case class Match(pattern0: Graph, // the pattern without bbox operations
bareWireMap.headOption match {
case Some((tw, pw +: pws)) =>
val (target1, (newW1, newW2, newE)) = target.expandWire(tw)
- val emap1 = map.e + (pattern.outEdges(pw).head -> newE)
+ val edgeMap1 = map.e + (pattern.outEdges(pw).head -> newE)
- var vmap1 = map.v
+ var vertexMap1 = map.v
for (pw1 <- map.v.codf(tw)) {
- if (pattern.isInput(pw1)) vmap1 = vmap1 + (pw1 -> newW2)
+ if (pattern.isInputWire(pw1)) vertexMap1 = vertexMap1 + (pw1 -> newW2)
}
- vmap1 = vmap1 + (pw -> newW1) + (pattern.succVerts(pw).head -> newW2)
+ vertexMap1 = vertexMap1 + (pw -> newW1) + (pattern.succVerts(pw).head -> newW2)
copy(
- map = map.copy(v = vmap1, e = emap1),
+ map = map.copy(v = vertexMap1, e = edgeMap1),
target = target1,
bareWireMap = bareWireMap + (tw -> pws)
).normalize
diff --git a/scala/src/main/scala/quanto/rewrite/MatchState.scala b/scala/src/main/scala/quanto/rewrite/MatchState.scala
index 4eb56ad3..ecae3120 100644
--- a/scala/src/main/scala/quanto/rewrite/MatchState.scala
+++ b/scala/src/main/scala/quanto/rewrite/MatchState.scala
@@ -1,30 +1,30 @@
package quanto.rewrite
+
import quanto.data._
import scala.annotation.tailrec
case class MatchState(
- m: Match, // the match being built
- tVerts: Set[VName], // restriction of the range of the match
- angleMatcher: AngleExpressionMatcher, // state of matched angle data
- pNodes: Set[VName] = Set(), // nodes with partially-mapped neighbourhood
- psNodes: Set[VName] = Set(), // same, but scheduled for completion
- sBBox: Option[BBName] = None, // a bbox scheduled for matching
- candidateNodes: Option[Set[VName]] = None, // nodes to try matching in the target
- candidateEdges: Option[Set[EName]] = None, // edges to try matching in the target
- candidateWires: Option[Set[(VName,Int)]] = None, // wire-vertices to try matching bare wires on
- candidateBBoxes: Option[Set[BBName]] = None, // bboxes to try matching in the target
- bboxOrbits: PFun[VName, VName] = PFun(), // for smashing redundant matches
- nextState: Option[MatchState] = None // next state to try after search terminates
+ m: Match, // the match being built
+ targetVertices: Set[VName], // restriction of the range of the match
+ expressionMatcher: CompositeExpressionMatcher, // state of matched angle data
+ pNodes: Set[VName] = Set(), // nodes with partially-mapped neighbourhood
+ psNodes: Set[VName] = Set(), // same, but scheduled for completion
+ sBBox: Option[BBName] = None, // a bbox scheduled for matching
+ candidateNodes: Option[Set[VName]] = None, // nodes to try matching in the target
+ candidateEdges: Option[Set[EName]] = None, // edges to try matching in the target
+ candidateWires: Option[Set[(VName, Int)]] = None, // wire-vertices to try matching bare wires on
+ candidateBBoxes: Option[Set[BBName]] = None, // bboxes to try matching in the target
+ bboxOrbits: PFun[VName, VName] = PFun(), // for smashing redundant matches
+ nextState: Option[MatchState] = None // next state to try after search terminates
) {
- val uVerts: Set[VName] = m.pattern.verts.filter(v => bboxesMatched(v) && !m.map.v.domSet.contains(v))
- val uCircles: Set[VName] = uVerts.filter(m.pattern.isCircle)
- lazy val uBareWires: Set[VName] = uVerts.filter(m.pattern.representsBareWire)
- lazy val uNodes: Set[VName] = uVerts.filter { v => !m.pattern.vdata(v).isWireVertex }
- lazy val uWires: Set[VName] = uVerts.filter { v => m.pattern.vdata(v).isWireVertex }
- val uBBoxes: Set[BBName] = m.pattern.bboxes.filter(bb => parentBBoxesMatched(bb) && !m.map.bb.domSet.contains(bb))
-
+ lazy val uBareWires: Set[VName] = unmatchedVertices.filter(m.pattern.representsBareWire)
+ lazy val uNodes: Set[VName] = unmatchedVertices.filter { v => !m.pattern.vdata(v).isWireVertex }
+ lazy val uWires: Set[VName] = unmatchedVertices.filter { v => m.pattern.vdata(v).isWireVertex }
+ val unmatchedVertices: Set[VName] = m.pattern.verts.filter(v => bboxesMatched(v) && !m.map.v.domSet.contains(v))
+ val uCircles: Set[VName] = unmatchedVertices.filter(m.pattern.isCircle)
+ val uBBoxes: Set[BBName] = m.pattern.bboxes.filter(bb => parentBBoxesMatched(bb) && !m.map.bb.domSet.contains(bb))
/**
@@ -34,28 +34,36 @@ case class MatchState(
*/
@tailrec
final def nextMatch(): Option[(Match, Option[MatchState])] = {
+// if (m.map.e.domSet.size == 4) {
+// print("got 4 edges")
+// }
+
// if unmatched circles are found in the pattern, match them first
if (uCircles.nonEmpty) {
val pc = uCircles.min
- tVerts.find(v => m.target.isCircle(v) && reflectsBBoxes(pc, v)) match {
- case None => nextState match { case Some(next) => next.nextMatch(); case None => None }
+ targetVertices.find(v => m.target.isCircle(v) && reflectsBBoxes(pc, v)) match {
+ case None => nextState match {
+ case Some(next) => next.nextMatch();
+ case None => None
+ }
case Some(tc) =>
val pce = m.pattern.inEdges(pc).min
val tce = m.target.inEdges(tc).min
- copy(m = m.addEdge(pce -> tce, pc -> tc), tVerts = tVerts - tc).nextMatch()
+ copy(m = m.addEdge(pce -> tce, pc -> tc), targetVertices = targetVertices - tc).nextMatch()
}
- // if there is a scheduled node, try to match its neighbourhood in every possible way
+ // if there is a scheduled node, try to match its neighbourhood in every possible way
} else if (psNodes.nonEmpty) {
val np = psNodes.min
- if (pVertexMayBeCompleted(np)) {
+ val cmp = pVertexMayBeCompleted(np)
+ if (cmp) {
val nt = m.map.v(np)
// get the next matchable edge in the neighbourhood of np
val uEdges = m.pattern.adjacentEdges(np).filter(e =>
!m.map.e.domSet.contains(e) &&
- bboxesMatched(m.pattern.edgeGetOtherVertex(e, np))
+ bboxesMatched(m.pattern.edgeGetOtherVertex(e, np))
)
val epOpt = if (uEdges.isEmpty) None else Some(uEdges.min)
epOpt match {
@@ -66,7 +74,7 @@ case class MatchState(
case None =>
copy(candidateEdges = Some(m.target.adjacentEdges(nt).filter { e =>
!m.map.e.codSet.contains(e) &&
- tVerts.contains(m.target.edgeGetOtherVertex(e, nt))
+ targetVertices.contains(m.target.edgeGetOtherVertex(e, nt))
})).nextMatch()
case Some(candidateEdges1) =>
if (candidateEdges1.isEmpty) {
@@ -78,8 +86,10 @@ case class MatchState(
val et = candidateEdges1.min
val next = copy(candidateEdges = Some(candidateEdges1 - et))
matchNewWire(np, ep, nt, et) match {
- case Some(ms1) => ms1.copy(candidateEdges = None, nextState = Some(next)).nextMatch()
- case None => next.nextMatch()
+ case Some(ms1) =>
+ ms1.copy(candidateEdges = None, nextState = Some(next)).nextMatch()
+ case None =>
+ next.nextMatch()
}
}
}
@@ -100,13 +110,13 @@ case class MatchState(
}
- // if there are no scheduled nodes, pick a new unmatched node in the pattern, match it in every possible way
- // and schedule its neighbourhood for matching
+ // if there are no scheduled nodes, pick a new unmatched node in the pattern, match it in every possible way
+ // and schedule its neighbourhood for matching
} else if (uNodes.nonEmpty) {
val np = uNodes.min
candidateNodes match {
case None =>
- copy(candidateNodes = Some(tVerts.filter { v =>
+ copy(candidateNodes = Some(targetVertices.filter { v =>
!m.target.vdata(v).isWireVertex
})).nextMatch()
case Some(candidateNodes1) =>
@@ -125,18 +135,18 @@ case class MatchState(
}
}
- // if there are bare wires remaining, add them in all possible ways
+ // if there are bare wires remaining, add them in all possible ways
} else if (uBareWires.nonEmpty) {
val pbw = uBareWires.min
candidateWires match {
case None =>
// pull all the candidate locations for matching this bare wire. If a wire already has n bare wires matched
// on it, this is n+1 possible locations.
- val cwires =
- for (v <- tVerts if m.target.representsWire(v) && reflectsBBoxes(pbw, v);
+ val cWires =
+ for (v <- targetVertices if m.target.representsWire(v) && reflectsBBoxes(pbw, v);
i <- 0 to m.bareWireMap.get(v).map(_.length).getOrElse(0))
- yield (v,i)
- copy(candidateWires = Some(cwires.toSet)).nextMatch()
+ yield (v, i)
+ copy(candidateWires = Some(cWires)).nextMatch()
case Some(candidateWires1) =>
if (candidateWires1.isEmpty) {
nextState match {
@@ -156,8 +166,8 @@ case class MatchState(
}
}
- // if all matchable verts are matched, pull the first top-level bbox and try to kill, expand, or copy+match
- } else if (uBBoxes.nonEmpty && uVerts.isEmpty) {
+ // if all matchable verts are matched, pull the first top-level bbox and try to kill, expand, or copy+match
+ } else if (uBBoxes.nonEmpty && unmatchedVertices.isEmpty) {
val pbb = uBBoxes.min
candidateBBoxes match {
case Some(candidateBBoxes1) =>
@@ -169,11 +179,11 @@ case class MatchState(
} else {
val tbb = candidateBBoxes1.min
val next = copy(candidateBBoxes = Some(candidateBBoxes1 - tbb))
-// val schedule =
-// m.pattern.adjacentVerts(m.pattern.contents(pbb)).filter { v =>
-// !m.pattern.vdata(v).isWireVertex &&
-// bboxesMatched(v)
-// }
+ // val schedule =
+ // m.pattern.adjacentVerts(m.pattern.contents(pbb)).filter { v =>
+ // !m.pattern.vdata(v).isWireVertex &&
+ // bboxesMatched(v)
+ // }
copy(
m = m.addBBox(pbb -> tbb),
psNodes = pNodes, // re-schedule everything
@@ -192,7 +202,7 @@ case class MatchState(
// if a !-box is wild, drop it, rather than copy/expand
val (dropGraph, dropOp) = m.pattern.dropBBox(pbb)
copy(m = m.copy(pattern = dropGraph, bbops = dropOp :: m.bbops),
- nextState = Some(killState))
+ nextState = Some(killState))
} else { // only expand/copy non-wild !-boxes, or we'll get infinite matchings
val minV = m.pattern.contents(pbb).min
val (expandGraph, expandOp) = m.pattern.expandBBox(pbb)
@@ -206,7 +216,7 @@ case class MatchState(
else
copy(
m = m.copy(pattern = copyGraph, bbops = copyOp :: m.bbops),
- candidateBBoxes = Some(m.target.bboxes.filter { tbb => reflectsParentBBoxes(pbb, tbb)}),
+ candidateBBoxes = Some(m.target.bboxes.filter { tbb => reflectsParentBBoxes(pbb, tbb) }),
bboxOrbits = bboxOrbits + (copyOp.mp.v(minV) -> minV),
nextState = Some(killState)
)
@@ -231,12 +241,12 @@ case class MatchState(
expState.nextMatch()
}
- // if there is nothing left to do, check if the match is complete and return it if so. If not, continue
- // the search from nextState
+ // if there is nothing left to do, check if the match is complete and return it if so. If not, continue
+ // the search from nextState
} else {
if (pNodes.isEmpty && m.isTotal) {
if (MatchState.countMatches) MatchState.matchCounter += 1
- val ms = copy(m = m.copy(subst = angleMatcher.toMap))
+ val ms = copy(m = m.copy(subst = expressionMatcher.toMap))
Some((ms.m, nextState))
} else {
nextState match {
@@ -256,33 +266,37 @@ case class MatchState(
bboxOrbits.codf(rep).forall { pv1 =>
m.map.v.get(pv1) match {
case None => true
- case Some(tv1) => (pv <= pv1) == (tv <= tv1)
+ case Some(tv1) => if (pv <= pv1) { tv <= tv1 } else { true }
}
}
case None => true
}
- def pVertexMayBeCompleted(vp: VName) = {
- val allVerts = m.pattern.adjacentVerts(vp)
- val concreteVerts = allVerts.filter(v => m.pattern.bboxesContaining(v).isEmpty)
- val hasBBox = allVerts.size > concreteVerts.size
+ def pVertexMayBeCompleted(vp: VName): Boolean = {
+ val allEdges = m.pattern.adjacentEdges(vp)
+ val concreteEdges = allEdges.filter(e => m.pattern.bboxesContaining(m.pattern.edgeGetOtherVertex(e, vp)).isEmpty)
+ val hasBBox = allEdges.size > concreteEdges.size
val tArity = m.target.arity(m.map.v(vp))
-
- concreteVerts.size == tArity || (hasBBox && concreteVerts.size <= tArity)
+ concreteEdges.size == tArity || (hasBBox && concreteEdges.size <= tArity)
+ //val allVertices = m.pattern.adjacentVerts(vp)
+ //val concreteVertices = allVertices.filter(v => m.pattern.bboxesContaining(v).isEmpty)
+ //val hasBBox = allVertices.size > concreteVertices.size
+ //val tArity = m.target.arity(m.map.v(vp))
+ //concreteVertices.size == tArity || (hasBBox && concreteVertices.size <= tArity)
}
// TODO: we may only need to check the closest parent, not all parents
- def reflectsBBoxes(vp: VName, vt: VName) =
+ def reflectsBBoxes(vp: VName, vt: VName): Boolean =
m.map.bb.directImage(m.pattern.bboxesContaining(vp)) == m.target.bboxesContaining(vt)
- def bboxesMatched(vp: VName) =
+ def bboxesMatched(vp: VName): Boolean =
m.pattern.bboxesContaining(vp).forall(m.map.bb.domSet.contains)
- def reflectsParentBBoxes(bbp: BBName, bbt: BBName) =
+ def reflectsParentBBoxes(bbp: BBName, bbt: BBName): Boolean =
m.map.bb.directImage(m.pattern.bboxParents(bbp)) == m.target.bboxParents(bbt)
- def parentBBoxesMatched(bbp: BBName) = {
+ def parentBBoxesMatched(bbp: BBName): Boolean = {
m.pattern.bboxParents(bbp).forall(m.map.bb.domSet.contains)
}
@@ -301,22 +315,23 @@ case class MatchState(
(m.pattern.vdata(np), m.target.vdata(nt)) match {
case (pd: NodeV, td: NodeV) =>
if (pd.typ == td.typ) {
- if (pd.hasAngle)
- angleMatcher.addMatch(pd.angle, td.angle).map { angleMatcher1 =>
- copy(
- m = m.addVertex(np -> nt),
- pNodes = pNodes + np,
- psNodes = psNodes + np,
- tVerts = tVerts - nt,
- angleMatcher = angleMatcher1
- )
+ if (pd.hasValue)
+ expressionMatcher.addMatch(pd.phaseData, td.phaseData).map {
+ angleMatcher1 =>
+ copy(
+ m = m.addVertex(np -> nt),
+ pNodes = pNodes + np,
+ psNodes = psNodes + np,
+ targetVertices = targetVertices - nt,
+ expressionMatcher = angleMatcher1
+ )
}
else if (pd.value == td.value)
Some(copy(
m = m.addVertex(np -> nt),
pNodes = pNodes + np,
psNodes = psNodes + np,
- tVerts = tVerts - nt
+ targetVertices = targetVertices - nt
))
else None
} else None
@@ -324,24 +339,28 @@ case class MatchState(
}
/**
- * Try to recursively add wire to matching, starting with the given head
- * vertex and edge. Return NONE on failure.
- *
- * (ported from the ML function tryadd_wire)
- *
- * @param vp already-matched vertex
- * @param ep unmatched edge incident to vp (other end must be in P, Uw or Un)
- * @param vt target of vp
- * @param et unmatched edge incident to vt
- */
- def matchNewWire(vp:VName, ep:EName, vt:VName, et:EName): Option[MatchState] = {
- val pdir = m.pattern.edata(ep).isDirected
- val tdir = m.target.edata(et).isDirected
+ * Try to recursively add wire to matching, starting with the given head
+ * vertex and edge. Return NONE on failure.
+ *
+ * (ported from the ML function tryadd_wire)
+ *
+ * @param vp already-matched vertex
+ * @param ep unmatched edge incident to vp (other end must be in P, Uw or Un)
+ * @param vt target of vp
+ * @param et unmatched edge incident to vt
+ */
+ def matchNewWire(vp: VName, ep: EName, vt: VName, et: EName): Option[MatchState] = {
+ val pDir = m.pattern.edata(ep).isDirected
+ val tDir = m.target.edata(et).isDirected
val pOutEdge = m.pattern.source(ep) == vp
val tOutEdge = m.target.source(et) == vt
+ val pType = m.pattern.edata(ep).typ
+ val tType = m.target.edata(et).typ
// match directedness and, if the edge is directed, direction
- if ((pdir && tdir && pOutEdge == tOutEdge) || (!pdir && !tdir)) {
+ // also match on type
+ // TODO: Match data on edges
+ if (((pDir && tDir && pOutEdge == tOutEdge) || (!pDir && !tDir)) && (pType == tType)) {
val newVp = m.pattern.edgeGetOtherVertex(ep, vp)
val newVt = m.target.edgeGetOtherVertex(et, vt)
@@ -349,14 +368,14 @@ case class MatchState(
if (m.map.v contains (newVp -> newVt))
Some(copy(psNodes = psNodes + newVp, m = m.addEdge(ep -> et)))
else None
- } else if (tVerts contains newVt) {
+ } else if (targetVertices contains newVt) {
(m.pattern.vdata(newVp), m.target.vdata(newVt)) match {
case (_: WireV, _: WireV) if reflectsBBoxes(newVp, newVt) && matchIsMonotone(newVp, newVt) =>
(m.pattern.wireVertexGetOtherEdge(newVp, ep), m.target.wireVertexGetOtherEdge(newVt, et)) match {
case (Some(newEp), Some(newEt)) =>
copy(
m = m.addEdge(ep -> et, newVp -> newVt),
- tVerts = tVerts - newVt
+ targetVertices = targetVertices - newVt
).matchNewWire(newVp, newEp, newVt, newEt)
case (Some(_), None) => None
case (None, _) =>
@@ -376,14 +395,13 @@ case class MatchState(
}
object MatchState {
+ // use !-box orbits to ignore redundant matches
+ var smashSymmetries = true
// for testing e.g. laziness in a single thread
private var matchCounter = 0
private var countMatches = false
- // use !-box orbits to ignore redundant matches
- var smashSymmetries = true
-
- def startCountingMatches() = {
+ def startCountingMatches(): Unit = {
matchCounter = 0
countMatches = true
}
diff --git a/scala/src/main/scala/quanto/rewrite/Matcher.scala b/scala/src/main/scala/quanto/rewrite/Matcher.scala
index f1e462c0..80b0dbeb 100644
--- a/scala/src/main/scala/quanto/rewrite/Matcher.scala
+++ b/scala/src/main/scala/quanto/rewrite/Matcher.scala
@@ -1,20 +1,9 @@
package quanto.rewrite
-import quanto.data._
-import scala.annotation.tailrec
+import quanto.data._
object Matcher {
- private def matchMain(ms: MatchState): Stream[Match] =
- ms.nextMatch() match {
- case Some((m1,Some(next))) => m1 #:: matchMain(next)
- case Some((m1,None)) => Stream(m1)
- case None => Stream()
- }
-
def initialise(pat: Graph, tgt: Graph, restrictTo: Set[VName]): MatchState = {
- // TODO: new free vars should be fresh w.r.t. vars in target
- val patVars = pat.freeVars.toVector
- val tgtVars = tgt.freeVars.toVector
val patN = pat.normalise
val tgtN = tgt.normalise
val restrict0 = restrictTo intersect tgtN.verts
@@ -26,8 +15,8 @@ object Matcher {
MatchState(
m = Match(pattern0 = patN, pattern = patN, target = tgtN),
- tVerts = restrict1,
- angleMatcher = AngleExpressionMatcher(patVars,tgtVars))
+ targetVertices = restrict1,
+ expressionMatcher = CompositeExpressionMatcher()) // Create the matcher empty, it will fill itself in in time
}
def findMatches(pat: Graph, tgt: Graph, restrictTo: Set[VName]): Stream[Match] = {
@@ -37,4 +26,11 @@ object Matcher {
def findMatches(pat: Graph, tgt: Graph): Stream[Match] =
findMatches(pat, tgt, tgt.verts)
+ private def matchMain(ms: MatchState): Stream[Match] =
+ ms.nextMatch() match {
+ case Some((m1, Some(next))) => m1 #:: matchMain(next)
+ case Some((m1, None)) => Stream(m1)
+ case None => Stream()
+ }
+
}
diff --git a/scala/src/main/scala/quanto/rewrite/Rewriter.scala b/scala/src/main/scala/quanto/rewrite/Rewriter.scala
index 3331ea60..ac862830 100644
--- a/scala/src/main/scala/quanto/rewrite/Rewriter.scala
+++ b/scala/src/main/scala/quanto/rewrite/Rewriter.scala
@@ -1,32 +1,10 @@
package quanto.rewrite
-import quanto.data._
import quanto.data.Names._
+import quanto.data._
object Rewriter {
- def expandRhs(m: Match, rhs: Graph): Graph = {
- // ensure that *all* boundary names used in expanding bbops are avoided
- val fullBoundary = m.bbops.foldRight(m.pattern0.boundary) { (bbop, vs) =>
- bbop match {
- case BBExpand(_, mp) => vs union mp.v.directImage(vs)
- case BBCopy(_, mp) => vs union mp.v.directImage(vs)
- case _ => vs
- }
- }
-
- val rhs1 = m.bbops.foldRight(rhs) { (bbop, g) => g.applyBBOp(bbop, fullBoundary) }
-
- val vdata = rhs1.vdata.mapValues {
- case d: NodeV =>
- val data = d.data.setPath("$.value", d.angle.subst(m.subst).toString).asObject
- d.copy(data = data)
- case d: WireV => d
- }
-
- rhs1.copy(vdata = vdata)
- }
-
- def rewrite(m: Match, rhs: Graph, desc: RuleDesc = RuleDesc()): (Graph, Rule) = {
+ def rewrite(m: Match, rhs: Graph, desc: RuleDesc = RuleDesc()): (Graph, Rule) = {
// expand bare wires in the match
val m1 = m.normalize
@@ -44,23 +22,48 @@ object Rewriter {
.deleteEdges(m1.map.e.codSet)
.deleteVertices(m1.map.v.directImage(interiorLhs))
- val vmap = interiorRhs.foldRight(m1.map.v.restrictDom(boundary)) { (v, mp) =>
+ val vertexMap = interiorRhs.foldRight(m1.map.v.restrictDom(boundary)) { (v, mp) =>
mp + (v -> (context.verts union mp.codSet).freshWithSuggestion(v))
}
- val emap = rhsE.edges.foldRight(PFun[EName,EName]()) { (e, mp) =>
+ val edgeMap = rhsE.edges.foldRight(PFun[EName, EName]()) { (e, mp) =>
mp + (e -> (context.edges union mp.codSet).freshWithSuggestion(e))
}
// quotient the lhs and rhs such that pairs of boundaries mapped to the same vertex are identified
val quotientLhs = m1.pattern.rename(m1.map.v.toMap, m1.map.e.toMap, m1.map.bb.toMap)
- val quotientRhs = rhsE.rename(vmap.toMap, emap.toMap, m1.map.bb.toMap)
+ val quotientRhs = rhsE.rename(vertexMap.toMap, edgeMap.toMap, m1.map.bb.toMap)
val ruleInst = if (desc.inverse) Rule(quotientRhs, quotientLhs, description = desc)
- else Rule(quotientLhs, quotientRhs, description = desc)
+ else Rule(quotientLhs, quotientRhs, description = desc)
// compute the pushout as a union of the context with the quotiented domain of the matching
(quotientRhs.appendGraph(context), ruleInst)
}
+
+ def expandRhs(m: Match, rhs: Graph): Graph = {
+ // ensure that *all* boundary names used in expanding bbops are avoided
+ val fullBoundary = m.bbops.foldRight(m.pattern0.boundary) { (bbop, vs) =>
+ bbop match {
+ case BBExpand(_, mp) => vs union mp.v.directImage(vs)
+ case BBCopy(_, mp) => vs union mp.v.directImage(vs)
+ case _ => vs
+ }
+ }
+
+ val rhs1 = m.bbops.foldRight(rhs) { (bbop, g) => g.applyBBOp(bbop, fullBoundary) }
+
+ // Apply all substitutions from our matches
+ val vdata = rhs1.vdata.mapValues {
+ case d: NodeV =>
+ val data = d.data.setPath(
+ "$.value", m.subst.foldLeft(d.phaseData)((data, t) => data.substSubValues(t._2)).toString
+ ).asObject
+ d.copy(data = data)
+ case d: WireV => d
+ }
+
+ rhs1.copy(vdata = vdata)
+ }
}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/rewrite/Simproc.scala b/scala/src/main/scala/quanto/rewrite/Simproc.scala
index 0722e45e..849256cd 100644
--- a/scala/src/main/scala/quanto/rewrite/Simproc.scala
+++ b/scala/src/main/scala/quanto/rewrite/Simproc.scala
@@ -1,20 +1,31 @@
package quanto.rewrite
+import quanto.cosy.{AutoReduce, RuleSynthesis}
+import quanto.data.Derivation.DerivationWithHead
import quanto.data._
import quanto.layout.ForceLayout
+import quanto.util.UserAlerts
+
+import scala.util.Random
abstract class Simproc {
+ var sourceFile: String = ""
+ var sourceCode: String = ""
+
def simp(g: Graph): Iterator[(Graph, Rule)]
+ // jython binding for >>
+ def __rshift__(t: Simproc): Simproc = this >> t
+
// chain two simprocs together
def >>(t: Simproc) = {
val s = this
new Simproc {
- override def simp(g: Graph): Iterator[(Graph, Rule)] = new Iterator[(Graph,Rule)] {
- var iterS: Iterator[(Graph,Rule)] = s.simp(g)
- var iterT: Iterator[(Graph,Rule)] = null
- var lastGraphS = g
+ override def simp(g: Graph): Iterator[(Graph, Rule)] = new Iterator[(Graph, Rule)] {
+ var iterS: Iterator[(Graph, Rule)] = s.simp(g)
+ var iterT: Iterator[(Graph, Rule)] = _
+ var lastGraphS: Graph = g
override def hasNext: Boolean =
if (iterT != null) iterT.hasNext
@@ -28,9 +39,9 @@ abstract class Simproc {
if (iterT != null) {
iterT.next()
} else if (iterS.hasNext) {
- val (g1,r1) = iterS.next()
+ val (g1, r1) = iterS.next()
lastGraphS = g1
- (g1,r1)
+ (g1, r1)
} else {
iterT = t.simp(lastGraphS)
iterT.next()
@@ -38,12 +49,66 @@ abstract class Simproc {
}
}
}
-
- // jython binding for >>
- def __rshift__(t: Simproc): Simproc = this >> t
}
object Simproc {
+
+ // Converts a (Derivation, Head) pair into an iterated series of steps
+ // Allows gluing together of simprocs and derivations
+ implicit def fromDerivationWithHead(d: DerivationWithHead): Iterator[(Graph, Rule)] = {
+ if (d._2.nonEmpty) {
+ d._1.stepsTo(d._2.get).map(d._1.steps).map(step => (step.graph, step.rule)).toIterator
+ } else {
+ Iterator.empty
+ }
+ }
+
+ /**
+ * Anneals the graph using only the rules (forwards only), using vertex size as the metric
+ * No initial heat specified, just accepts worse states with a (decreasing-over-time) random chance
+ *
+ * @param rules List of rules, taken forwards only
+ * @param steps Number of steps to be taken
+ * @param dilation How slowly we stop accepting worse states
+ * @return The resulting derivation will appear (all at once) in the side bar
+ */
+ def ANNEAL(rules: List[Rule],
+ steps: Int,
+ dilation: Double,
+ seed: Random = new Random(),
+ vertexLimit: Option[Int] = None) = new Simproc {
+ override def simp(g: Graph): Iterator[(Graph, Rule)] = {
+ val reduced = AutoReduce.annealingReduce(
+ RuleSynthesis.basicGraphComparison,
+ RuleSynthesis.graphToDerivation(g),
+ rules,
+ steps,
+ dilation,
+ seed,
+ vertexLimit)
+ fromDerivationWithHead(reduced)
+ }
+ }
+
+ def LOG(graphEval: Graph => String) = new Simproc {
+ override def simp(g: Graph): Iterator[(Graph, Rule)] = {
+ UserAlerts.alert(graphEval(g), UserAlerts.Elevation.NOTICE)
+ Iterator.empty
+ }
+ }
+
+ // takes a list of rules and rewrites w.r.t. the first that gets a match
+ def REWRITE(rules: List[Rule]) = new Simproc {
+ override def simp(g: Graph): Iterator[(Graph, Rule)] = {
+ for (rule <- rules)
+ Matcher.findMatches(rule.lhs, g).headOption.foreach { m =>
+ return Iterator.single(layout(Rewriter.rewrite(m, rule.rhs, rule.description)))
+ }
+ //println("got no match REWRITE: " + rules.map{_.name}.toString())
+ Iterator.empty
+ }
+ }
+
private def layout(gr: (Graph, Rule)) = {
val (graph, rule) = gr
val layoutProc = new ForceLayout
@@ -55,54 +120,71 @@ object Simproc {
layoutProc.maxIterations = 300
//layoutProc.keepCentered = false
- val rhsi = rule.rhs.verts.filter(!rule.rhs.isBoundary(_))
+ val rhsInterior = rule.rhs.verts.filter(!rule.rhs.isTerminalWire(_))
//println(rhsi)
- graph.verts.foreach { v => if (!rhsi.contains(v)) layoutProc.lockVertex(v) }
+ graph.verts.foreach { v => if (!rhsInterior.contains(v)) layoutProc.lockVertex(v) }
//graph.verts.foreach { v => if (graph.isBoundary(v)) layoutProc.lockVertex(v) }
(layoutProc.layout(graph, randomCoords = false).snapToGrid(), rule)
//(graph, rule)
}
- object EMPTY extends Simproc { override def simp(g: Graph): Iterator[(Graph,Rule)] = Iterator.empty }
-
- // takes a list of rules and rewrites w.r.t. the first that gets a match
- def REWRITE(rules: List[Rule]) = new Simproc {
- override def simp(g: Graph): Iterator[(Graph, Rule)] = {
- for (rule <- rules)
- Matcher.findMatches(rule.lhs, g).headOption.foreach { m =>
- return Iterator.single(layout(Rewriter.rewrite(m, rule.rhs, rule.description)))
- }
- Iterator.empty
- }
- }
-
- def REWRITE_TARGETED(rule: Rule, vp: VName, targ: Graph => Option[VName]) = new Simproc {
+ // Applies rewrite rules, but only if the rule affects the targeted vertex
+ def REWRITE_TARGETED(rule: Rule, vertexInPattern: VName, target: Graph => Option[VName]) = new Simproc {
override def simp(g: Graph): Iterator[(Graph, Rule)] = {
- targ(g).flatMap { vt =>
+ target(g).flatMap { vt =>
//println("REWRITE_TARGETED(" + rule.name + ", " + vt + ")")
if (g.verts contains vt) {
val ms = Matcher.initialise(rule.lhs, g, g.verts)
- ms.matchNewNode(vp, vt).flatMap(_.nextMatch())
+ ms.matchNewNode(vertexInPattern, vt).flatMap(_.nextMatch())
} else None
} match {
- case Some((m,_)) =>
+ case Some((m, _)) =>
//println("SUCCESS")
Iterator.single(layout(Rewriter.rewrite(m, rule.rhs, rule.description)))
case None =>
+ //println("got no match REWRITE_TARGETED: " + rule.name)
//println("FAILED")
Iterator.empty
}
}
}
- def REWRITE_METRIC(rules: List[Rule], metric: Graph => Int) =
+ def REWRITE_TARGET_LIST(rule: Rule, vp: VName, target: List[VName]): Simproc = new Simproc {
+ override def simp(g: Graph): Iterator[(Graph, Rule)] = {
+ for (vt <- target) {
+ if (g.verts contains vt) {
+ val ms = Matcher.initialise(rule.lhs, g, g.verts)
+ ms.matchNewNode(vp, vt).flatMap(_.nextMatch()).map { case (m, _) =>
+ return Iterator.single(layout(Rewriter.rewrite(m, rule.rhs, rule.description)))
+ }
+ }
+ }
+ Iterator.empty
+ // targ(g).flatMap { vt =>
+ // //println("REWRITE_TARGETED(" + rule.name + ", " + vt + ")")
+ // if (g.verts contains vt) {
+ // val ms = Matcher.initialise(rule.lhs, g, g.verts)
+ // ms.matchNewNode(vp, vt).flatMap(_.nextMatch())
+ // } else None
+ // } match {
+ // case Some((m,_)) =>
+ // //println("SUCCESS")
+ // Iterator.single(layout(Rewriter.rewrite(m, rule.rhs, rule.description)))
+ // case None =>
+ // //println("FAILED")
+ // Iterator.empty
+ // }
+ }
+ }
+
+ def REWRITE_METRIC(rules: List[Rule], metric: Graph => Int, target: Int = 0) =
new Simproc {
override def simp(g: Graph): Iterator[(Graph, Rule)] = {
- if (metric(g) <= 0) return Iterator.empty
+ if (metric(g) <= target) return Iterator.empty
for (rule <- rules) {
Matcher.findMatches(rule.lhs, g).foreach { m =>
- val (g1,r1) = Rewriter.rewrite(m, rule.rhs, rule.description)
- if (metric(g1) < metric(g)) return Iterator.single(layout((g1,r1)))
+ val (g1, r1) = Rewriter.rewrite(m, rule.rhs, rule.description)
+ if (metric(g1) < metric(g)) return Iterator.single(layout((g1, r1)))
}
}
Iterator.empty
@@ -115,43 +197,45 @@ object Simproc {
if (metric(g) <= 0) return Iterator.empty
for (rule <- rules) {
Matcher.findMatches(rule.lhs, g).foreach { m =>
- val (g1,r1) = Rewriter.rewrite(m, rule.rhs, rule.description)
- if (metric(g1) <= metric(g)) return Iterator.single(layout((g1,r1)))
+ val (g1, r1) = Rewriter.rewrite(m, rule.rhs, rule.description)
+ if (metric(g1) <= metric(g)) return Iterator.single(layout((g1, r1)))
}
}
Iterator.empty
}
}
- def REPEAT(s: Simproc): Simproc = new Simproc {
- override def simp(g: Graph): Iterator[(Graph, Rule)] = new Iterator[(Graph, Rule)] {
- var iterS: Iterator[(Graph,Rule)] = s.simp(g)
- var lastGraphS: Graph = g
-
- override def hasNext: Boolean =
- if (iterS.hasNext) true
- else {
- val iterT = s.simp(lastGraphS)
- if (iterT.hasNext) {
- iterS = iterT
- true
- } else false
- }
+ def REPEAT(s: Simproc): Simproc = (g: Graph) => new Iterator[(Graph, Rule)] {
+ var iterS: Iterator[(Graph, Rule)] = s.simp(g)
+ var lastGraphS: Graph = g
+
+ override def hasNext: Boolean =
+ if (iterS.hasNext) true
+ else {
+ val iterT = s.simp(lastGraphS)
+ if (iterT.hasNext) {
+ iterS = iterT
+ true
+ } else false
+ }
- override def next(): (Graph,Rule) =
- if (iterS.hasNext) {
- val (g1,r1) = iterS.next()
+ override def next(): (Graph, Rule) =
+ if (iterS.hasNext) {
+ val (g1, r1) = iterS.next()
+ lastGraphS = g1
+ (g1, r1)
+ } else {
+ val iterT = s.simp(lastGraphS)
+ if (iterT.hasNext) {
+ iterS = iterT
+ val (g1, r1) = iterS.next()
lastGraphS = g1
- (g1,r1)
- } else {
- val iterT = s.simp(lastGraphS)
- if (iterT.hasNext) {
- iterS = iterT
- val (g1,r1) = iterS.next()
- lastGraphS = g1
- (g1,r1)
- } else null
- }
- }
+ (g1, r1)
+ } else null
+ }
+ }
+
+ object EMPTY extends Simproc {
+ override def simp(g: Graph): Iterator[(Graph, Rule)] = Iterator.empty
}
}
\ No newline at end of file
diff --git a/scala/src/main/scala/quanto/rewrite/rewriting.sc b/scala/src/main/scala/quanto/rewrite/rewriting.sc
deleted file mode 100644
index 81308318..00000000
--- a/scala/src/main/scala/quanto/rewrite/rewriting.sc
+++ /dev/null
@@ -1,12 +0,0 @@
-import quanto.data._
-import quanto.rewrite._
-val e1 = AngleExpression.parse("x")
-val f1 = AngleExpression.parse("2 a + b")
-val p = Vector("x", "y", "z")
-val t = Vector("a", "b", "c")
-var m = AngleExpressionMatcher(p, t)
-println(m.mat)
-//m = m.mtch(e1, f1).get
-//m.toMap
-
-
diff --git a/scala/src/main/scala/quanto/util/DirectoryWatcher.scala b/scala/src/main/scala/quanto/util/DirectoryWatcher.scala
index 36fd3ac8..7c07abd7 100644
--- a/scala/src/main/scala/quanto/util/DirectoryWatcher.scala
+++ b/scala/src/main/scala/quanto/util/DirectoryWatcher.scala
@@ -3,19 +3,17 @@ package quanto.util
import java.io._
sealed abstract class FileTree
+
case class DirNode(name: String, files: Set[FileTree]) extends FileTree
+
case class FileNode(name: String) extends FileTree
class DirectoryWatcher(val dir: String, onChange: FileTree => Any) extends Thread {
var fileTree: FileTree = FileNode("")
var poll = true
- private def buildFileTree(f: File): FileTree = {
- if (f.isDirectory) DirNode(f.getName, f.listFiles.toSet.map(buildFileTree))
- else FileNode(f.getName)
- }
override def run() {
- while(poll) {
+ while (poll) {
val ft = buildFileTree(new File(dir))
if (ft != fileTree) {
fileTree = ft
@@ -27,7 +25,14 @@ class DirectoryWatcher(val dir: String, onChange: FileTree => Any) extends Threa
}
}
- def stopPolling() { poll = false }
+ def stopPolling() {
+ poll = false
+ }
+
+ private def buildFileTree(f: File): FileTree = {
+ if (f.isDirectory) DirNode(f.getName, f.listFiles.toSet.map(buildFileTree))
+ else FileNode(f.getName)
+ }
}
object DirectoryWatcher {
diff --git a/scala/src/main/scala/quanto/util/FileHelper.scala b/scala/src/main/scala/quanto/util/FileHelper.scala
index dec592ea..8a1010eb 100644
--- a/scala/src/main/scala/quanto/util/FileHelper.scala
+++ b/scala/src/main/scala/quanto/util/FileHelper.scala
@@ -1,12 +1,34 @@
package quanto.util
import java.io.File
+import java.net.URI
import quanto.util.json.Json
-import scala.util.matching.Regex
+import scala.io.Source
+
object FileHelper {
+
+ implicit def uriToFile(uri: URI): File = new File(uri)
+
+ implicit def fileToURI(file: File): URI = file.toURI
+
+ implicit def pathToFile(path: String): File = new File(path)
+
+ val Home: URI = {
+ val uri = System.getProperty("user.home").toURI
+ if (!uri.exists) throw new IllegalStateException("Couldn't access dir: " + uri)
+ uri
+ }
+
+ def printToFile(file_name: File, string: String, append: Boolean) {
+ printToFile(file_name, append) { p => {
+ p.println(string)
+ }
+ }
+ }
+
/**
* Helper method to print to a file.
*
@@ -16,7 +38,7 @@ object FileHelper {
*/
def printToFile(file_name: File, append: Boolean = true)
(op: java.io.PrintWriter => Unit) {
- val p = new java.io.PrintWriter(new java.io.FileWriter(file_name, append))
+ val p = new java.io.PrintWriter(new java.io.FileWriter(ensureParentFolderExists(file_name), append))
try {
op(p)
} finally {
@@ -24,8 +46,43 @@ object FileHelper {
}
}
+ def ensureParentFolderExists(file: File): File = {
+ ensureFolderExists(file.getParentFile)
+ file
+ }
+
+ def ensureFolderExists(file: File): File = {
+ if (!file.exists && !file.mkdirs) throw new IllegalStateException("Couldn't create dir: " + file)
+ file
+ }
+
+ def printJson(fileName: String, json: Json): Unit = {
+ val targetFile = new File(fileName)
+ val parent = targetFile.getParentFile
+ if (!parent.exists && !parent.mkdirs) throw new IllegalStateException("Couldn't create dir: " + parent)
+ json.writeTo(new File(fileName))
+ }
+
def readFile[T](file: File, conversion: Json => T): T = conversion(Json.parse(file))
+ def readJson(file: File): Json = Json.parse(file)
+
+ def readFile(file: File): List[String] = {
+ val bufferedSource = Source.fromFile(file)
+ val lines = bufferedSource.getLines().toList
+ bufferedSource.close
+ lines
+ }
+
+ def extension(file: File): String = {
+ val pattern = """.*\.(\w+)""".r
+ file.getAbsolutePath match {
+ case pattern(extension) => extension
+ case _ => ""
+ }
+ }
+
+
def readAllOfType[T](directory: String, regexFilter: String, conversion: Json => T): List[T] = {
readJSONFromDirectory(directory, regexFilter).map(conversion)
}
diff --git a/scala/src/main/scala/quanto/util/Geometry.scala b/scala/src/main/scala/quanto/util/Geometry.scala
index 849f6e9b..51acbd94 100644
--- a/scala/src/main/scala/quanto/util/Geometry.scala
+++ b/scala/src/main/scala/quanto/util/Geometry.scala
@@ -1,55 +1,75 @@
package quanto.util
-import math.{min,max}
+import scala.math.{max, min}
class RichPt(val p: Geometry.Pt) {
- def +(p1: Geometry.Pt) = (p._1 + p1._1, p._2 + p1._2)
- def -(p1: Geometry.Pt) = (p._1 - p1._1, p._2 - p1._2)
- def unary_-(p1: Geometry.Pt) = (-p._1, -p._2)
+ def +(p1: Geometry.Pt): (Double, Double) = (p._1 + p1._1, p._2 + p1._2)
+
+ def -(p1: Geometry.Pt): (Double, Double) = (p._1 - p1._1, p._2 - p1._2)
+
+ def unary_-(p1: Geometry.Pt): (Double, Double) = (-p._1, -p._2)
}
object RichPt {
- implicit def richPtToPt(rp: RichPt) = rp.p
- implicit def ptToRichPt(p: Geometry.Pt) = new RichPt(p)
+ implicit def richPtToPt(rp: RichPt): (Double, Double) = rp.p
+
+ implicit def ptToRichPt(p: Geometry.Pt): RichPt = new RichPt(p)
}
class RichRect(val r: Geometry.Rect) {
- def pad(padding: Double) =
- r match { case (((lx,ly),(ux,uy))) =>
- ((lx - padding, ly - padding), (ux + padding, uy + padding))
+ def pad(padding: Double): ((Double, Double), (Double, Double)) =
+ r match {
+ case (((lx, ly), (ux, uy))) =>
+ ((lx - padding, ly - padding), (ux + padding, uy + padding))
}
- def coords = (r._1,r._2)
- def center: Geometry.Pt = r match { case (((lx,ly),(ux,uy))) => ((lx+ux)/2.0,(ly+uy)/2.0) }
- def width: Double = r match { case ((lx,ux),_) => ux - lx }
- def height: Double = r match { case (_,(ly,uy)) => uy - ly }
- def size = (width, height)
+
+ def coords: ((Double, Double), (Double, Double)) = (r._1, r._2)
+
+ def center: Geometry.Pt = r match {
+ case (((lx, ly), (ux, uy))) => ((lx + ux) / 2.0, (ly + uy) / 2.0)
+ }
+
+ def size: (Double, Double) = (width, height)
+
+ def width: Double = r match {
+ case ((lx, ux), _) => ux - lx
+ }
+
+ def height: Double = r match {
+ case (_, (ly, uy)) => uy - ly
+ }
}
object RichRect {
- implicit def richRectToRect(rr: RichRect) = rr.r
- implicit def rectToRichRect(r: Geometry.Rect) = new RichRect(r)
+ implicit def richRectToRect(rr: RichRect): ((Double, Double), (Double, Double)) = rr.r
+
+ implicit def rectToRichRect(r: Geometry.Rect): RichRect = new RichRect(r)
}
object Geometry {
- type Pt = (Double,Double)
- type Rect = (Pt,Pt)
- def bounds(ps: Iterable[Pt]): Option[(Pt,Pt)] = {
+ type Pt = (Double, Double)
+ type Rect = (Pt, Pt)
+
+ def bounds(ps: Iterable[Pt]): Option[(Pt, Pt)] = {
val it = ps.iterator
-
+
if (it.hasNext) {
- var upper = it.next()
- var lower = upper
+ var upper = it.next()
+ var lower = upper
for (p <- it) {
- lower = (min(lower._1,p._1),min(lower._2,p._2))
- upper = (max(upper._1,p._1),max(upper._2,p._2))
+ lower = (min(lower._1, p._1), min(lower._2, p._2))
+ upper = (max(upper._1, p._1), max(upper._2, p._2))
}
- Some(lower,upper)
+ Some(lower, upper)
} else None
}
// implicit conversions for various geometric classes
implicit def ptToPoint(p: Pt): java.awt.Point = new java.awt.Point(p._1.toInt, p._2.toInt)
+
implicit def ptToPoint2D(p: Pt): java.awt.geom.Point2D = new java.awt.geom.Point2D.Double(p._1, p._2)
- implicit def pointToPt(p: java.awt.Point): Pt = (p.getX,p.getY)
- implicit def point2DToPt(p: java.awt.geom.Point2D): Pt = (p.getX,p.getY)
+
+ implicit def pointToPt(p: java.awt.Point): Pt = (p.getX, p.getY)
+
+ implicit def point2DToPt(p: java.awt.geom.Point2D): Pt = (p.getX, p.getY)
}
diff --git a/scala/src/main/scala/quanto/util/Globals.scala b/scala/src/main/scala/quanto/util/Globals.scala
index 54872e78..3207b3e1 100644
--- a/scala/src/main/scala/quanto/util/Globals.scala
+++ b/scala/src/main/scala/quanto/util/Globals.scala
@@ -1,14 +1,19 @@
package quanto.util
-import java.io.File
import java.awt.event.InputEvent
+import java.io.File
object Globals {
- def CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask
- def CommandDownMask = if (CommandMask == InputEvent.META_MASK) InputEvent.META_DOWN_MASK
- else InputEvent.CTRL_DOWN_MASK
+ def CommandDownMask: Int = if (CommandMask == InputEvent.META_MASK) InputEvent.META_DOWN_MASK
+ else InputEvent.CTRL_DOWN_MASK
+
+ def CommandMask: Int = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask
+
+ def isBundle: Boolean = isMacBundle || isWindowsBundle || isLinuxBundle
+
def isMacBundle: Boolean = new File("osx-bundle").exists
+
def isLinuxBundle: Boolean = new File("linux-bundle").exists
+
def isWindowsBundle: Boolean = new File("windows-bundle").exists
- def isBundle: Boolean = isMacBundle || isWindowsBundle || isLinuxBundle
}
diff --git a/scala/src/main/scala/quanto/util/QuadTree.scala b/scala/src/main/scala/quanto/util/QuadTree.scala
index d1889a07..8a64dd6a 100644
--- a/scala/src/main/scala/quanto/util/QuadTree.scala
+++ b/scala/src/main/scala/quanto/util/QuadTree.scala
@@ -1,32 +1,34 @@
package quanto.util
-import math.{abs,min,max}
+import scala.math.{abs, max, min}
/**
- * A quadtree for 2D spacial queries
- */
+ * A quadtree for 2D spacial queries
+ */
sealed abstract class QuadTree[A] {
def x1: Double
+
def y1: Double
+
def x2: Double
+
def y2: Double
def value: Option[A]
- def p: (Double,Double)
- def insert(p: (Double,Double), v: A): QuadTree[A]
- // visit each node/leaf with 'f', recursing until f returns true or a leaf is encountered
- def visit(f : QuadTree[A] => Boolean)
+ def p: (Double, Double)
- private def intervalOverlap(s1: Double, t1: Double, s2: Double, t2: Double) =
- if (s1 <= s2) t1 >= s2 else t2 >= s1
+ def insert(p: (Double, Double), v: A): QuadTree[A]
+
+ // visit each node/leaf with 'f', recursing until f returns true or a leaf is encountered
+ def visit(f: QuadTree[A] => Boolean)
// query the tree, returning all values in regions with the given bound
def query(xLower: Double, yLower: Double, xUpper: Double, yUpper: Double): Iterable[A] = {
val result = collection.mutable.ListBuffer[A]()
visit { tr =>
- if (intervalOverlap(tr.x1,tr.x2,xLower,xUpper) && intervalOverlap(tr.y1,tr.y2,yLower,yUpper)) {
+ if (intervalOverlap(tr.x1, tr.x2, xLower, xUpper) && intervalOverlap(tr.y1, tr.y2, yLower, yUpper)) {
tr.value.map(v => result += v)
false // recurse
} else {
@@ -35,57 +37,59 @@ sealed abstract class QuadTree[A] {
}
result
}
+
+ private def intervalOverlap(s1: Double, t1: Double, s2: Double, t2: Double) =
+ if (s1 <= s2) t1 >= s2 else t2 >= s1
}
object QuadTree {
- def apply[A](items: Iterable[((Double,Double),A)]): QuadTree[A] = {
+ def apply[A](items: Iterable[((Double, Double), A)]): QuadTree[A] = {
items.headOption match {
- case Some(((hx,hy),_)) =>
+ case Some(((hx, hy), _)) =>
// compute bounding box
- val (x1,y1,x2,y2) = items.tail.foldLeft(hx,hy,hx,hy) {
- case ((minX,minY,maxX,maxY),((x,y),_)) => (min(minX,x),min(minY,y),max(maxX,x),max(maxY,y))
+ val (x1, y1, x2, y2) = items.tail.foldLeft(hx, hy, hx, hy) {
+ case ((minX, minY, maxX, maxY), ((x, y), _)) => (min(minX, x), min(minY, y), max(maxX, x), max(maxY, y))
}
// make bounds square
val sz = max(x2 - x1, y2 - y1)
- val emptyTree: QuadTree[A] = QuadNode[A](x1,y1,x1+sz,y1+sz)
+ val emptyTree: QuadTree[A] = QuadNode[A](x1, y1, x1 + sz, y1 + sz)
- items.foldLeft(emptyTree) { case (tr,(p,v)) => tr.insert(p,v) }
- case None => QuadNode[A](0,0,0,0)
+ items.foldLeft(emptyTree) { case (tr, (p, v)) => tr.insert(p, v) }
+ case None => QuadNode[A](0, 0, 0, 0)
}
}
}
case class QuadNode[A](
- x1: Double, y1: Double,
- x2: Double, y2: Double,
-
- value: Option[A],
- p: (Double,Double) = (0,0),
- nw: QuadTree[A],
- ne: QuadTree[A],
- sw: QuadTree[A],
- se: QuadTree[A]) extends QuadTree[A]
-{
- lazy val children = Seq(nw,ne,sw,se)
- val midX = (x1 + x2) / 2
- val midY = (y1 + y2) / 2
-
- def insert(p: (Double,Double), v: A) = p match {
- case (x,y) if x < midX && y < midY => copy(nw = nw.insert(p,v))
- case (x,y) if x >= midX && y < midY => copy(ne = ne.insert(p,v))
- case (x,y) if x < midX && y >= midY => copy(sw = sw.insert(p,v))
- case (x,y) if x >= midX && y >= midY => copy(se = se.insert(p,v))
+ x1: Double, y1: Double,
+ x2: Double, y2: Double,
+
+ value: Option[A],
+ p: (Double, Double) = (0, 0),
+ nw: QuadTree[A],
+ ne: QuadTree[A],
+ sw: QuadTree[A],
+ se: QuadTree[A]) extends QuadTree[A] {
+ lazy val children = Seq(nw, ne, sw, se)
+ val midX: Double = (x1 + x2) / 2
+ val midY: Double = (y1 + y2) / 2
+
+ def insert(p: (Double, Double), v: A): QuadNode[A] = p match {
+ case (x, y) if x < midX && y < midY => copy(nw = nw.insert(p, v))
+ case (x, y) if x >= midX && y < midY => copy(ne = ne.insert(p, v))
+ case (x, y) if x < midX && y >= midY => copy(sw = sw.insert(p, v))
+ case (x, y) if x >= midX && y >= midY => copy(se = se.insert(p, v))
}
- def visit(f : QuadTree[A] => Boolean) {
- if (!f(this)) {
- nw.visit(f)
- ne.visit(f)
- sw.visit(f)
- se.visit(f)
- }
+ def visit(f: QuadTree[A] => Boolean) {
+ if (!f(this)) {
+ nw.visit(f)
+ ne.visit(f)
+ sw.visit(f)
+ se.visit(f)
+ }
}
//def mapValue[B](f: Option[A] => Option[B]) = QuadNode[B](x1,y1,x2,y2,f(value),p,nw,ne,sw,se)
@@ -95,41 +99,43 @@ object QuadNode {
// a new empty quad node with given bounds
def apply[A](x1: Double, y1: Double, x2: Double, y2: Double): QuadNode[A] = {
- QuadNode[A](x1,y1,x2,y2,None,(0.0,0.0))
+ QuadNode[A](x1, y1, x2, y2, None, (0.0, 0.0))
}
// a new node with a value, point, and the given bounds
- def apply[A](x1: Double, y1: Double, x2: Double, y2: Double, value: Option[A], p: (Double,Double)): QuadNode[A] = {
+ def apply[A](x1: Double, y1: Double, x2: Double, y2: Double, value: Option[A], p: (Double, Double)): QuadNode[A] = {
val midX = (x1 + x2) / 2
val midY = (y1 + y2) / 2
- val nw = QuadLeaf[A](x1,y1,midX,midY)
- val ne = QuadLeaf[A](midX,y1,x2,midY)
- val sw = QuadLeaf[A](x1,midY,midX,y2)
- val se = QuadLeaf[A](midX,midY,x2,y2)
- QuadNode[A](x1,y1,x2,y2,value,p,nw,ne,sw,se)
+ val nw = QuadLeaf[A](x1, y1, midX, midY)
+ val ne = QuadLeaf[A](midX, y1, x2, midY)
+ val sw = QuadLeaf[A](x1, midY, midX, y2)
+ val se = QuadLeaf[A](midX, midY, x2, y2)
+ QuadNode[A](x1, y1, x2, y2, value, p, nw, ne, sw, se)
}
}
case class QuadLeaf[A](
- x1: Double, y1: Double,
- x2: Double, y2: Double,
- value: Option[A] = None,
- p: (Double,Double) = (0,0)) extends QuadTree[A]
-{
- def insert(p1: (Double,Double), v1: A) =
+ x1: Double, y1: Double,
+ x2: Double, y2: Double,
+ value: Option[A] = None,
+ p: (Double, Double) = (0, 0)) extends QuadTree[A] {
+ def insert(p1: (Double, Double), v1: A): QuadTree[A] =
value match {
case None => copy(p = p1, value = Some(v1))
case Some(v) =>
if (abs(p._1 - p1._1) < 0.01 && abs(p._2 - p1._2) < 0.01) {
// if they are very close, place the current value on the node, and add the new value as a leaf
- QuadNode[A](x1,y1,x2,y2,Some(v),p).insert(p1,v1)
+ QuadNode[A](x1, y1, x2, y2, Some(v), p).insert(p1, v1)
} else {
// otherwise place them on their own leaves
- QuadNode[A](x1,y1,x2,y2).insert(p,v).insert(p1,v1)
+ QuadNode[A](x1, y1, x2, y2).insert(p, v).insert(p1, v1)
}
}
- def visit(f : QuadTree[A] => Boolean) { f(this) }
- def mapValue[B](f: A => B): QuadLeaf[B] = QuadLeaf[B](x1,y1,x2,y2,value.map(f(_)),p)
+ def visit(f: QuadTree[A] => Boolean) {
+ f(this)
+ }
+
+ def mapValue[B](f: A => B): QuadLeaf[B] = QuadLeaf[B](x1, y1, x2, y2, value.map(f(_)), p)
}
diff --git a/scala/src/main/scala/quanto/util/Rational.scala b/scala/src/main/scala/quanto/util/Rational.scala
index 88075b96..558ceb78 100644
--- a/scala/src/main/scala/quanto/util/Rational.scala
+++ b/scala/src/main/scala/quanto/util/Rational.scala
@@ -3,45 +3,60 @@ package quanto.util
class RationalDivideByZeroException(r: Rational)
extends Exception("Attempted to divide by 0 in (" + r.n + "/" + r.d + ")")
-class Rational(numerator : Int, denominator : Int) extends Ordered[Rational] {
- if (denominator == 0) throw new RationalDivideByZeroException(this)
- private val r = Rational.gcd(numerator,denominator)
+class Rational(numerator: Int, denominator: Int) extends Ordered[Rational] {
+
+ private val r = Rational.gcd(numerator, denominator)
private val dsn = if (denominator < 0) -1 else 1
- val n: Int = dsn * numerator/r
- val d: Int = dsn * denominator/r
+ if (denominator == 0) throw new RationalDivideByZeroException(this)
+ val n: Int = dsn * numerator / r
+ val d: Int = dsn * denominator / r
+
+ def +(r: Rational) = Rational(n * r.d + r.n * d, d * r.d)
+
+ def -(r: Rational) = Rational(n * r.d - r.n * d, d * r.d)
- def +(r : Rational) = Rational(n * r.d + r.n * d, d * r.d)
- def -(r : Rational) = Rational(n * r.d - r.n * d, d * r.d)
- def *(r : Rational) = Rational(n * r.n, d * r.d)
- def *(i : Int) = Rational(n * i, d)
- def /(r : Rational) = Rational(n * r.d, d * r.n)
- def mod(i : Int): Rational =
+ def *(r: Rational) = Rational(n * r.n, d * r.d)
+
+ def *(i: Int) = Rational(n * i, d)
+
+ def /(r: Rational) = Rational(n * r.d, d * r.n)
+
+ def mod(i: Int): Rational =
if (n < 0) Rational((n % (d * i)) + (d * i), d)
else Rational(n % (d * i), d)
- def inv = Rational(d,n)
- override def equals(r : Any):Boolean = r match {
- case r1 : Rational => n == r1.n && d == r1.d
+ def inv = Rational(d, n)
+
+ override def equals(r: Any): Boolean = r match {
+ case r1: Rational => n == r1.n && d == r1.d
case _ => false
}
- override def compare(r : Rational): Int = { n * r.d - r.n * d }
+
+ override def compare(r: Rational): Int = {
+ n * r.d - r.n * d
+ }
+
def isZero: Boolean = n == 0
+
def isOne: Boolean = n == 1 && d == 1
- override def toString:String = if (d == 1) n.toString else "(" + n + "/" + d + ")"
+ override def toString: String = if (d == 1) n.toString else "(" + n + "/" + d + ")"
}
object Rational {
def apply(numerator: Int, denominator: Int) = new Rational(numerator, denominator)
- def apply(numerator: Int) = new Rational(numerator, 1)
- private def gcd(a: Int,b: Int): Int = {
- if(b == 0) Math.abs(a) else gcd(b, a%b)
+ private def gcd(a: Int, b: Int): Int = {
+ if (b == 0) Math.abs(a) else gcd(b, a % b)
}
- implicit def intToRational(i : Int) : Rational = Rational(i)
+ def apply(numerator: Int) = new Rational(numerator, 1)
+
+
+
+ implicit def intToRational(i: Int): Rational = Rational(i)
- implicit def rationalToDouble(r: Rational) : Double = r.n.toFloat / r.d.toFloat
+ implicit def rationalToDouble(r: Rational): Double = r.n.toFloat / r.d.toFloat
}
diff --git a/scala/src/main/scala/quanto/util/RationalMatrix.scala b/scala/src/main/scala/quanto/util/RationalMatrix.scala
index 120bde65..42a135d3 100644
--- a/scala/src/main/scala/quanto/util/RationalMatrix.scala
+++ b/scala/src/main/scala/quanto/util/RationalMatrix.scala
@@ -3,48 +3,52 @@ package quanto.util
/**
*
* A matrix of the form (M1|M2) representing a system of equations:
- * M1 v = M2 c
+ * M1 v = M2 c
* where "v" is a vector of pattern variables and "c" a vector of target variables (treated as constants) where the
* final constant "pi" is treated modulo 2.
*
* The variables in "t" are treated as constants when judging whether a solution exists.
*
- * @param mat A 2D Vector of rational numbers
- * @param line The position of the line between LHS and RHS matrices, i.e. the number of free variables
+ * @param mat A 2D Vector of rational numbers
+ * @param line The position of the line between LHS and RHS matrices, i.e. the number of free variables
* @param constModulo Modulus to apply to the constant or -1 for no modulus.
*/
-class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val constModulo : Int = 2) {
+class RationalMatrix(val mat: Vector[Vector[Rational]], val line: Int, val constModulo: Option[Int] = Some(2)) {
+
import RationalMatrix._
- def numRows : Int = mat.length
- def numCols : Int = if (mat.isEmpty) 0 else mat(0).length
- override def equals(that: Any) = that match {
- case m1 : RationalMatrix =>
+ def numRows: Int = mat.length
+
+ override def equals(that: Any): Boolean = that match {
+ case m1: RationalMatrix =>
mat == m1.mat && line == line && constModulo == constModulo
case _ => false
}
- def apply(i : Int): Vector[Rational] = mat(i)
- def rows = mat
+ def apply(i: Int): Vector[Rational] = mat(i)
- //def insertVar = new RationalMatrix(mat.map { row => ins(row, line) }, line+1, constModulo)
- //def insertConst = new RationalMatrix(mat.map { row => ins(row, row.length-1) }, line, constModulo)
+ def rows: Vector[Vector[Rational]] = mat
- def padTo(vCols: Int, cCols: Int) = {
+ def padTo(vCols: Int, cCols: Int): RationalMatrix = {
val m = vCols - line
val n = cCols - (numCols - line - 1)
if (m >= 0 && n >= 0)
new RationalMatrix(mat.map { row =>
row.slice(0, line) ++ Vector.fill(m)(Rational(0)) ++
- row.slice(line, row.length - 1) ++ Vector.fill(n)(Rational(0)) :+
- row(row.length - 1)
+ row.slice(line, row.length - 1) ++ Vector.fill(n)(Rational(0)) :+
+ row(row.length - 1)
}, math.max(line, vCols), constModulo)
else this
}
-// private def ins(row : Vector[Rational], i : Int) : Vector[Rational] =
-// row.take(i) ++ (Rational(0) +: row.takeRight(row.length - i))
+ //def insertVar = new RationalMatrix(mat.map { row => ins(row, line) }, line+1, constModulo)
+ //def insertConst = new RationalMatrix(mat.map { row => ins(row, row.length-1) }, line, constModulo)
+
+ def numCols: Int = if (mat.isEmpty) 0 else mat(0).length
+
+ // private def ins(row : Vector[Rational], i : Int) : Vector[Rational] =
+ // row.take(i) ++ (Rational(0) +: row.takeRight(row.length - i))
def isReduced: Boolean = {
var p = -1
@@ -59,8 +63,8 @@ class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val cons
// translate matrix to echelon form. returns None if there is an inconsistent row, i.e. a row of the form
// (0..0|v!=0..0)
- def gauss : Option[RationalMatrix] = {
- val empty : Option[RationalMatrix] = Some(new RationalMatrix(Vector(), line, constModulo))
+ def gauss: Option[RationalMatrix] = {
+ val empty: Option[RationalMatrix] = Some(new RationalMatrix(Vector(), line, constModulo))
mat.foldLeft(empty) {
case (Some(m), r) => m.gaussUpdate(r)
case _ => None
@@ -68,7 +72,7 @@ class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val cons
}
// add a new row, keeping the matrix in echelon form. Returns None if new row introduces an inconsistency.
- def gaussUpdate(row : Vector[Rational]) : Option[RationalMatrix] = {
+ def gaussUpdate(row: Vector[Rational]): Option[RationalMatrix] = {
var r = row
// gaussian reduce the new row
@@ -89,8 +93,8 @@ class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val cons
} else {
// otherwise use the new row to further reduce existing rows, and insert into the correct position
r = normaliseRow(r, p)
- val (mat1, inserted) = mat.foldLeft((Vector[Vector[Rational]](),false)) {
- case ((rows, ins),row1) =>
+ val (mat1, inserted) = mat.foldLeft((Vector[Vector[Rational]](), false)) {
+ case ((rows, ins), row1) =>
if (findPivot(row1) > p) {
if (!ins) (rows :+ r :+ row1, true)
else (rows :+ row1, ins)
@@ -104,22 +108,22 @@ class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val cons
}
// multiply row by a scalar to make pivot = 1
- private def normaliseRow(row : Vector[Rational], p : Int) = {
+ private def normaliseRow(row: Vector[Rational], p: Int) = {
val r = row(p).inv
Vector.tabulate(row.length) { i =>
- if (constModulo != -1 && i == row.length) (row(i) * r) mod constModulo
+ if (constModulo.nonEmpty && i == row.length) (row(i) * r) mod constModulo.get
else row(i) * r
}
}
// subtract a multiple of row2 (with pivot p2) from row1 to make row1(p2) == 0. Assumes row2 is
// normalised.
- private def reduceWith(row1 : Vector[Rational], row2 : Vector[Rational], p2 : Int) = {
+ private def reduceWith(row1: Vector[Rational], row2: Vector[Rational], p2: Int) = {
val n = row1(p2)
if (!n.isZero) {
Vector.tabulate(row1.length) { i =>
val n1 = row1(i) - (n * row2(i))
- if (constModulo != -1 && i == row1.length) n1 mod constModulo
+ if (constModulo.nonEmpty && i == row1.length) n1 mod constModulo.get
else n1
}
} else {
@@ -127,11 +131,11 @@ class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val cons
}
}
- override def toString = mat.foldLeft("\n") { (s,row) =>
+ override def toString: String = mat.foldLeft("\n") { (s, row) =>
var t = "["
for (i <- row.indices) {
t += " " + row(i).toString
- if (i == line-1) t += " |"
+ if (i == line - 1) t += " |"
else t += " "
}
s + t + "]\n"
@@ -139,7 +143,7 @@ class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val cons
}
object RationalMatrix {
- def findPivot(row : Vector[Rational]) : Int =
+ def findPivot(row: Vector[Rational]): Int =
row.indexWhere(!_.isZero)
}
diff --git a/scala/src/main/scala/quanto/util/RichCubicCurve.scala b/scala/src/main/scala/quanto/util/RichCubicCurve.scala
index 80b23501..07446fd6 100644
--- a/scala/src/main/scala/quanto/util/RichCubicCurve.scala
+++ b/scala/src/main/scala/quanto/util/RichCubicCurve.scala
@@ -3,21 +3,23 @@ package quanto.util
import java.awt.geom.CubicCurve2D
class RichCubicCurve(curve: CubicCurve2D) {
+
import RichCubicCurve._
- def pointAt(dist: Double) = (
+
+ def pointAt(dist: Double): (Double, Double) = (
bezierInterpolate(dist, curve.getX1, curve.getCtrlX1, curve.getCtrlX2, curve.getX2),
bezierInterpolate(dist, curve.getY1, curve.getCtrlY1, curve.getCtrlY2, curve.getY2)
)
}
object RichCubicCurve {
- def bezierInterpolate(dist: Double, c0: Double, c1: Double, c2: Double, c3: Double) = {
- val distp = 1 - dist
+ def bezierInterpolate(dist: Double, c0: Double, c1: Double, c2: Double, c3: Double): Double = {
+ val distP = 1 - dist
- (distp*distp*distp) * c0 +
- 3.0 * (distp*distp) * dist * c1 +
- 3.0 * (dist*dist) * distp * c2 +
- (dist*dist*dist) * c3
+ (distP * distP * distP) * c0 +
+ 3.0 * (distP * distP) * dist * c1 +
+ 3.0 * (dist * dist) * distP * c2 +
+ (dist * dist * dist) * c3
}
implicit def cubicCurveToRichCubicCurve(curve: CubicCurve2D): RichCubicCurve = new RichCubicCurve(curve)
diff --git a/scala/src/main/scala/quanto/util/Scripting.scala b/scala/src/main/scala/quanto/util/Scripting.scala
index 10fd3ef9..2d999811 100644
--- a/scala/src/main/scala/quanto/util/Scripting.scala
+++ b/scala/src/main/scala/quanto/util/Scripting.scala
@@ -11,8 +11,6 @@ import quanto.util.json._
import scala.collection.JavaConverters._
import scala.concurrent.duration._
-
-
// object providing functions specifically for python scripting
object Scripting {
@@ -56,6 +54,12 @@ object Scripting {
listToPyList(pyListToList[String](ss).map(load_rule))
}
+ def include_inverses(rs: PyList) : PyList = {
+ listToPyList(pyListToList[Rule](rs).flatMap(r => {
+ List(r, r.inverse)
+ }).distinct)
+ }
+
def plug(g1: Graph, g2: Graph, b1: String, b2: String) =
g1.plugGraph(g2, VName(b1), VName(b2))
@@ -94,8 +98,13 @@ object Scripting {
// }
}
+ def new_graph_from_json(jsonString: String): Unit = {
+ val doc = QuantoDerive.newGraph()
+ doc.replaceJson(Json.parse(jsonString))
+ }
+
class derivation(start : Graph) {
- var d = Derivation(theory, start)
+ var d = Derivation(start)
def rewrite(r: (String, Rule)) = {
false
@@ -135,7 +144,7 @@ object Scripting {
}
def vertex_angle_is(g: Graph, v: VName, a: String) = g.vdata(v) match {
- case nv: NodeV => nv.angle == AngleExpression.parse(a)
+ case nv: NodeV => nv.phaseData.values == CompositeExpression.parseKnowingTypes(a, nv.phaseData.valueTypes)
case _ => false
}
@@ -148,6 +157,39 @@ object Scripting {
// python wrappers for simproc combinators
val EMPTY = Simproc.EMPTY
+
+ //takes in a python function that accepts a json version of the graph
+ // and outputs new json content for a new graph
+ def JSON_REWRITE(func: PyFunction) = new Simproc{
+ override def simp(g: Graph): Iterator[(Graph, Rule)] = {
+ val input = g.toJson().toString()
+ val output = Py.tojava(func.__call__(Py.java2py(input)),classOf[String])
+ Iterator.single((Graph.fromJson(Json.parse(output),project.theory),Rule(Graph(),Graph())))
+ }
+ }
+
+ def JSON_REWRITE_STEPS(starter: PyMethod, step_getter: PyMethod, name_getter: PyMethod) = new Simproc{
+ override def simp(g: Graph): Iterator[(Graph, Rule)] = {
+ val input = g.toJson().toString()
+ val total_steps = Py.tojava(starter.__call__(Py.java2py(input)),classOf[Integer])
+ var index = 0
+ new Iterator[(Graph, Rule)] {
+ override def length = total_steps
+ override def hasNext: Boolean = index < total_steps
+ override def next(): (Graph, Rule) = {
+ if (index < total_steps) {
+ val name = Py.tojava(name_getter.__call__(Py.java2py(index)), classOf[String])
+ println("Adding rewrite step: " + name)
+ val js = Py.tojava(step_getter.__call__(Py.java2py(index)), classOf[String])
+ val rule = Rule(Graph(), Graph(), None, RuleDesc(name))
+ index += 1
+ (Graph.fromJson(Json.parse(js), project.theory), rule)
+ }
+ else null
+ }
+ }
+ }
+ }
def REWRITE(o: Object) = o match {
case list: PyList => Simproc.REWRITE(pyListToList(list))
case _ => Simproc.REWRITE(List(o.asInstanceOf[Rule]))
@@ -162,6 +204,15 @@ object Scripting {
Simproc.REWRITE_METRIC(rules, {g => Py.py2int(metric.__call__(Py.java2py(g)))})
}
+ def REWRITE_METRIC_TO(o: Object, metric: PyFunction, target: Int) = {
+ val rules = o match {
+ case list: PyList => pyListToList(list)
+ case _ => List(o.asInstanceOf[Rule])
+ }
+
+ Simproc.REWRITE_METRIC(rules, {g => Py.py2int(metric.__call__(Py.java2py(g)))}, target)
+ }
+
def REWRITE_WEAK_METRIC(o: Object, metric: PyFunction) = {
val rules = o match {
case list: PyList => pyListToList(list)
@@ -180,14 +231,33 @@ object Scripting {
})
}
+ def ANNEAL(r: PyList, steps: Int, dilation: Double) = Simproc.ANNEAL(pyListToList[Rule](r), steps, dilation)
+
+ def LOG(graphEval: PyFunction) = Simproc.LOG(
+ g => {
+ val pyReturn = graphEval.__call__(Py.java2py(g))
+ if (pyReturn.isInstanceOf[PyString]) pyReturn.toString else ""
+ }
+ )
+
+ def REWRITE_TARGET_LIST(rule: Rule, v: String, tlist: PyList) = Simproc.REWRITE_TARGET_LIST(rule, VName(v), pyListToList(tlist))
+
def REPEAT(s: Simproc) = Simproc.REPEAT(s)
// REDUCE_XXX(-) := REPEAT(REWRITE_XXX(-))
def REDUCE(o: Object) = REPEAT(REWRITE(o))
def REDUCE_TARGETED(rule: Rule, v: String, targ: PyFunction) = REPEAT(REWRITE_TARGETED(rule, v, targ))
def REDUCE_METRIC(o: Object, metric: PyFunction) = REPEAT(REWRITE_METRIC(o, metric))
+ def REDUCE_METRIC_TO(o: Object, metric: PyFunction, target: Int) = REPEAT(REWRITE_METRIC_TO(o, metric, target))
def REDUCE_WEAK_METRIC(o: Object, metric: PyFunction) = REPEAT(REWRITE_WEAK_METRIC(o, metric))
+ private def register_simproc(simprocName: String, simproc: Simproc, sourceFile: String): Unit = {
+ simproc.sourceFile = sourceFile
+ project.simprocs += simprocName -> simproc
+ }
+
+ def register_simproc(s: String, sp: Simproc): Unit = {
+ register_simproc(s, sp, project.lastRunPythonFilePath.getOrElse(""))
+ }
- def register_simproc(s: String, sp: Simproc) { project.simprocs += s -> sp }
}
diff --git a/scala/src/main/scala/quanto/util/SignallingStreamRedirector.scala b/scala/src/main/scala/quanto/util/SignallingStreamRedirector.scala
index 61fbc8de..3c54c011 100644
--- a/scala/src/main/scala/quanto/util/SignallingStreamRedirector.scala
+++ b/scala/src/main/scala/quanto/util/SignallingStreamRedirector.scala
@@ -8,12 +8,19 @@ import java.io._
//case class InterruptedSignal(id: Int) extends Signal(id)
sealed abstract class MessagePart
+
case class CodePart(c: Char) extends MessagePart
+
case class IntPart(i: Int) extends MessagePart
+
case class StringPart(s: String) extends MessagePart
case class StreamMessage(parts: MessagePart*) {
+ def writeTo(out: OutputStream) {
+ writeTo(new OutputStreamWriter(out, "ISO-8859-1"))
+ }
+
def writeTo(out: Writer) {
parts.foreach {
case CodePart(c: Char) =>
@@ -28,15 +35,11 @@ case class StreamMessage(parts: MessagePart*) {
out.flush()
}
- def writeTo(out: OutputStream) {
- writeTo(new OutputStreamWriter(out, "ISO-8859-1"))
- }
-
- def stripCodes = parts.filter{ case _: CodePart => false ; case _ => true }
+ def stripCodes: Seq[MessagePart] = parts.filter { case _: CodePart => false; case _ => true }
}
object StreamMessage {
- def compileMessage(id: Int, fileName: String, code: String) = {
+ def compileMessage(id: Int, fileName: String, code: String): StreamMessage = {
new StreamMessage(
CodePart('R'),
IntPart(id), CodePart(','),
@@ -52,85 +55,24 @@ object StreamMessage {
}
/**
- * A simple stream redirector that allows bits of code to listen for signals coming over the stream, all of
- * this form: <<[S](id)>>, <<[F](id)>>, <<[I](id)>> indicating success, failure, and interrupt (along with
- * some identifying code).
- *
- */
+ * A simple stream redirector that allows bits of code to listen for signals coming over the stream, all of
+ * this form: <<[S](id)>>, <<[F](id)>>, <<[I](id)>> indicating success, failure, and interrupt (along with
+ * some identifying code).
+ *
+ */
class SignallingStreamRedirector(from: InputStream, to: Option[OutputStream] = None)
-extends Thread("Signalling Stream Redirector") {
+ extends Thread("Signalling Stream Redirector") {
+ private val listeners = collection.mutable.Map[Int, List[(StreamMessage => Any)]]()
+ private val outputStreams = collection.mutable.Buffer[OutputStream]()
private var state = 0
private var currentCode: Option[Char] = None
private var currentId = 0
private var buf = new StringBuffer
private var msgParts = Seq[MessagePart]()
- private val listeners = collection.mutable.Map[Int, List[(StreamMessage => Any)]]()
- private val outputStreams = collection.mutable.Buffer[OutputStream]()
to.map { s => outputStreams += s }
- private def fire() {
- msgParts match {
- case (_ :: IntPart(id) :: _) =>
- val msg = StreamMessage(msgParts: _*)
- listeners.synchronized {
- listeners.remove(id).map { _.reverse.foreach( f => f(msg)) }
- }
- case _ =>
- println("Got bad message: " + msgParts)
- }
- }
-
- // process string via tiny state machine
-// private def processChar(c: Char) = state match {
-// case 0 => if (c == '<') state = 1
-// case 1 => if (c == '<') state = 2 else state = 0
-// case 2 => if (c == '[') state = 3 else state = 0
-// case 3 => if (c == 'S' || c == 'F' || c == 'I') { currentCode = c ; state = 4 }
-// else { currentCode = 'X' ; state = 0 }
-// case 4 => if (c == ']') state = 5
-// else { currentCode = 'X' ; state = 0 }
-// case 5 => if (c.isDigit) currentId = (10 * currentId) + c.toString.toInt
-// else if (c == '>') state = 6
-// else { currentCode = 'X' ; currentId = 0; state = 0 }
-// case 6 => if (c == '>') fire(currentCode, currentId)
-// currentCode = 'X' ; currentId = 0; state = 0
-// }
-
- private def resetState() {
- msgParts = List()
- buf = new StringBuffer
- currentCode = None
- currentId = 0
- state = 0
- }
-
- private def processChar(c: Char) =
- state match {
- case 0 =>
- if (c == '\u001B') { state = 1 ; false }
- else { state = 0; true }
- case 1 =>
- msgParts :+= CodePart(c)
- currentCode = Some(c.toLower)
- state = 2
- false
- case 2 =>
- if (c.isDigit) { currentId = (10 * currentId) + c.toString.toInt }
- else if (c == '\u001B') { msgParts :+= IntPart(currentId); state = 3 }
- else { resetState() }
- false
- case 3 => msgParts :+= CodePart(c)
- if (Some(c) == currentCode) { fire(); resetState() }
- else { state = 4 }
- false
- case 4 =>
- if (c == '\u001B') { msgParts :+= StringPart(buf.toString); buf = new StringBuffer; state = 3 }
- else { buf.append(c) }
- false
- }
-
- def addListener(id: Int)(f : StreamMessage => Any) {
+ def addListener(id: Int)(f: StreamMessage => Any) {
listeners.synchronized {
listeners.put(id, listeners.get(id) match {
case Some(list) => f :: list
@@ -139,13 +81,29 @@ extends Thread("Signalling Stream Redirector") {
}
}
+ // process string via tiny state machine
+ // private def processChar(c: Char) = state match {
+ // case 0 => if (c == '<') state = 1
+ // case 1 => if (c == '<') state = 2 else state = 0
+ // case 2 => if (c == '[') state = 3 else state = 0
+ // case 3 => if (c == 'S' || c == 'F' || c == 'I') { currentCode = c ; state = 4 }
+ // else { currentCode = 'X' ; state = 0 }
+ // case 4 => if (c == ']') state = 5
+ // else { currentCode = 'X' ; state = 0 }
+ // case 5 => if (c.isDigit) currentId = (10 * currentId) + c.toString.toInt
+ // else if (c == '>') state = 6
+ // else { currentCode = 'X' ; currentId = 0; state = 0 }
+ // case 6 => if (c == '>') fire(currentCode, currentId)
+ // currentCode = 'X' ; currentId = 0; state = 0
+ // }
+
def addOutputStream(out: OutputStream) {
outputStreams.synchronized {
outputStreams += out
}
}
- def removeOutputStream(out : OutputStream) {
+ def removeOutputStream(out: OutputStream) {
outputStreams.synchronized {
outputStreams -= out
}
@@ -154,12 +112,12 @@ extends Thread("Signalling Stream Redirector") {
override def run() {
try {
val buffer: Array[Byte] = new Array[Byte](200)
- val buffer1 : Array[Byte] = new Array[Byte](200)
+ val buffer1: Array[Byte] = new Array[Byte](200)
var count: Int = from.read(buffer)
while (count != -1) {
var j = 0
- for (i <- 0 to count - 1) {
+ for (i <- 0 until count) {
if (processChar(buffer(i).toChar)) {
buffer1(j) = buffer(i)
j += 1
@@ -181,4 +139,69 @@ extends Thread("Signalling Stream Redirector") {
ex.printStackTrace()
}
}
+
+ private def processChar(c: Char) =
+ state match {
+ case 0 =>
+ if (c == '\u001B') {
+ state = 1; false
+ }
+ else {
+ state = 0; true
+ }
+ case 1 =>
+ msgParts :+= CodePart(c)
+ currentCode = Some(c.toLower)
+ state = 2
+ false
+ case 2 =>
+ if (c.isDigit) {
+ currentId = (10 * currentId) + c.toString.toInt
+ }
+ else if (c == '\u001B') {
+ msgParts :+= IntPart(currentId); state = 3
+ }
+ else {
+ resetState()
+ }
+ false
+ case 3 => msgParts :+= CodePart(c)
+ if (currentCode.contains(c)) {
+ fire(); resetState()
+ }
+ else {
+ state = 4
+ }
+ false
+ case 4 =>
+ if (c == '\u001B') {
+ msgParts :+= StringPart(buf.toString); buf = new StringBuffer; state = 3
+ }
+ else {
+ buf.append(c)
+ }
+ false
+ }
+
+ private def fire() {
+ msgParts match {
+ case (_ :: IntPart(id) :: _) =>
+ val msg = StreamMessage(msgParts: _*)
+ listeners.synchronized {
+ listeners.remove(id).foreach {
+ _.reverse.foreach(f => f(msg))
+ }
+ }
+ case _ =>
+ println("Got bad message: " + msgParts)
+ }
+ }
+
+ private def resetState() {
+ msgParts = List()
+ buf = new StringBuffer
+ currentCode = None
+ currentId = 0
+ state = 0
+ }
}
diff --git a/scala/src/main/scala/quanto/util/SwingTimer.scala b/scala/src/main/scala/quanto/util/SwingTimer.scala
index e824d848..1267e128 100644
--- a/scala/src/main/scala/quanto/util/SwingTimer.scala
+++ b/scala/src/main/scala/quanto/util/SwingTimer.scala
@@ -3,7 +3,7 @@ package quanto.util
object SwingTimer {
def apply(interval: Int, repeats: Boolean = true)(op: => Unit) {
val timeOut = new javax.swing.AbstractAction() {
- def actionPerformed(e : java.awt.event.ActionEvent) = op
+ def actionPerformed(e: java.awt.event.ActionEvent): Unit = op
}
val t = new javax.swing.Timer(interval, timeOut)
t.setRepeats(repeats)
diff --git a/scala/src/main/scala/quanto/util/TreeSeq.scala b/scala/src/main/scala/quanto/util/TreeSeq.scala
index d705325c..25e192e6 100644
--- a/scala/src/main/scala/quanto/util/TreeSeq.scala
+++ b/scala/src/main/scala/quanto/util/TreeSeq.scala
@@ -3,29 +3,31 @@ package quanto.util
class TreeSeqFormatException(msg: String) extends Exception(msg)
/**
- * An abstract class representing a tree whose elements also have a sequential ordering, as in a branching history.
- */
+ * An abstract class representing a tree whose elements also have a sequential ordering, as in a branching history.
+ */
abstract class TreeSeq[A] {
+
import TreeSeq._
+
def toSeq: Seq[A]
+
def indexOf(a: A): Int
+
def parent(a: A): Option[A]
- def children(a: A): Seq[A]
- private def padding(size: Int) =
- Seq.fill[WhiteSpace[A]](size) { WhiteSpace[A](collapseBottom = false, collapseTop = true) }
+ def children(a: A): Seq[A]
def flatten: Seq[(Seq[Decoration[A]], A)] =
toSeq.foldLeft(Seq[(Seq[Decoration[A]], A)]()) { (rows, a) =>
val node = NodeLink(parent(a), children(a))
val prev = if (rows.isEmpty) Seq[Decoration[A]]()
- else rows.last._1
+ else rows.last._1
var inserted = false
var pad = 0
// traverse the previous decoration list from left to right and construct the current decoration list
- val current = prev.foldLeft(Seq[Decoration[A]]()) { (cols,col) =>
+ val current = prev.foldLeft(Seq[Decoration[A]]()) { (cols, col) =>
col match {
case (WireLink(dest)) =>
if (inserted) cols :+ WireLink(dest)
@@ -43,7 +45,7 @@ abstract class TreeSeq[A] {
pad += 1
cols :+ WhiteSpace[A](collapseBottom = true, collapseTop = false)
}
- else outs.foldLeft(cols) { (outCols,out) =>
+ else outs.foldLeft(cols) { (outCols, out) =>
if (inserted) outCols :+ WireLink(out)
else {
if (out == a) {
@@ -55,17 +57,17 @@ abstract class TreeSeq[A] {
}
}
}
- case WhiteSpace(false,_) =>
+ case WhiteSpace(false, _) =>
pad += 1
cols :+ WhiteSpace[A](collapseBottom = true, collapseTop = false)
- case WhiteSpace(true,_) => cols
+ case WhiteSpace(true, _) => cols
}
}
rows :+ (if (!inserted) {
node.input match {
- case Some (p) =>
+ case Some(p) =>
throw new TreeSeqFormatException("Node '" + a.toString + "' occurs before its parent '" + p + "'")
case None => current :+ node
}
@@ -73,19 +75,17 @@ abstract class TreeSeq[A] {
current
}, a)
}
+
+ private def padding(size: Int) =
+ Seq.fill[WhiteSpace[A]](size) {
+ WhiteSpace[A](collapseBottom = false, collapseTop = true)
+ }
}
object TreeSeq {
- sealed abstract class Decoration[A]
- case class NodeLink[A](input: Option[A], outputs: Seq[A]) extends Decoration[A]
- case class WireLink[A](dest: A) extends Decoration[A]
-
- // placeholder for whitespace. if collapse is false, the space propagates to the next rank. if it is true, it
- // disappears at the next rank.
- case class WhiteSpace[A](collapseBottom: Boolean, collapseTop: Boolean) extends Decoration[A]
// figure out how wide a given decoration sequence is
- def decorationWidth[A](dec: Seq[Decoration[A]]) = {
+ def decorationWidth[A](dec: Seq[Decoration[A]]): Int = {
var topIndex = 0
var bottomIndex = 0
var sz = 0
@@ -93,11 +93,11 @@ object TreeSeq {
case WireLink(_) =>
topIndex += 1
bottomIndex += 1
- sz = math.max(topIndex,bottomIndex)
+ sz = math.max(topIndex, bottomIndex)
case NodeLink(inputOpt, outputs) =>
topIndex += inputOpt.size
bottomIndex += math.max(1, outputs.size)
- sz = math.max(topIndex,bottomIndex)
+ sz = math.max(topIndex, bottomIndex)
case WhiteSpace(collapseBottom, collapseTop) =>
if (!collapseBottom) bottomIndex += 1
if (!collapseTop) topIndex += 1
@@ -105,4 +105,14 @@ object TreeSeq {
sz
}
+
+ sealed abstract class Decoration[A]
+
+ case class NodeLink[A](input: Option[A], outputs: Seq[A]) extends Decoration[A]
+
+ case class WireLink[A](dest: A) extends Decoration[A]
+
+ // placeholder for whitespace. if collapse is false, the space propagates to the next rank. if it is true, it
+ // disappears at the next rank.
+ case class WhiteSpace[A](collapseBottom: Boolean, collapseTop: Boolean) extends Decoration[A]
}
diff --git a/scala/src/main/scala/quanto/util/UserAlerts.scala b/scala/src/main/scala/quanto/util/UserAlerts.scala
index e21c0ec0..d5a1f8ef 100644
--- a/scala/src/main/scala/quanto/util/UserAlerts.scala
+++ b/scala/src/main/scala/quanto/util/UserAlerts.scala
@@ -1,9 +1,11 @@
package quanto.util
-import java.util.{Calendar, Date, UUID}
+import java.io.File
+import java.util.concurrent.TimeUnit
+import java.util.{Calendar, Date}
-import scala.swing.{Color, Dialog, Publisher}
import scala.swing.event.Event
+import scala.swing.{Color, Dialog, Publisher}
// Universal system for alerting the user
@@ -12,13 +14,79 @@ import scala.swing.event.Event
// Listen to events via AlertPublisher
object UserAlerts {
+ var ongoingProcesses: List[UserStartedProcess] = List()
+ var alerts: List[Alert] = List()
+
+ def leastCompleteProcess: Option[UserStartedProcess] = {
+ if (ongoingProcesses.isEmpty) None else {
+ val indeterminate = ongoingProcesses.find(op => !op.determinate)
+ if (indeterminate.nonEmpty) indeterminate else {
+ Some(ongoingProcesses.minBy(op => op.value))
+ }
+ }
+ }
+
+ def latestMessage: Alert = {
+ if (alerts.headOption.nonEmpty) alerts.head else {
+ alert("Quantomatic starting up")
+ latestMessage
+ }
+ }
+
+ def alert(message: Any): Unit = alert(message.toString, Elevation.NOTICE)
+
+ def debug(message: String): Unit = alert(message, Elevation.DEBUG)
+
+ def errorBox(message: String): Unit = {
+ alert(message, Elevation.ERROR)
+ Dialog.showMessage(
+ title = "Error",
+ message = message,
+ messageType = Dialog.Message.Error)
+ }
+
+ def alert(message: String, elevation: Elevation.Elevation): Unit = {
+ val newAlert = Alert(Calendar.getInstance().getTime, elevation, message)
+ println(newAlert.toString)
+ alerts = newAlert :: alerts
+ AlertPublisher.publish(UserAlertEvent(newAlert))
+ writeToLogFile(newAlert)
+ }
+
+ def writeToLogFile(alert: Alert, force: Boolean = false): Unit = {
+ val elevation = alert.elevationText match {
+ case "" => ""
+ case e => s"[$e]"
+ }
+ if (logFile.nonEmpty && (UserOptions.logging || force)) {
+ FileHelper.printToFile(logFile.get)(
+ p => p.println(UserOptions.preferredTimeFormat.format(alert.time) + ": " + elevation + alert.message)
+ )
+
+ }
+ }
+
+ def registerLogFile(optionFile: Option[File]) : Unit = {
+ _logFile = optionFile
+ }
+
+ private var _logFile : Option[File] = None
+
+ def logFile: Option[File] = _logFile
+
case class UserAlertEvent(alert: Alert) extends Event
- case class UserProcessUpdate(ongoingProcess: UserStartedProcess) extends Event
- object AlertPublisher extends Publisher
+ case class UserProcessUpdate(ongoingProcess: UserStartedProcess) extends Event
class SelfAlertingProcess(name: String) extends UserStartedProcess(name) {
alert(name + ": Started")
+ val startTime: Long = Calendar.getInstance().getTimeInMillis
+
+
+ override def halt(): Unit = {
+ super.halt()
+ alert(name + ": Halted", Elevation.NOTICE)
+ }
override def fail(): Unit = {
super.fail()
@@ -27,26 +95,21 @@ object UserAlerts {
override def finish(): Unit = {
super.finish()
- alert(name + ": Finished")
+ val timeTaken = TimeUnit.SECONDS.toSeconds(Calendar.getInstance().getTimeInMillis - startTime)
+ alert(name + s": Finished ${timeTaken / 1000.0}s")
}
}
class UserStartedProcess(val name: String) {
//private val uuid : UUID = UUID.randomUUID() //Will need for log files
private var _determinate: Boolean = false
- private var _value : Int = 0
- private var _failed : Boolean = false
+ private var _value: Int = 0
+ private var _failed: Boolean = false
- def failed : Boolean = _failed
+ def failed: Boolean = _failed
- def determinate : Boolean = _determinate
+ def determinate: Boolean = _determinate
- def value: Int = _value
- def value_=(newValue: Int) : Unit = {
- _value = newValue
- _determinate = true
- AlertPublisher.publish(UserProcessUpdate(this))
- }
def setIndeterminate(): Unit = {
_value = 0
_determinate = false
@@ -58,29 +121,31 @@ object UserAlerts {
value = 100
}
+ def value: Int = _value
+
+ def value_=(newValue: Int): Unit = {
+ _value = newValue
+ _determinate = true
+ AlertPublisher.publish(UserProcessUpdate(this))
+ }
+
def finish(): Unit = {
value = 100
}
- ongoingProcesses = this :: ongoingProcesses
- AlertPublisher.publish(UserProcessUpdate(this))
- }
-
- var ongoingProcesses : List[UserStartedProcess] = List()
-
- def leastCompleteProcess : Option[UserStartedProcess] = {
- if (ongoingProcesses.isEmpty) None else {
- val indeterminate = ongoingProcesses.find(op => !op.determinate)
- if (indeterminate.nonEmpty) indeterminate else {
- Some(ongoingProcesses.minBy(op => op.value))
- }
+ def halt(): Unit = {
+ _failed = false
+ value = 0
}
- }
+ ongoingProcesses = this :: ongoingProcesses
+ AlertPublisher.publish(UserProcessUpdate(this))
+ }
- case class Alert(time : Date, elevation: Elevation.Elevation, message: String) {
+ case class Alert(time: Date, elevation: Elevation.Elevation, message: String) {
override def toString: String = UserOptions.preferredTimeFormat.format(time) + ": " + message
- def color : Color= {
+
+ def color: Color = {
elevation match {
case Elevation.ERROR => new Color(150, 0, 0) // Something broke
case Elevation.ALERT => new Color(150, 150, 0) // Something soon to break
@@ -89,36 +154,22 @@ object UserAlerts {
case Elevation.NOTICE => new Color(0, 0, 150) // Nothing is broken
}
}
- }
- object Elevation extends Enumeration {
- type Elevation = Value
- val ALERT, ERROR, WARNING, NOTICE, DEBUG = Value
- }
-
- var Alerts : List[Alert] = List()
-
- def latestMessage : Alert = {
- if (Alerts.headOption.nonEmpty) Alerts.head else {
- alert("Quantomatic starting up")
- latestMessage
+ def elevationText: String = {
+ elevation match {
+ case Elevation.ERROR => "ERROR" // Something broke
+ case Elevation.ALERT => "ALERT" // Something soon to break
+ case Elevation.WARNING => "WARNING" // That would have caused something to break
+ case Elevation.DEBUG => "DEBUG" // I want to know how it would have broken
+ case Elevation.NOTICE => "" // Nothing is broken
+ }
}
}
- def alert(message: String) : Unit = alert(message, Elevation.NOTICE)
-
- def errorbox(message: String) : Unit = {
- alert(message, Elevation.ERROR)
- Dialog.showMessage(
- title = "Error",
- message = message,
- messageType = Dialog.Message.Error)
- }
+ object AlertPublisher extends Publisher
- def alert(message: String, elevation: Elevation.Elevation): Unit ={
- val newAlert = Alert(Calendar.getInstance().getTime, elevation, message)
- println(newAlert.toString)
- Alerts = newAlert :: Alerts
- AlertPublisher.publish(UserAlertEvent(newAlert))
+ object Elevation extends Enumeration {
+ type Elevation = Value
+ val ALERT, ERROR, WARNING, NOTICE, DEBUG = Value
}
}
diff --git a/scala/src/main/scala/quanto/util/UserOptions.scala b/scala/src/main/scala/quanto/util/UserOptions.scala
index fd0bfeb5..6719791d 100644
--- a/scala/src/main/scala/quanto/util/UserOptions.scala
+++ b/scala/src/main/scala/quanto/util/UserOptions.scala
@@ -15,16 +15,42 @@ class UserOptions {
object UserOptions {
+ val prefs: Preferences = Preferences.userRoot().node(this.getClass.getName)
+ // A scaling of 1 corresponds to a font size of 14
+ // changing the font size will also change the rest of the scaling
+ private var _uiScale: Double = prefs.getDouble("uiScale", 1)
+ private var _fontDecoration = Font.PLAIN
+ private var _fontFamily = "defaultFont"
+ private var _logging: Boolean = prefs.getBoolean("logging", false)
+ private var _graphScale: Double = 1
+ private var _preferredTimeFormat: SimpleDateFormat = new SimpleDateFormat("HH:mm:ss")
+ private var _preferredDateTimeFormat: SimpleDateFormat = new SimpleDateFormat("yy-MM-dd.HH:mm:ss")
- case class UIRedrawRequest() extends Event
+ def scaleInt(d: Double): Int = math.floor(scale(d)).toInt
- object OptionsChanged extends Publisher
+ def scale(d: Double): Double = {
+ d * uiScale
+ }
+
+ def resetScale(): Unit = {
+ uiScale = 1
+ }
+
+ def uiScale: Double = _uiScale
+
+ def uiScale_=(d: Double) {
+ _uiScale = d
+ _uiScale = math.max(_uiScale, 0.5) // Limit scaling to equivalent of 7pt font
+ _uiScale = math.min(_uiScale, 4) // Limit scaling to equivalent of 56pt font
+ setUIFont(new FontUIResource(_fontFamily, _fontDecoration, fontSize))
+ prefs.putDouble("uiScale", _uiScale)
+ requestUIRefresh()
+ }
private def requestUIRefresh(): Unit = {
OptionsChanged.publish(UIRedrawRequest())
}
- // Changes the default font but doesn't request redraw
private def setUIFont(f: FontUIResource): Unit = {
val keys = UIManager.getDefaults.keys
while ( {
@@ -36,67 +62,70 @@ object UserOptions {
}
}
- // A scaling of 1 corresponds to a font size of 14
- // changing the font size will also change the rest of the scaling
- private var _uiScale : Double = 1
- def uiScale : Double = _uiScale
- def uiScale_=(d: Double){
- _uiScale = math.max(d, 0.5) // Limit scaling to equivalent of 7pt font
- setUIFont(new FontUIResource(_fontFamily, _fontDecoration, fontSize))
- val prefs = Preferences.userRoot().node(this.getClass.getName)
- prefs.putDouble("uiScale", _uiScale)
- requestUIRefresh()
+ def font: Font = {
+ new Font(_fontFamily, _fontDecoration, fontSize)
}
- def scale(d: Double) : Double = {
- d * uiScale
+ def fontSize: Int = {
+ Math.floor(14 * _uiScale).toInt
}
- def scaleInt(d: Double) : Int = math.floor(scale(d)).toInt
-
- def resetScale(): Unit ={
- uiScale = 1
+ def fontSize_=(n: Int) {
+ uiScale = n / 14.0 //changing uiScale triggers redraw and font size changes
}
- def font : Font = {
- new Font(_fontFamily, _fontDecoration, fontSize)
- }
+ def fontDecoration: Int = _fontDecoration
- private var _fontDecoration = Font.PLAIN
- def fontDecoration : Int = _fontDecoration
- def fontDecoration_=(n: Int){
+ def fontDecoration_=(n: Int) {
_fontDecoration = n
}
- private var _fontFamily = "defaultFont"
- def fontFamily : String = _fontFamily
- def fontFamily_=(s: String): Unit ={
+ def fontFamily: String = _fontFamily
+
+ def fontFamily_=(s: String): Unit = {
_fontFamily = s
}
- def fontSize : Int = {
- Math.floor(14*_uiScale).toInt
- }
- def fontSize_=(n: Int){
- uiScale = n / 14.0 //changing uiScale triggers redraw and font size changes
+ def logging: Boolean = _logging
+
+ def logging_=(b: Boolean): Unit = {
+ prefs.putBoolean("logging", b)
+ _logging = b
}
- private var _graphScale : Double = 1
- def graphScale : Double = _graphScale
- def graphScale_=(n: Double): Unit ={
+ def graphScale: Double = _graphScale
+
+ def graphScale_=(n: Double): Unit = {
_graphScale = n
}
- private var _preferredTimeFormat : SimpleDateFormat = new SimpleDateFormat("HH:mm:ss")
- def preferredTimeFormat : SimpleDateFormat = _preferredTimeFormat
+ def preferredTimeFormat: SimpleDateFormat = _preferredTimeFormat
+
def preferredTimeFormat_=(format: String): Unit = {
try {
_preferredTimeFormat = new SimpleDateFormat(format)
}
- catch {
- case e: Exception =>
- e.printStackTrace()
- }
+ catch {
+ case e: Exception =>
+ e.printStackTrace()
+ }
+ }
+
+ def preferredDateTimeFormat: SimpleDateFormat = _preferredDateTimeFormat
+
+ def preferredDateTimeFormat_=(format: String): Unit = {
+ try {
+ _preferredDateTimeFormat = new SimpleDateFormat(format)
+ }
+ catch {
+ case e: Exception =>
+ e.printStackTrace()
+ }
}
+ case class UIRedrawRequest() extends Event
+
+ object OptionsChanged extends Publisher
+
+
}
diff --git a/scala/src/main/scala/quanto/util/WebHelper.scala b/scala/src/main/scala/quanto/util/WebHelper.scala
index 8f8ad34c..f3c3ee63 100644
--- a/scala/src/main/scala/quanto/util/WebHelper.scala
+++ b/scala/src/main/scala/quanto/util/WebHelper.scala
@@ -1,6 +1,7 @@
package quanto.util
-import java.net.URL
+
import java.awt.Desktop
+import java.net.URL
object WebHelper {
diff --git a/scala/src/test/scala/quanto/cosy/test/BlockEnumerationSpec.scala b/scala/src/test/scala/quanto/cosy/test/BlockEnumerationSpec.scala
index c1047ac1..04c165b6 100644
--- a/scala/src/test/scala/quanto/cosy/test/BlockEnumerationSpec.scala
+++ b/scala/src/test/scala/quanto/cosy/test/BlockEnumerationSpec.scala
@@ -2,13 +2,16 @@ package quanto.cosy.test
import quanto.cosy._
import org.scalatest.FlatSpec
-import quanto.data.Rule
+import quanto.data.{Graph, GraphTikz, Rule}
import quanto.rewrite.Matcher
+import quanto.cosy.BlockRowMaker._
+import quanto.util.FileHelper
/**
* Created by hector on 28/06/17.
*/
+
class BlockEnumerationSpec extends FlatSpec {
implicit def quickList(n: Int): List[Int] = {
@@ -22,31 +25,31 @@ class BlockEnumerationSpec extends FlatSpec {
behavior of "Block Enumeration"
it should "build a small ZW row" in {
- var rowsAllowed = BlockRowMaker(1, allowedBlocks = BlockRowMaker.ZW)
+ var rowsAllowed = BlockRowMaker(1, allowedBlocks = BlockGenerators.ZW)
println(rowsAllowed)
}
it should "build bigger ZW rows" in {
- var rowsAllowed = BlockRowMaker(2, allowedBlocks = BlockRowMaker.ZW)
+ var rowsAllowed = BlockRowMaker(2, allowedBlocks = BlockGenerators.ZW)
println(rowsAllowed)
assert(rowsAllowed.length == 11 * 11 + 11)
}
it should "stack rows" in {
- var rowsAllowed = BlockRowMaker(1, allowedBlocks = BlockRowMaker.ZW)
+ var rowsAllowed = BlockRowMaker(1, allowedBlocks = BlockGenerators.ZW)
var stacks = BlockStackMaker(2, rowsAllowed)
println(stacks)
}
it should "limit wires" in {
- var rowsAllowed = BlockRowMaker(2, allowedBlocks = BlockRowMaker.ZW, maxInOut = Option(2))
+ var rowsAllowed = BlockRowMaker(2, allowedBlocks = BlockGenerators.ZW, maxInOut = Option(2))
var stacks = BlockStackMaker(2, rowsAllowed)
println(stacks)
assert(stacks.forall(s => (s.inputs.length <= 2) && (s.outputs.length <= 2)))
}
it should "compute tensors" in {
- var rowsAllowed = BlockRowMaker(2, allowedBlocks = BlockRowMaker.ZW)
+ var rowsAllowed = BlockRowMaker(2, allowedBlocks = BlockGenerators.ZW)
var stacks = BlockStackMaker(2, rowsAllowed)
for (elem <- stacks) {
println("---\n" + elem.toString + " = \n" + elem.tensor)
@@ -66,7 +69,7 @@ class BlockEnumerationSpec extends FlatSpec {
it should "find wire identities" in {
var rowsAllowed = BlockRowMaker(1, allowedBlocks = List(
- Block(1, 1, " 1 ", Tensor.idWires(1)),
+ Block(1, 1, " 1 ", Tensor.idWires(1), new Graph()),
Block(1, 1, " w ", new Tensor(Array(Array[Complex](1, 0), Array[Complex](0, -1)))),
Block(1, 1, " b ", new Tensor(Array(Array[Complex](0, 1), Array[Complex](1, 0))))
))
@@ -109,42 +112,42 @@ class BlockEnumerationSpec extends FlatSpec {
behavior of "Stack to Graph"
it should "convert a block to a graph" in {
- var b = BlockRowMaker.Bian2Qubit(2) // T-gate
- var g = BlockRowMaker.Bian2QubitToGraph(b)
+ var b = BlockGenerators.Bian2Qubit(2) // T-gate
+ var g = BlockGenerators.Bian2QubitToGraph(b)
println(g.toString)
}
it should "convert a row to a graph" in {
- var B2 = BlockRowMaker.Bian2Qubit
+ var B2 = BlockGenerators.Bian2Qubit
var r = new BlockRow(List(B2(2), B2(3))) // T x H
- var g = BlockRowMaker.rowToGraph(r, BlockRowMaker.Bian2QubitToGraph)
+ var g = BlockRowMaker.predicateRowToGraph(r, BlockGenerators.Bian2QubitToGraph)
println(g.toString)
}
it should "convert a stack to a graph" in {
- var B2 = BlockRowMaker.Bian2Qubit
+ var B2 = BlockGenerators.Bian2Qubit
var r = new BlockRow(List(B2(2), B2(3))) // T x H
- var g = BlockRowMaker.stackToGraph(
+ var g = BlockRowMaker.predicateStackToGraph(
new BlockStack(List(r, r)),
- BlockRowMaker.Bian2QubitToGraph)
+ BlockGenerators.Bian2QubitToGraph)
println(g)
}
behavior of "qutrits and qudits"
it should "generate enough qutrit generators" in {
- assert(BlockRowMaker.ZXQutrit(9).length == (10 + 2 * 81))
+ assert(BlockGenerators.ZXQutrit(9).length == (10 + 2 * 81))
}
it should "generate enough qudit generators" in {
- assert(BlockRowMaker.ZXQudit(3, 9).length == (10 + 2 * 81))
+ assert(BlockGenerators.ZXQudit(3, 9).length == (10 + 2 * 81))
// And check it is the correct swap tensor:
- assert(BlockRowMaker.ZXQudit(3, 9)(1).tensor == BlockRowMaker.ZXQutrit(9)(1).tensor)
- assert(BlockRowMaker.ZXQudit(4, 8).length == (10 + 2 * math.pow(8, 4 - 1)).toInt)
+ assert(BlockGenerators.ZXQudit(3, 9)(1).tensor == BlockGenerators.ZXQutrit(9)(1).tensor)
+ assert(BlockGenerators.ZXQudit(4, 8).length == (10 + 2 * math.pow(8, 4 - 1)).toInt)
}
it should "have spider rules for qudits" in {
- var Q4 = BlockRowMaker.ZXQudit(4, 8)
+ var Q4 = BlockGenerators.ZXQudit(4, 8)
var r760 = Q4.find(p => p.name == "r|7|6|0")
var r230 = Q4.find(p => p.name == "r|2|3|0")
var r110 = Q4.find(p => p.name == "r|1|1|0")
@@ -157,7 +160,7 @@ class BlockEnumerationSpec extends FlatSpec {
behavior of "Bell Simple"
it should "display quantum teleportation" in {
- var BSRow = BlockRowMaker(2, BlockRowMaker.BellTeleportation, Option(3))
+ var BSRow = BlockRowMaker(2, BlockGenerators.BellTeleportation, Option(3))
var BSStacks = BlockStackMaker(4, BSRow)
var tp = BSStacks.
//filterNot(x => x.toString.matches(raw".*\(w\d \).*")).
@@ -169,4 +172,59 @@ class BlockEnumerationSpec extends FlatSpec {
tp.foreach(x => println("---- \n " + x.toString + "\n" + x.tensor))
assert(tp.length == 4)
}
+
+ behavior of "ZX Clifford"
+
+ it should "Find CZ gate" in {
+ var BSRow = BlockRowMaker(2, BlockGenerators.ZXClifford, Option(2))
+ var BSStacks = BlockStackMaker(3, BSRow)
+ var tp = BSStacks.
+ filter(x=> x.tensor.isSameShapeAs(Tensor.idWires(2))).
+ filter(x => x.tensor.isRoughlyUpToScalar(
+ Tensor(Array(Array(1, 0, 0, 0), Array(0, 1, 0, 0), Array(0, 0, 1, 0), Array(0, 0, 0, -1))))
+ )
+ tp.foreach(x => println("---- \n " + x.toString + "\n" + x.tensor))
+ assert(tp.nonEmpty)
+ }
+
+ behavior of "graph generation"
+
+ val ZXClifford = BlockGenerators.ZXClifford
+
+ it should "make a pair horizontally" in {
+ var row = new BlockRow(List(ZXClifford(1), ZXClifford(3)))
+ var g = row.graph
+ assert(g.verts.size == 6)
+ }
+
+ it should "make a pair vertically" in {
+ var row = new BlockRow(List(ZXClifford(0)))
+ var g = BlockStack(List(row, row)).graph
+ assert(g.verts.size == 4)
+ assert(g.edges.size == 3)
+ assert(g.minimise.edges.size == 1)
+ }
+
+ it should "make a 2x2" in {
+ var row = new BlockRow(List(ZXClifford(5), ZXClifford(1)))
+ var row2 = new BlockRow(List(ZXClifford(1), ZXClifford(5)))
+ var g = BlockStack(List(row, row2)).graph
+ assert(g.verts.size == 12)
+ }
+
+ it should "cache rows" in {
+ var row = new BlockRow(List(ZXClifford(0), ZXClifford(1)))
+ var row2 = new BlockRow(List(ZXClifford(5), ZXClifford(0)))
+ var row3 = new BlockRow(List(ZXClifford(0), ZXClifford(5)))
+ var b1 = BlockStack(List(row, row2, row3))
+ var b2 = BlockStack(List(row2, row3)).append(row)
+ var b3 = BlockStack(List(row3)).append(row2).append(row)
+ var g1 = b1.graph
+ var g2 = b2.graph
+ assert(b1.tensor == b2.tensor)
+ assert(b1.tensor == b3.tensor)
+ assert(g1.edges.size == g2.edges.size)
+ assert(g1.verts.toList.sorted == g2.verts.toList.sorted)
+ }
+
}
\ No newline at end of file
diff --git a/scala/src/test/scala/quanto/cosy/test/BlockGeneratorsSpec.scala b/scala/src/test/scala/quanto/cosy/test/BlockGeneratorsSpec.scala
new file mode 100644
index 00000000..c23bcf53
--- /dev/null
+++ b/scala/src/test/scala/quanto/cosy/test/BlockGeneratorsSpec.scala
@@ -0,0 +1,70 @@
+package quanto.cosy.test
+
+import org.scalatest.FlatSpec
+import quanto.cosy.BlockGenerators._
+import quanto.cosy._
+
+import scala.concurrent.duration.Duration
+import java.io.File
+
+import quanto.data._
+import quanto.rewrite.Matcher
+import quanto.util.{FileHelper, Rational}
+import quanto.util.json.Json
+import quanto.data.Names._
+
+import scala.util.Random
+
+class BlockGeneratorsSpec extends FlatSpec {
+
+
+ behavior of "ZX generators"
+
+ it should "make the right number of generators" in {
+ // CNOTs start at width 2
+ assert(zxCNOTs(2).size == 1)
+ assert(zxCNOTs(4).size == 3)
+ assert(zxTONCs(4).size == 3)
+ // Make X and Z twists
+ val twists = zxQubitTwists(8)
+ assert(twists.size == 16)
+ }
+
+ it should "make the right angles" in {
+ val twists = zxQubitTwists(8)
+ assert(twists(2).graph.vdata("X").asInstanceOf[NodeV].phaseData.values.head.constant == Rational(1, 4))
+ assert(twists(15).graph.vdata("Z").asInstanceOf[NodeV].phaseData.values.head.constant == Rational(7, 4))
+ }
+
+ it should "make hadamards" in {
+ val twists = zxQubitTwists(8)
+ assert((zxQubitHadamard o twists(2) o zxQubitHadamard).isRoughlyUpToScalar(twists(3)))
+ }
+
+ it should "make correct CNOTs and TNOCs" in {
+ assert((zxCNOT(4) o zxCNOT(4)).isRoughlyUpToScalar(Tensor.idWires(4)))
+ assert((zxTONC(4) o zxTONC(4)).isRoughlyUpToScalar(Tensor.idWires(4)))
+ assert((zxCNOT(4) o zxTONC(4) o zxCNOT(4)).isRoughlyUpToScalar(Tensor.swap(List(3, 1, 2, 0))))
+ val c4 = zxCNOT(4)
+ val c4graphInterpretation = Interpreter.interpretZXGraph(c4.graph,
+ List("i-0", "i-1", "i-2", "i-3"),
+ List("o-0", "o-1", "o-2", "o-3"))
+ assert(c4.tensor.isRoughlyUpToScalar(c4graphInterpretation))
+
+ val t4 = zxTONC(4)
+ val t4graphInterpretation = Interpreter.interpretZXGraph(t4.graph,
+ List("i-0", "i-1", "i-2", "i-3"),
+ List("o-0", "o-1", "o-2", "o-3"))
+ assert(t4.tensor.isRoughlyUpToScalar(t4graphInterpretation))
+
+ }
+
+ it should "not have string edges, only rails" in {
+ val blocks: List[Block] = BlockGenerators.ZXGates(4, 1)
+ val rows: List[BlockRow] = BlockRowMaker.makeRowsUpToSize(1, blocks, Some(1))
+ val stacks = BlockStackMaker.makeStacksOfSize(2, rows)
+ val graphs = stacks.map(_.graph)
+ val graphsWithStrings = graphs.filter(g => g.edata.values.exists(ed => ed.typ == "string"))
+ assert(graphsWithStrings.isEmpty)
+ }
+}
diff --git a/scala/src/test/scala/quanto/cosy/test/CoSyRunSpec.scala b/scala/src/test/scala/quanto/cosy/test/CoSyRunSpec.scala
new file mode 100644
index 00000000..e4b2cebb
--- /dev/null
+++ b/scala/src/test/scala/quanto/cosy/test/CoSyRunSpec.scala
@@ -0,0 +1,263 @@
+package quanto.cosy.test
+
+import org.scalatest.FlatSpec
+import quanto.cosy.CoSyRuns._
+import quanto.cosy._
+
+import scala.concurrent.duration.Duration
+import java.io.File
+
+import quanto.data.Theory.VertexDesc
+import quanto.data._
+import quanto.rewrite.Matcher
+import quanto.util.FileHelper
+import quanto.util.json.{Json, JsonObject}
+
+import scala.util.Random
+
+/**
+ * Created by hector on 24/05/17.
+ */
+class CoSyRunSpec extends FlatSpec {
+
+ behavior of "ZX"
+
+ it should "do a small run" in {
+ var theory = Theory.fromFile("red_green")
+ var CR = new CoSyRuns.CoSyZX(duration = Duration.Inf,
+ numBoundaries = List(0, 1, 2),
+ outputDir = None,
+ scalars = false,
+ numVertices = 2,
+ rulesDir = new File("./cosy/"), theory = theory,
+ numAngles = 4)
+
+ def interpret(g: Graph) = Interpreter.interpretZXGraph(g, g.verts.filter(g.isBoundary).toList, List())
+ // Don't test this here
+ // It isn't a standard feature
+ // And pollutes the filesystem
+ // Ask hmillerbakewell@gmail.com for more information
+
+ // CR.begin()
+ /*
+ val reduced = RuleSynthesis.minimiseRuleset(CR.reductionRules, theory, new Random(1))
+ //reduced.foreach(r => FileHelper.printJson(s"./cosy/${r.lhs.hashCode}-${r.rhs.hashCode}.qrule", Rule.toJson(r, theory)))
+ reduced.foreach(r =>
+ assert(interpret(r.lhs).isRoughly(interpret(r.rhs))))
+ */
+
+ }
+
+
+ behavior of "ZX with bool"
+
+ it should "do a small run" in {
+ val theory = Theory.fromJson(
+ """
+ |{
+ | "name": "ZXH",
+ | "core_name": "zxh",
+ | "vertex_types": {
+ | "Z": {
+ | "value": {
+ | "type": "angle_expr, bool",
+ | "latex_constants": true,
+ | "validate_with_core": false
+ | },
+ | "style": {
+ | "label": {
+ | "position": "center",
+ | "fg_color": [
+ | 0.0,
+ | 0.0,
+ | 0.0
+ | ]
+ | },
+ | "stroke_color": [
+ | 0.0,
+ | 0.0,
+ | 0.0
+ | ],
+ | "fill_color": [
+ | 0.0,
+ | 1.0,
+ | 0.0
+ | ],
+ | "shape": "circle"
+ | },
+ | "default_data": {
+ | "type": "Z",
+ | "value": "0, 0"
+ | }
+ | },
+ | "X": {
+ | "value": {
+ | "type": "angle_expr, bool",
+ | "latex_constants": true,
+ | "validate_with_core": false
+ | },
+ | "style": {
+ | "label": {
+ | "position": "center",
+ | "fg_color": [
+ | 0.0,
+ | 0.0,
+ | 0.0
+ | ]
+ | },
+ | "stroke_color": [
+ | 0.0,
+ | 0.0,
+ | 0.0
+ | ],
+ | "fill_color": [
+ | 1.0,
+ | 0.0,
+ | 0.0
+ | ],
+ | "shape": "circle"
+ | },
+ | "default_data": {
+ | "type": "X",
+ | "value": "0, 0"
+ | }
+ | },
+ | "hadamard": {
+ | "value": {
+ | "type": "empty",
+ | "latex_constants": true,
+ | "validate_with_core": false
+ | },
+ | "style": {
+ | "label": {
+ | "position": "center",
+ | "fg_color": [
+ | 0.0,
+ | 0.0,
+ | 0.0
+ | ]
+ | },
+ | "stroke_color": [
+ | 0.0,
+ | 0.0,
+ | 0.0
+ | ],
+ | "fill_color": [
+ | 1.0,
+ | 1.0,
+ | 0.0
+ | ],
+ | "shape": "rectangle"
+ | },
+ | "default_data": {
+ | "type": "hadamard",
+ | "value": ""
+ | }
+ | },
+ | "dummyBoundary": {
+ | "value": {
+ | "type": "empty",
+ | "latex_constants": true,
+ | "validate_with_core": false
+ | },
+ | "style": {
+ | "label": {
+ | "position": "center",
+ | "fg_color": [
+ | 0.0,
+ | 0.0,
+ | 0.0
+ | ]
+ | },
+ | "stroke_color": [
+ | 0.0,
+ | 0.0,
+ | 0.0
+ | ],
+ | "fill_color": [
+ | 0.0,
+ | 1.0,
+ | 1.0
+ | ],
+ | "shape": "rectangle"
+ | },
+ | "default_data": {
+ | "type": "dummyBoundary",
+ | "value": ""
+ | }
+ | }
+ | },
+ | "default_vertex_type": "Z",
+ | "default_edge_type": "plain",
+ | "edge_types": {
+ | "plain": {
+ | "value": {
+ | "type": "empty",
+ | "latex_constants": false,
+ | "validate_with_core": false
+ | },
+ | "style": {
+ | "stroke_color": [
+ | 0.0,
+ | 0.0,
+ | 0.0
+ | ],
+ | "stroke_width": 1,
+ | "label": {
+ | "position": "auto",
+ | "fg_color": [
+ | 0.0,
+ | 0.0,
+ | 0.0
+ | ]
+ | }
+ | },
+ | "default_data": {
+ | "type": "plain"
+ | }
+ | }
+ | }
+ |}
+ """.stripMargin)
+ var CR = new CoSyRuns.CoSyZXBool(duration = Duration.Inf,
+ numBoundaries = List(0, 1, 2),
+ outputDir = None,
+ scalars = false,
+ numVertices = 3,
+ rulesDir = new File("./cosy/"), theory = theory,
+ numAngles = 4)
+
+ // Don't test this here
+ // It isn't a standard feature
+ // And pollutes the filesystem
+ // Ask hmillerbakewell@gmail.com for more information
+
+ //CR.begin()
+ assert(1 == 1)
+ /*
+ val reduced = RuleSynthesis.minimiseRuleset(CR.reductionRules, theory, new Random(1))
+ //reduced.foreach(r => FileHelper.printJson(s"./cosy/${r.lhs.hashCode}-${r.rhs.hashCode}.qrule", Rule.toJson(r, theory)))
+ reduced.foreach(r =>
+ assert(interpret(r.lhs).isRoughly(interpret(r.rhs))))
+ */
+
+ }
+
+ behavior of "ZX Circuit"
+
+ it should "do a small run" in {
+ val ZXRails = Theory.fromFile("ZXRails")
+
+ // THIS WILL RUN UNTIL THE TIME RUNS OUT.
+ var CR = new CoSyRuns.CoSyCircuit(duration = Duration(4, "minutes"), numBoundaries = 3,
+ outputDir = Some(new File("./cosy_synth/")),
+ rulesDir = new File("./cosy_synth/"), theory = ZXRails)
+ //var rules = CR.begin()
+ /*
+ val reduced = RuleSynthesis.minimiseRuleset(CR.reductionRules, theory, new Random(1))
+ reduced.foreach(r => FileHelper.printJson(s"./cosy/${r.lhs.hashCode}-${r.rhs.hashCode}.qrule", Rule.toJson(r, theory)))
+ */
+ assert(1 == 1)
+ }
+
+}
diff --git a/scala/src/test/scala/quanto/cosy/test/ColbournReadEnumSpec.scala b/scala/src/test/scala/quanto/cosy/test/ColbournReadEnumSpec.scala
index 1d2ea1cf..7f78cf95 100644
--- a/scala/src/test/scala/quanto/cosy/test/ColbournReadEnumSpec.scala
+++ b/scala/src/test/scala/quanto/cosy/test/ColbournReadEnumSpec.scala
@@ -2,6 +2,8 @@ package quanto.cosy.test
import org.scalatest._
import quanto.cosy._
+import quanto.data.{Graph, NodeV, Theory}
+import quanto.util.json.JsonObject
class ColbournReadEnumSpec extends FlatSpec {
@@ -198,4 +200,28 @@ class ColbournReadEnumSpec extends FlatSpec {
assert(numBi(6) === 1 + 1 + 1 + 1 + 3 + 5 + 17)
assert(numBi(7) === 1 + 1 + 1 + 1 + 3 + 5 + 17 + 44)
}
+
+ behavior of "conversion to graph"
+
+ it should "convert small amats into ZX graphs" in {
+ val rg = Theory.fromFile("red_green")
+
+ val pi = math.Pi
+ val rdata = Vector(
+ NodeV(data = JsonObject("type" -> "X", "value" -> "0"), theory = rg),
+ NodeV(data = JsonObject("type" -> "X", "value" -> pi.toString), theory = rg),
+ NodeV(data = JsonObject("type" -> "X", "value" -> (0.5 * pi).toString), theory = rg),
+ NodeV(data = JsonObject("type" -> "X", "value" -> (-0.5 * pi).toString), theory = rg)
+ )
+ val gdata = Vector(
+ NodeV(data = JsonObject("type" -> "Z", "value" -> "0"), theory = rg),
+ NodeV(data = JsonObject("type" -> "Z", "value" -> pi.toString), theory = rg),
+ NodeV(data = JsonObject("type" -> "Z", "value" -> (0.5 * pi).toString), theory = rg),
+ NodeV(data = JsonObject("type" -> "Z", "value" -> (-0.5 * pi).toString), theory = rg)
+ )
+
+ var one = Complex.one
+ def quickGraph(amat: AdjMat) : Graph = Graph.fromAdjMat(amat, rdata, gdata)
+ ColbournReadEnum.enumerate(2,2,2,2).map(quickGraph)
+ }
}
diff --git a/scala/src/test/scala/quanto/cosy/test/EQCAnalysisSpec.scala b/scala/src/test/scala/quanto/cosy/test/EQCAnalysisSpec.scala
index e7e70076..69e6a623 100644
--- a/scala/src/test/scala/quanto/cosy/test/EQCAnalysisSpec.scala
+++ b/scala/src/test/scala/quanto/cosy/test/EQCAnalysisSpec.scala
@@ -68,7 +68,7 @@ class EQCAnalysisSpec extends FlatSpec {
assert(!adjacencyMatrix._2(bIndex)(vIndex))
- val ghostedErrors = GraphAnalysis.bypassSpecial(GraphAnalysis.detectErrors)(targetGraph, adjacencyMatrix)
+ val ghostedErrors = GraphAnalysis.bypassSpecial(GraphAnalysis.detectPiNodes)(targetGraph, adjacencyMatrix)
assert(ghostedErrors._2(bIndex)(eIndex))
assert(ghostedErrors._2(bIndex)(vIndex))
}
diff --git a/scala/src/test/scala/quanto/cosy/test/EquivalenceClassesSpec.scala b/scala/src/test/scala/quanto/cosy/test/EquivalenceClassesSpec.scala
index 7c3a2947..d4cca9a2 100644
--- a/scala/src/test/scala/quanto/cosy/test/EquivalenceClassesSpec.scala
+++ b/scala/src/test/scala/quanto/cosy/test/EquivalenceClassesSpec.scala
@@ -6,6 +6,7 @@ import java.nio.file.Paths
import org.scalatest._
import quanto.cosy._
import quanto.data._
+import quanto.util.FileHelper
import quanto.util.json.{JsonArray, JsonObject}
import scala.util.parsing.json.JSON
@@ -123,18 +124,6 @@ class EquivalenceClassesSpec extends FlatSpec {
behavior of "IO"
- it should "save run results to file" in {
- var results = EquivClassRunAdjMat(numAngles = 4,
- tolerance = EquivClassRunAdjMat.defaultTolerance,
- rulesList = emptyRuleList,
- theory = rg)
- var testFile = new File("test_run_output.qrun")
- quanto.util.FileHelper.printToFile(testFile, append = false)(
- p => p.println(results.toJSON.toString())
- )
- assert(testFile.delete())
- }
-
it should "output to and input from JSON" in {
var results = EquivClassRunAdjMat(numAngles = 4,
tolerance = EquivClassRunAdjMat.defaultTolerance,
@@ -161,26 +150,6 @@ class EquivalenceClassesSpec extends FlatSpec {
behavior of "batch runner"
- it should "create an output qrun file" in {
- EquivClassBatchRunner(4, 2, 2, "test.qrun")
- var testFile = new File(EquivClassBatchRunner.outputPath + "/" + "test.qrun")
- assert(testFile.exists())
- assert(testFile.delete())
- }
-
- it should "allow outputs to home directory" in {
- EquivClassBatchRunner.outputPath = Paths.get(System.getProperty("user.home"), "cosy_synth").toString
- println(EquivClassBatchRunner.outputPath)
- EquivClassBatchRunner.outputPath = "cosy_synth" // reset to avoid problems in later tests
- }
-
- it should "create an output qtensor file" in {
- TensorBatchRunner(1, 2, 2)
- var testFile = new File(TensorBatchRunner.outputPath + "/" + "tensors-1-2-2.qtensor")
- assert(testFile.exists())
- assert(testFile.delete())
- }
-
it should "be writing legible JSON" in {
var results = EquivClassRunAdjMat(numAngles = 4,
tolerance = EquivClassRunAdjMat.defaultTolerance,
@@ -200,7 +169,7 @@ class EquivalenceClassesSpec extends FlatSpec {
it should "put stacks into equivalence classes" in {
var allowedStacks = BlockStackMaker(maxRows = 2,
- BlockRowMaker(maxBlocks = 1, BlockRowMaker.ZX(8), maxInOut = Option(2)))
+ BlockRowMaker(maxBlocks = 1, BlockGenerators.ZXGates(8), maxInOut = Option(2)))
var eqc = new EquivClassRunBlockStack()
allowedStacks.foreach(s => eqc.add(s))
println(eqc.equivalenceClassesNormalised.foreach(
@@ -212,7 +181,7 @@ class EquivalenceClassesSpec extends FlatSpec {
it should "put stacks into equivalence classes" in {
var allowedStacks = BlockStackMaker(maxRows = 2,
- BlockRowMaker(maxBlocks = 2, BlockRowMaker.ZW, maxInOut = Option(2)))
+ BlockRowMaker(maxBlocks = 2, BlockGenerators.ZW, maxInOut = Option(2)))
var eqc = new EquivClassRunBlockStack()
allowedStacks.foreach(s => eqc.add(s))
println(eqc.equivalenceClassesNormalised.foreach(
diff --git a/scala/src/test/scala/quanto/cosy/test/GraphAnalysisSpec.scala b/scala/src/test/scala/quanto/cosy/test/GraphAnalysisSpec.scala
index c04fe5a9..f510b6dd 100644
--- a/scala/src/test/scala/quanto/cosy/test/GraphAnalysisSpec.scala
+++ b/scala/src/test/scala/quanto/cosy/test/GraphAnalysisSpec.scala
@@ -3,11 +3,15 @@ package quanto.cosy.test
import java.io.File
import org.scalatest.FlatSpec
+import quanto.cosy.BlockGenerators.QuickGraph
import quanto.cosy.RuleSynthesis._
+import quanto.cosy.GraphAnalysis._
import quanto.cosy._
import quanto.data._
import quanto.util.Rational
+import scala.util.matching.Regex
+
/**
* Test files for the Graph Analysis Object
* It is used to calculate various graph properties
@@ -25,13 +29,15 @@ class GraphAnalysisSpec extends FlatSpec {
behavior of "Graph Analysis"
+
+ private val errorGates = quanto.util.FileHelper.readFile[Graph](
+ new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"),
+ Graph.fromJson(_, rg)
+ )
+
it should "compute adjacency matrices" in {
- val targetGraph = quanto.util.FileHelper.readFile[Graph](
- new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"),
- Graph.fromJson(_, rg)
- )
- val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(targetGraph)
+ val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(errorGates)
// boundary b2, error v2, next gate vertex v3
val bIndex = adjacencyMatrix._1.indexOf(VName("b2"))
val eIndex = adjacencyMatrix._1.indexOf(VName("v2"))
@@ -40,18 +46,14 @@ class GraphAnalysisSpec extends FlatSpec {
assert(!adjacencyMatrix._2(bIndex)(vIndex))
- val ghostedErrors = GraphAnalysis.bypassSpecial(GraphAnalysis.detectErrors)(targetGraph, adjacencyMatrix)
+ val ghostedErrors = GraphAnalysis.bypassSpecial(GraphAnalysis.detectPiNodes)(errorGates, adjacencyMatrix)
assert(ghostedErrors._2(bIndex)(eIndex))
assert(ghostedErrors._2(bIndex)(vIndex))
}
it should "calculate distance from ends" in {
- val targetGraph = quanto.util.FileHelper.readFile[Graph](
- new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"),
- Graph.fromJson(_, rg)
- )
- val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(targetGraph)
+ val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(errorGates)
val errorName = VName("v2")
val leftBoundary = VName("b2")
@@ -73,11 +75,7 @@ class GraphAnalysisSpec extends FlatSpec {
it should "find neighbours" in {
- val targetGraph = quanto.util.FileHelper.readFile[Graph](
- new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"),
- Graph.fromJson(_, rg)
- )
- val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(targetGraph)
+ val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(errorGates)
val errorName = VName("v2")
@@ -87,12 +85,8 @@ class GraphAnalysisSpec extends FlatSpec {
it should "calculate distance of a given, ignored set from ends" in {
- val targetGraph = quanto.util.FileHelper.readFile[Graph](
- new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"),
- Graph.fromJson(_, rg)
- )
- val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(targetGraph)
- val ghostedErrors = GraphAnalysis.bypassSpecial(GraphAnalysis.detectErrors)(targetGraph, adjacencyMatrix)
+ val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(errorGates)
+ val ghostedErrors = GraphAnalysis.bypassSpecial(GraphAnalysis.detectPiNodes)(errorGates, adjacencyMatrix)
val errorName = VName("v2")
val leftBoundary = VName("b2")
@@ -119,9 +113,163 @@ class GraphAnalysisSpec extends FlatSpec {
// Now check with the simproc methods
- val eDistances = SimplificationProcedure.PullErrors.errorsDistance(rightBoundaries)(targetGraph, Set(errorName))
+ val eDistances = SimplificationProcedure.PullErrors.errorsDistance(rightBoundaries)(errorGates, Set(errorName))
assert(eDistances.get == 2.0)
}
+ behavior of "Graph comparison"
+
+ it should "compare all these graphs" in {
+ def compare(a: Graph, b: Graph) = GraphAnalysis.zxGraphCompare(a, b)
+
+ implicit def stackToGraph(s: BlockStack) : Graph = s.graph
+ implicit def rowToGraph(s: BlockRow) : Graph = s.graph
+ implicit def blockToGraph(s: Block) : Graph = s.graph
+
+ def zx(s: String) : Graph = BlockGenerators.ZXClifford.filter(b => b.name == s).head
+
+ val hadamard = zx(" H ")
+ val id = zx(" 1 ")
+
+ assert(compare(hadamard, id) > 0)
+ }
+
+ behavior of "Connectiviy analysis"
+
+ it should "leave empty graph disconnected" in {
+ var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1)
+ amat = amat.addVertex(Vector())
+ amat = amat.addVertex(Vector(false))
+ amat = amat.addVertex(Vector(false, false))
+ val cc = connectionClasses(amat)
+ assert(cc == (0 until 3).toVector)
+ }
+
+
+ it should "connect all in line graph" in {
+ var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1)
+ amat = amat.addVertex(Vector())
+ amat = amat.addVertex(Vector(true))
+ amat = amat.addVertex(Vector(false, true))
+ val cc = connectionClasses(amat)
+ assert(cc == Vector(0,0,0))
+ }
+
+
+ it should "find two classes" in {
+ var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1)
+ amat = amat.addVertex(Vector())
+ amat = amat.addVertex(Vector(false))
+ amat = amat.addVertex(Vector(true, false))
+ val cc = connectionClasses(amat)
+ assert(cc == Vector(0,1,0))
+ }
+
+ it should "detect no scalars in all boundaries" in {
+ var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1)
+ amat = amat.addVertex(Vector())
+ amat = amat.addVertex(Vector(false))
+ amat = amat.addVertex(Vector(true, false))
+ assert(!containsScalars(amat))
+ }
+
+ it should "detect no scalars with some non-boundaries" in {
+ var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1)
+ amat = amat.addVertex(Vector())
+ amat = amat.addVertex(Vector(false))
+ amat = amat.addVertex(Vector(true, false))
+ amat = amat.nextType.get
+ amat = amat.addVertex(Vector(false, true, false))
+ assert(!containsScalars(amat))
+ }
+
+
+ it should "detect scalars with some non-boundaries" in {
+ var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1)
+ amat = amat.addVertex(Vector())
+ amat = amat.addVertex(Vector(false))
+ amat = amat.addVertex(Vector(true, false))
+ amat = amat.nextType.get
+ amat = amat.addVertex(Vector(false, false, false))
+ amat = amat.addVertex(Vector(false, false, false, true))
+ assert(containsScalars(amat))
+ }
+
+ behavior of "circuit analysis"
+
+ // Only works for circuits generated inside CoSy
+ // Note that this will not magically give you reductions - it is not graph-invariant by circuit-invariant
+ // Turning a graph upside down gives very a very different measure
+
+ it should "distill circuit placement from name via regex" in {
+ val example = "r-2-bl-1-h-1"
+ val output = CircuitPlacementParser.p(example)
+ assert (output == (2,1,"h"))
+ }
+
+ it should "bias circuits to the left" in {
+ val blocks: List[Block] = BlockGenerators.ZXGates(1)
+ val rows: List[BlockRow] = BlockRowMaker.makeRowsUpToSize(2, blocks, Some(2))
+ val stacks = BlockStackMaker.makeStacksOfSize(1, rows)
+ val e1 = stacks.find(_.toString == "( 1 x 0Z1)").get
+ val e2 = stacks.find(_.toString == "(0Z1 x 1 )").get
+ assert(zxCircuitCompare(e1.graph, e2.graph) > 0)
+ }
+
+ it should "bias circuits down" in {
+ val blocks: List[Block] = BlockGenerators.ZXGates(1)
+ val rows: List[BlockRow] = BlockRowMaker.makeRowsUpToSize(1, blocks, Some(1))
+ val stacks = BlockStackMaker.makeStacksOfSize(2, rows)
+ val e1 = stacks.find(_.toString == "( 1 ) o (0Z1)").get
+ val e2 = stacks.find(_.toString == "(0Z1) o ( 1 )").get
+ assert(zxCircuitCompare(e1.graph, e2.graph) < 0)
+ }
+
+ behavior of "Checking isomorphisms"
+
+ // Boundaries have names like i-20 or o-1
+ val boundaryRegex : Option[Regex] = Some(raw"""(i|o)-(\d+)""".r)
+ def ZXiso(left: Graph, right: Graph): Boolean = checkIsomorphic(rg, boundaryRegex)(left, right)
+
+ it should "not match empty onto non-empty" in {
+ val g = QuickGraph(rg)
+ assert(!ZXiso(g, g.node("Z", "0")))
+ }
+
+ it should "match discrete graphs onto only themselves" in {
+ var g = QuickGraph(rg)
+ val graphs = (for (_ <- 1 to 5) yield {
+ g = g.node("Z", "0")
+ g
+ }).toList.zipWithIndex
+ for(i <- graphs; j <- graphs) {
+ // Check we have made alrge enough graphs; index is out by 1
+ assert(i._1.verts.size == (i._2 + 1))
+ if(i._2 == j._2){
+ assert(ZXiso(i._1, j._1))
+ } else {
+ assert(!ZXiso(i._1, j._1))
+ }
+ }
+ }
+
+ it should "not match different data" in {
+ val g = QuickGraph(rg)
+ assert(!ZXiso(g.node("Z","0"), g.node("Z", "1")))
+ assert(!ZXiso(g.node("Z","0"), g.node("X", "0")))
+ }
+
+ it should "match outputs with the same name" in {
+ val g1 = QuickGraph(rg).node("Z", angle = "0", nodeName = "v").addInput().join("v", "i-0")
+ val g2 = QuickGraph(rg).node("Z", angle = "0", nodeName = "w").addInput().join("w", "i-0")
+ assert(ZXiso(g1,g2))
+ }
+
+ it should "not match boundaries with different names" in {
+ val g1 = QuickGraph(rg).node("Z", angle = "0", nodeName = "v").addInput().join("v", "i-0")
+ val g2 = QuickGraph(rg).node("Z", angle = "0", nodeName = "v").addOutput().join("v", "o-0")
+ assert(!ZXiso(g1,g2))
+ }
+
}
diff --git a/scala/src/test/scala/quanto/cosy/test/InterpreterSpec.scala b/scala/src/test/scala/quanto/cosy/test/InterpreterSpec.scala
index 3dd95070..ddea9606 100644
--- a/scala/src/test/scala/quanto/cosy/test/InterpreterSpec.scala
+++ b/scala/src/test/scala/quanto/cosy/test/InterpreterSpec.scala
@@ -1,32 +1,186 @@
package quanto.cosy.test
import org.scalatest.FlatSpec
-import quanto.cosy.Interpreter.AngleMap
-import quanto.cosy.{AdjMat, Complex, Interpreter, Tensor}
-import quanto.data.{NodeV, Theory}
+import quanto.cosy.BlockGenerators.QuickGraph
+import quanto.cosy.Interpreter._
+import quanto.cosy._
+import quanto.data.Theory.ValueType
+import quanto.data._
import quanto.util.json.JsonObject
/**
* Created by hector on 24/05/17.
*/
class InterpreterSpec extends FlatSpec {
+
+ behavior of "Connecting graphs"
+
+
+ implicit def vname(str: String): VName = VName(str)
+
+ implicit def vname2(strs: (String, String)): (VName, VName) = (vname(strs._1), vname(strs._2))
+
+ implicit def ename(str: String): EName = EName(str)
+
+
+ val cap : Tensor = Tensor(Array(Array[Complex](1, 0, 0, 1)))
+
+ it should "make two id wires" in {
+ var g = new Graph().
+ addVertex(vname("v0"), WireV()).
+ addVertex(vname("v1"), WireV()).
+ addVertex(vname("v2"), WireV()).
+ addVertex(vname("v3"), WireV()).
+ addEdge("e0", UndirEdge(), "v0" -> "v1").
+ addEdge("e1", UndirEdge(), "v2" -> "v3")
+ val tensor = stringGraph(g, cap, List("v0", "v2"), List("v1", "v3"))
+ assert(tensor == Tensor.idWires(2))
+ }
+
+ it should "make two caps" in {
+ var g = new Graph().
+ addVertex(vname("v0"), WireV()).
+ addVertex(vname("v1"), WireV()).
+ addVertex(vname("v2"), WireV()).
+ addVertex(vname("v3"), WireV()).
+ addEdge("e0", UndirEdge(), "v0" -> "v2").
+ addEdge("e1", UndirEdge(), "v1" -> "v3")
+ val tensor = stringGraph(g, cap, List("v0", "v2", "v1","v3"), List())
+ assert(tensor == (cap x cap))
+ }
+
+
+ it should "make cap and wire" in {
+ var g = new Graph().
+ addVertex(vname("v0"), WireV()).
+ addVertex(vname("v1"), WireV()).
+ addVertex(vname("v2"), WireV()).
+ addVertex(vname("v3"), WireV()).
+ addEdge("e0", UndirEdge(), "v0" -> "v3").
+ addEdge("e1", UndirEdge(), "v1" -> "v2")
+ val tensor = stringGraph(g, cap, List("v1", "v2", "v0"), List("v3"))
+ assert(tensor == ( cap x Tensor.idWires(1)))
+ }
+
+
+ it should "make wire and cap" in {
+ var g = new Graph().
+ addVertex(vname("v0"), WireV()).
+ addVertex(vname("v1"), WireV()).
+ addVertex(vname("v2"), WireV()).
+ addVertex(vname("v3"), WireV()).
+ addEdge("e0", UndirEdge(), "v0" -> "v3").
+ addEdge("e1", UndirEdge(), "v1" -> "v2")
+ val tensor = stringGraph(g, cap, List("v0", "v1", "v2"), List("v3"))
+ assert(tensor == (Tensor.idWires(1) x cap))
+ }
+
+
+ it should "make cup and wire" in {
+ var g = new Graph().
+ addVertex(vname("v0"), WireV()).
+ addVertex(vname("v1"), WireV()).
+ addVertex(vname("v2"), WireV()).
+ addVertex(vname("v3"), WireV()).
+ addEdge("e0", UndirEdge(), "v0" -> "v3").
+ addEdge("e1", UndirEdge(), "v1" -> "v2")
+ val tensor = stringGraph(g, cap, List("v0"), List("v1", "v2", "v3"))
+ assert(tensor == ( cap.transpose x Tensor.idWires(1)))
+ }
+
+
+
+ it should "make wire and cup" in {
+ var g = new Graph().
+ addVertex(vname("v0"), WireV()).
+ addVertex(vname("v1"), WireV()).
+ addVertex(vname("v2"), WireV()).
+ addVertex(vname("v3"), WireV()).
+ addEdge("e0", UndirEdge(), "v0" -> "v3").
+ addEdge("e1", UndirEdge(), "v1" -> "v2")
+ val tensor = stringGraph(g, cap, List("v0"), List("v3", "v1", "v2"))
+ assert(tensor == (Tensor.idWires(1) x cap.transpose))
+ }
+
+
+
+ it should "make crossing" in {
+ var g = new Graph().
+ addVertex(vname("v0"), WireV()).
+ addVertex(vname("v1"), WireV()).
+ addVertex(vname("v2"), WireV()).
+ addVertex(vname("v3"), WireV()).
+ addEdge("e0", UndirEdge(), "v0" -> "v3").
+ addEdge("e1", UndirEdge(), "v1" -> "v2")
+ val tensor = stringGraph(g, cap, List("v0", "v1"), List("v2", "v3"))
+ assert(tensor == Tensor.swap(List(1,0)))
+ }
+
+
+
+ it should "make cap and cup" in {
+ var g = new Graph().
+ addVertex(vname("v0"), WireV()).
+ addVertex(vname("v1"), WireV()).
+ addVertex(vname("v2"), WireV()).
+ addVertex(vname("v3"), WireV()).
+ addEdge("e0", UndirEdge(), "v0" -> "v3").
+ addEdge("e1", UndirEdge(), "v1" -> "v2")
+ val tensor = stringGraph(g, cap, List("v0", "v3"), List("v1", "v2"))
+ assert(tensor == Tensor(Array(Array(1,0,0,1),Array(0,0,0,0), Array(0,0,0,0), Array(1,0,0,1))))
+ }
+
+
+
+ it should "make (id x s) o (s x id)" in {
+ var g = QuickGraph.apply(BlockGenerators.ZXTheory).addInput(3).addOutput(3)
+ .join("i-0", "o-2")
+ .join("i-1", "o-0")
+ .join("i-2", "o-1")
+ val tensor = stringGraph(g, cap, List("i-0","i-1","i-2"), List("o-0","o-1","o-2"))
+ assert(tensor == Tensor.swap(List(2, 0, 1)))
+ }
+
+
+ it should "make (s x id) o (id x s)" in {
+ var g = QuickGraph.apply(BlockGenerators.ZXTheory).addInput(3).addOutput(3)
+ .join("i-0", "o-1")
+ .join("i-1", "o-2")
+ .join("i-2", "o-0")
+ val tensor = stringGraph(g, cap, List("i-0","i-1","i-2"), List("o-0","o-1","o-2"))
+ assert(tensor == Tensor.swap(List(1,2,0)))
+ }
+
behavior of "ZX"
val rg = Theory.fromFile("red_green")
val pi = math.Pi
val rdata = Vector(
NodeV(data = JsonObject("type" -> "X", "value" -> "0"), theory = rg),
- NodeV(data = JsonObject("type" -> "X", "value" -> pi.toString), theory = rg),
- NodeV(data = JsonObject("type" -> "X", "value" -> (0.5 * pi).toString), theory = rg),
- NodeV(data = JsonObject("type" -> "X", "value" -> (-0.5 * pi).toString), theory = rg)
+ NodeV(data = JsonObject("type" -> "X", "value" -> "1"), theory = rg),
+ NodeV(data = JsonObject("type" -> "X", "value" -> "1/2"), theory = rg),
+ NodeV(data = JsonObject("type" -> "X", "value" -> "-1/2"), theory = rg)
)
val gdata = Vector(
NodeV(data = JsonObject("type" -> "Z", "value" -> "0"), theory = rg),
- NodeV(data = JsonObject("type" -> "Z", "value" -> pi.toString), theory = rg),
- NodeV(data = JsonObject("type" -> "Z", "value" -> (0.5 * pi).toString), theory = rg),
- NodeV(data = JsonObject("type" -> "Z", "value" -> (-0.5 * pi).toString), theory = rg)
+ NodeV(data = JsonObject("type" -> "Z", "value" -> "1"), theory = rg),
+ NodeV(data = JsonObject("type" -> "Z", "value" -> "1/2"), theory = rg),
+ NodeV(data = JsonObject("type" -> "Z", "value" -> "-1/2"), theory = rg)
)
+ //Change the last number for larger tests
+ //Don't include boundaries as the methods can give permutations of each other's answers
+ val smallAdjMats: Stream[AdjMat] = ColbournReadEnum.enumerate(2, 2, 2, 2)
+
+ implicit def quickGraph(amat: AdjMat): Graph = Graph.fromAdjMat(amat, rdata, gdata)
+
+ implicit def stringToPhase(s: String): PhaseExpression = {
+ PhaseExpression.parse(s, ValueType.AngleExpr)
+ }
+
var one = Complex.one
+ var zero = Complex.zero
+
+ def amatToZXTensor(adjMat: AdjMat) = Interpreter.interpretZXAdjMat(adjMat, rdata, gdata)
it should "make hadamards" in {
// Via "new"
@@ -40,39 +194,39 @@ class InterpreterSpec extends FlatSpec {
}
it should "make Green Spiders" in {
- def gs(angle: Double, in: Int, out: Int) = Interpreter.interpretZXSpider(true, angle, in, out)
+ def gs(angle: String, in: Int, out: Int) = zxSpider(true, angle, in, out)
- var g12a0 = Interpreter.interpretZXSpider(true, 0, 1, 2)
- var g21aPi = Interpreter.interpretZXSpider(true, math.Pi, 2, 1)
- var g11aPi2 = Interpreter.interpretZXSpider(true, math.Pi / 2.0, 1, 1)
- var t = gs(math.Pi / 4, 1, 1)
- var g2 = gs(math.Pi / 2, 1, 1)
+ var g12a0 = zxSpider(true, "0", 1, 2)
+ var g21aPi = zxSpider(true, "pi", 2, 1)
+ var g11aPi2 = zxSpider(true, "pi/2", 1, 1)
+ var t = gs("pi/4", 1, 1)
+ var g2 = gs("pi/2", 1, 1)
assert((t o t).isRoughly(g2))
- assert((g2 o g2).isRoughly(gs(math.Pi, 1, 1)))
+ assert((g2 o g2).isRoughly(gs("pi", 1, 1)))
}
it should "make red spiders" in {
- def rs(angle: Double, in: Int, out: Int) = Interpreter.interpretZXSpider(false, angle, in, out)
+ def rs(angle: String, in: Int, out: Int) = zxSpider(false, angle, in, out)
- assert(rs(math.Pi, 1, 1).isRoughly(Tensor(Array(Array(0, 1), Array(1, 0)))))
- var rp2 = rs(math.Pi / 2, 1, 1)
- assert((rp2 o rp2).isRoughly(rs(math.Pi, 1, 1)))
+ assert(rs("pi", 1, 1).isRoughlyUpToScalar(Tensor(Array(Array(0, 1), Array(1, 0)))))
+ var rp2 = rs("pi/2", 1, 1)
+ assert((rp2 o rp2).isRoughly(rs("pi", 1, 1)))
assert((rp2 o rp2 o rp2 o rp2).isRoughly(Tensor.id(2)))
- assert(Interpreter.cached.contains("ZX:false:3.141592653589793:1:1"))
+ assert(Interpreter.cached.contains("ZX:red:1:1:1"))
}
it should "respect spider law" in {
- var g1 = Interpreter.interpretZXSpider(true, math.Pi / 8, 1, 2)
- var g2 = Interpreter.interpretZXSpider(true, math.Pi / 8, 1, 1)
- var g3 = Interpreter.interpretZXSpider(true, math.Pi / 4, 1, 2)
+ var g1 = zxSpider(true, "pi/8", 1, 2)
+ var g2 = zxSpider(true, "pi/8", 1, 1)
+ var g3 = zxSpider(true, "pi/4", 1, 2)
assert(((Tensor.id(2) x g2) o g1).isRoughly(g3))
}
it should "apply Hadamards" in {
var h1 = Tensor.hadamard
var h2 = h1 x h1
- var g = Interpreter.interpretZXSpider(true, math.Pi / 4, 1, 2)
- var r = Interpreter.interpretZXSpider(false, math.Pi / 4, 1, 2)
+ var g = zxSpider(true, "pi/4", 1, 2)
+ var r = zxSpider(false, "pi/4", 1, 2)
assert((h2 o r o h1).isRoughly(g))
}
@@ -83,7 +237,9 @@ class InterpreterSpec extends FlatSpec {
amat = amat.nextType.get
amat = amat.addVertex(Vector(true, true))
println(amat)
- val i1 = Interpreter.interpretZXAdjMat(amat, redAM = rdata, greenAM = gdata)
+ val i1 = amatToZXTensor(amat)
+ val i2 = amatToGraphToZXTensor(amat)
+ assert(i1.isRoughly(i2))
assert(i1.isRoughly(Tensor(Array(Array(1, 0, 0, 1)))))
}
@@ -97,11 +253,23 @@ class InterpreterSpec extends FlatSpec {
amat = amat.nextType.get
// Red pi
amat = amat.addVertex(Vector(true, true))
+ amat = amat.nextType.get
+ // Green 0
+ amat = amat.nextType.get
+ // Green pi
println(amat)
- val i1 = Interpreter.interpretZXAdjMat(amat, redAM = rdata, greenAM = gdata)
- assert(i1.isRoughly(Interpreter.interpretZXSpider(green = false, math.Pi, 2, 0)))
+ val i1 = amatToZXTensor(amat)
+ val i2 = amatToGraphToZXTensor(amat)
+ assert(i1.isRoughly(i2))
+ assert(i1.isRoughly(zxSpider(false, "pi", 2, 0)))
}
- var zero = Complex.zero
+
+ def amatToGraphToZXTensor(adjMat: AdjMat) = {
+ val asGraph = quickGraph(adjMat)
+ Interpreter.interpretZXGraph(asGraph,
+ asGraph.verts.filter(asGraph.isTerminalWire).toList.sortBy(vn => vn.s), List())
+ }
+
it should "process red spider law" in {
// Simple red and green identities
var amat = new AdjMat(numRedTypes = 4, numGreenTypes = 4)
@@ -117,9 +285,13 @@ class InterpreterSpec extends FlatSpec {
amat = amat.addVertex(Vector(false, true, true))
amat = amat.nextType.get
//red -pi/2
- val i1 = Interpreter.interpretZXAdjMat(amat, redAM = rdata, greenAM = gdata)
- assert(i1.isRoughly(Interpreter.interpretZXSpider(false, -pi / 2, 2, 0)))
+ amat = amat.nextType.get
+ val i1 = amatToZXTensor(amat)
+ val i2 = amatToGraphToZXTensor(amat)
+ assert(i1.isRoughly(i2))
+ assert(i1.isRoughly(zxSpider(false, "-pi / 2", 2, 0)))
}
+
it should "satisfy the Euler identity" in {
// Euler identity
var amat = new AdjMat(numRedTypes = 4, numGreenTypes = 4)
@@ -137,55 +309,55 @@ class InterpreterSpec extends FlatSpec {
amat = amat.nextType.get
amat = amat.nextType.get
amat = amat.addVertex(Vector(false, false, true, true))
+ amat = amat.nextType.get
println(amat)
- val i2 = Interpreter.interpretZXAdjMat(amat, redAM = rdata, greenAM = gdata)
- val i3 = Interpreter.interpretZXSpider(true, 0, 2, 0) o (Tensor.id(2) x Tensor.hadamard)
- println(i3.scaled(i2.contents(0)(0) / i3.contents(0)(0)))
- assert(i2.isRoughly(i3.scaled(i2.contents(0)(0) / i3.contents(0)(0))))
- assert(i2.isRoughlyUpToScalar(Tensor(Array(Array(1, 1, 1, -1)))))
+ val i1 = Interpreter.interpretZXAdjMat(amat, redAM = rdata, greenAM = gdata)
+ val i2 = amatToGraphToZXTensor(amat)
+ assert(i2.isRoughly(i2))
+ val i3 = zxSpider(true, "0", 2, 0) o (Tensor.id(2) x Tensor.hadamard)
+ println(i3.scaled(i1.contents(0)(0) / i3.contents(0)(0)))
+ assert(i1.isRoughly(i3.scaled(i1.contents(0)(0) / i3.contents(0)(0))))
+ assert(i1.isRoughlyUpToScalar(Tensor(Array(Array(1, 1, 1, -1)))))
}
behavior of "ZW"
it should "Evaluate the four-output GHZ spider" in {
- var t = Interpreter.interpretZWSpider(black = true, 4)
- assert(t == Tensor(Array(Array(1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0))).transpose)
+ val t1 = Interpreter.interpretZWSpiderNoInputs(black = true, 4)
+ val g = QuickGraph(Theory.fromFile("ZW")).addOutput(4).node("w", nodeName = "w")
+ .join("w", "o-0")
+ .join("w", "o-1")
+ .join("w", "o-2")
+ .join("w", "o-3")
+
+ val t2 = Interpreter.interpretSpiderGraph(zwSpiderInterpreter)(g, List(), List("o-0", "o-1", "o-2", "o-3"))
+ val t3 = Tensor(Array(Array(1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0))).transpose
+ assert(t1 == t3)
+ assert(t2 == t3)
}
- /* Too intensive!
- it should "agree on rule 5d" in {
- var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1)
- amat = amat.addVertex(Vector())
- amat = amat.addVertex(Vector(false))
- amat = amat.nextType.get
- // Blacks (reds)
- amat = amat.addVertex(Vector(true, false))
- amat = amat.addVertex(Vector(false, false, true))
- amat = amat.addVertex(Vector(false, false, false, true))
- amat = amat.addVertex(Vector(false, true, false, false, true))
- amat = amat.nextType.get
- // Whites (greens)
- amat = amat.addVertex(Vector(false, false, false, true, true, false))
+ it should "make w spiders" in {
+ val g = QuickGraph(Theory.fromFile("ZW")).addInput().addOutput().node("w", nodeName = "w")
+ .join("i-0", "w").join("w", "o-0")
- var lhs = amat.copy()
- amat = new AdjMat(numBoundaries = 2, numRedTypes = 1, numGreenTypes = 1)
- amat = amat.addVertex(Vector())
- amat = amat.addVertex(Vector(false))
- amat = amat.nextType.get
- // Blacks (reds)
- amat = amat.addVertex(Vector(true, false))
- amat = amat.addVertex(Vector(false, false, true))
- amat = amat.addVertex(Vector(false, false, false, false))
- amat = amat.addVertex(Vector(false, true, false, false, true))
+ val t1 = Interpreter.interpretSpiderGraph(zwSpiderInterpreter)(g, List("i-0"), List("o-0"))
+ assert(t1.isRoughly(Tensor(Array(Array(0, 1), Array(1, 0)))))
+
+
+ val t2 = Interpreter.interpretSpiderGraph(zwSpiderInterpreter)(g, List("i-0", "o-0"), List())
+ assert(t2.isRoughly(Tensor(Array(Array(0, 1, 1, 0)))))
- var t1 = Interpreter.interpretZWAdjMat(lhs)
- var t2 = Interpreter.interpretZWAdjMat(amat)
- println(t1)
+
+ val g3 = QuickGraph(Theory.fromFile("ZW")).addInput().addOutput(2).node("w", nodeName = "w")
+ .join("i-0", "w").join("w", "o-0").join("w", "o-1")
+
+
+ val t3 = Interpreter.interpretSpiderGraph(zwSpiderInterpreter)(g3, List("i-0"), List("o-0", "o-1"))
+ assert(t3.isRoughly(Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 0))).transpose))
}
- */
- it should "agree on rule 3a" in {
+ it should "agree on rule nat_c^n" in {
var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1)
amat = amat.addVertex(Vector())
@@ -214,23 +386,155 @@ class InterpreterSpec extends FlatSpec {
amat = amat.addVertex(Vector(false, false, false, true, true))
amat = amat.addVertex(Vector(true, false, false, false, false, true))
- var t1 = Interpreter.interpretZWAdjMat(rhs)
- var t2 = Interpreter.interpretZWAdjMat(amat)
- assert(t1 == t2)
+ var t1 = Interpreter.interpretZWAdjMat(rhs, List("v0"), List("v1", "v2"))
+ var t2 = Interpreter.interpretZWAdjMat(amat, List("v0"), List("v1", "v2"))
+ assert(t1.isRoughly(t2))
}
- it should "agree on rule 5c" in {
+ it should "agree on rule inv" in {
+
+ val g = QuickGraph(Theory.fromFile("ZW")).addInput().addOutput().node("w",nodeName = "w1").node("w",nodeName = "w2")
+ .join("i-0", "w1").join("w1","w2").join("w2", "o-0")
+
+
+ var t1 = Interpreter.interpretSpiderGraph(zwSpiderInterpreter)(g, List("i-0"),List("o-0"))
+ assert(t1.isRoughly(Tensor.idWires(1)))
+ }
+
+ behavior of "block stacks and graph interpreters"
+
+ def zxSpider(isGreen: Boolean, angle: String, inputs: Int, outputs: Int): Tensor =
+ Interpreter.interpretZXSpider(ZXAngleData(isGreen, angle), inputs, outputs)
+
+ it should "interpret Z pi/4" in {
+
+ val cnotGraph = QuickGraph(BlockGenerators.ZXTheory).addInput(1).addOutput(1).node("Z", nodeName = "z", angle = "1/4")
+ .join("i-0", "z")
+ .join("o-0", "z")
+
+ val gt = Interpreter.interpretZXGraph(cnotGraph, List("i-0"), List("o-0"))
+ assert(gt.isRoughly(zxSpider(true, "1/4", 1, 1)))
+ }
+
+
+ it should "interpret X pi/4" in {
+
+ val cnotGraph = QuickGraph(BlockGenerators.ZXTheory).addInput(1).addOutput(1).node("X", nodeName = "x", angle = "1/4")
+ .join("i-0", "x")
+ .join("o-0", "x")
+
+ val gt = Interpreter.interpretZXGraph(cnotGraph, List("i-0"), List("o-0"))
+ assert(gt.isRoughly(zxSpider(false, "1/4", 1, 1)))
+ }
- var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1)
- amat = amat.nextType.get
- // Blacks (reds)
- amat = amat.addVertex(Vector())
- amat = amat.addVertex(Vector(true))
- amat = amat.addVertex(Vector(false, true))
- amat = amat.addVertex(Vector(false, false, true))
- amat = amat.nextType.get
- var t1 = Interpreter.interpretZWAdjMat(amat)
- assert(t1 == Tensor.id(1))
+ it should "interpret X pi" in {
+
+ val cnotGraph = QuickGraph(BlockGenerators.ZXTheory).addInput(1).addOutput(1).node("X", nodeName = "x", angle = "1")
+ .join("i-0", "x")
+ .join("o-0", "x")
+
+ val gt = Interpreter.interpretZXGraph(cnotGraph, List("i-0"), List("o-0"))
+ assert(gt.isRoughly(zxSpider(false, "1", 1, 1)))
+ }
+
+
+ it should "interpret CNOT" in {
+ val cnotGraph = QuickGraph(BlockGenerators.ZXTheory).addInput(2).addOutput(2).node("Z", nodeName = "z").node("X", xCoord = 1, nodeName = "x")
+ .join("i-0", "z").join("i-1", "x")
+ .join("o-0", "z").join("o-1", "x")
+ .join("z", "x")
+
+
+ assert(BlockGenerators.ZXClifford.filter(_.name == "CNT").head.tensor.isRoughlyUpToScalar(
+ Tensor(Array(Array(1, 0, 0, 0), Array(0, 1, 0, 0), Array(0, 0, 0, 1), Array(0, 0, 1, 0)))))
+
+ assert(Interpreter.interpretZXGraph(cnotGraph, List("i-0", "i-1"), List("o-0", "o-1")).isRoughlyUpToScalar(
+ Tensor(Array(Array(1, 0, 0, 0), Array(0, 1, 0, 0), Array(0, 0, 0, 1), Array(0, 0, 1, 0)))))
+
}
+
+ it should "interpret CNOT x id" in {
+ val cnotGraph = QuickGraph(BlockGenerators.ZXTheory).addInput(3).addOutput(3).node("Z", nodeName = "z").node("X", xCoord = 1, nodeName = "x")
+ .join("i-0", "z").join("i-1", "x")
+ .join("o-0", "z").join("o-1", "x")
+ .join("z", "x")
+ .join("i-2", "o-2")
+
+ val height = 1
+ val width = 3
+ val rows = BlockRowMaker.makeRowsUpToSize(width, BlockGenerators.ZXCNOT, Some(width))
+ val stacks = BlockStackMaker.makeStacksOfSize(height, rows)
+ var gate = stacks.filter(_.toString == "(CNT x 1 )")
+
+ val breakdown = Tensor.swap(List(2,1,0)) o
+ (Tensor.idWires(2) x Tensor(Array(Array(1,0,0,0), Array(0,0,0,1)))) o
+ Tensor.swap(List(2,0,1,3)) o
+ (Tensor.idWires(2) x Tensor(Array(Array(1,0,0,1), Array(0,1,1,0))).transpose) o
+ Tensor.swap(List(0,2,1))
+
+
+ val breakdown2 = Tensor(Array(Array(1,0,0,0), Array(0,1,0,0), Array(0,0,0,1), Array(0,0,1,0))) x Tensor.idWires(1)
+
+ assert(breakdown.isRoughlyUpToScalar(breakdown2))
+
+
+ assert(Interpreter.interpretZXGraph(cnotGraph, List("i-0", "i-1","i-2"), List("o-0", "o-1","o-2")).isRoughlyUpToScalar(
+ breakdown
+ ))
+
+
+ assert(Interpreter.interpretZXGraph(cnotGraph, List("i-0", "i-1","i-2"), List("o-0", "o-1","o-2")).isRoughlyUpToScalar(
+ gate.head.tensor
+ ))
+
+ }
+
+
+ it should "interpret green strings" in {
+ val cnotGraph = QuickGraph(BlockGenerators.ZXTheory).addInput(3).addOutput(3)
+ .node("Z", nodeName = "z2")
+ .node("Z", nodeName = "z1")
+ .node("Z", nodeName = "z0")
+ .join("i-0", "z0").join("z0", "o-0")
+ .join("i-1", "z1").join("z1", "o-1")
+ .join("i-2", "z2").join("z2", "o-2")
+
+ assert(Interpreter.interpretZXGraph(cnotGraph, List("i-0", "i-1", "i-2"), List("o-0", "o-1", "o-2")).isRoughlyUpToScalar(
+ Tensor.idWires(3)))
+
+ }
+
+
+ it should "agree on small adjmats" in {
+ val errors: List[AdjMat] = smallAdjMats.filterNot(adj => {
+ val am = amatToZXTensor(adj)
+ val gr = amatToGraphToZXTensor(adj)
+ am.isRoughly(gr)
+ }).toList
+ assert(errors.isEmpty)
+ }
+
+ it should "agree between block stacks and spiders" in {
+ // Will only work up to width 10 unless you change how boundaries are sorted
+ val height = 1
+ val width = 3
+ val rows = BlockRowMaker.makeRowsUpToSize(width, BlockGenerators.ZXClifford.filterNot(b => b.toString == " H "), Some(width))
+ val stacks = BlockStackMaker.makeStacksOfSize(height, rows)
+ stacks.foreach(bs => {
+ val asGraph = bs.graph
+ val tensorFromGraph = Interpreter.interpretZXGraph(asGraph,
+ asGraph.verts.filter(_.s.matches(raw"r-0-i-\d+")).toList.sortBy(vn => vn.s),
+ asGraph.verts.filter(_.s.matches(raw"r-"+(height-1)+raw"-o-\d+")).toList.sortBy(vn => vn.s))
+ assert(bs.tensor.isRoughlyUpToScalar(tensorFromGraph))
+ val asGraph2 = bs.graph
+ val tensorFromGraph2 = Interpreter.interpretZXGraph(asGraph,
+ asGraph.verts.filter(_.s.matches(raw"r-0-i-\d+")).toList.sortBy(vn => vn.s),
+ asGraph.verts.filter(_.s.matches(raw"r-"+(height-1)+raw"-o-\d+")).toList.sortBy(vn => vn.s))
+ assert(bs.tensor.isRoughlyUpToScalar(tensorFromGraph))
+ }
+ )
+ }
+
+
}
diff --git a/scala/src/test/scala/quanto/cosy/test/RuleSynthesisSpec.scala b/scala/src/test/scala/quanto/cosy/test/RuleSynthesisSpec.scala
index eae9e6ec..072166f6 100644
--- a/scala/src/test/scala/quanto/cosy/test/RuleSynthesisSpec.scala
+++ b/scala/src/test/scala/quanto/cosy/test/RuleSynthesisSpec.scala
@@ -6,10 +6,11 @@ import quanto.cosy._
import org.scalatest.FlatSpec
import quanto.data._
import quanto.rewrite.{Matcher, Rewriter}
-import quanto.data.Derivation.DerivationWithHead
import quanto.cosy.RuleSynthesis._
import quanto.cosy.AutoReduce._
+import quanto.cosy.BlockGenerators.QuickGraph
import quanto.data
+import quanto.util.json.JsonObject
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration.Duration
@@ -81,7 +82,8 @@ class RuleSynthesisSpec extends FlatSpec {
)
var r1 = ruleList.head
var m = Matcher.findMatches(r1.lhs, r1.lhs)
- var shrunkRules = RuleSynthesis.discardDirectlyReducibleRules(rules = ruleList, rg, seed = new Random(1))
+ var shrunkRules = RuleSynthesis.discardDirectlyReducibleRules(
+ comparison = basicGraphComparison, rules = ruleList, seed = new Random(1))
println(shrunkRules)
assert(ruleList.length > shrunkRules.length)
}
@@ -90,7 +92,7 @@ class RuleSynthesisSpec extends FlatSpec {
it should "find small rules" in {
var results = EquivClassRunBlockStack(1e-14)
- var rowsAllowed = BlockRowMaker(2, BlockRowMaker.Bian2Qubit, maxInOut = Option(2))
+ var rowsAllowed = BlockRowMaker(2, BlockGenerators.Bian2Qubit, maxInOut = Option(2))
var stacks = BlockStackMaker(2, rowsAllowed)
stacks.foreach(s => results.add(s))
results.equivalenceClassesNormalised
@@ -102,7 +104,7 @@ class RuleSynthesisSpec extends FlatSpec {
it should "find small rules" in {
var results = EquivClassRunBlockStack(1e-14)
- var rowsAllowed = BlockRowMaker(1, BlockRowMaker.ZXQutrit(3), maxInOut = Option(2))
+ var rowsAllowed = BlockRowMaker(1, BlockGenerators.ZXQutrit(3), maxInOut = Option(2))
var stacks = BlockStackMaker(2, rowsAllowed)
stacks.foreach(s => results.add(s))
results.equivalenceClassesNormalised
@@ -115,7 +117,7 @@ class RuleSynthesisSpec extends FlatSpec {
it should "find small rules" in {
var results = EquivClassRunBlockStack(1e-14)
- var rowsAllowed = BlockRowMaker(1, BlockRowMaker.ZXQudit(4, 2), maxInOut = Option(2))
+ var rowsAllowed = BlockRowMaker(1, BlockGenerators.ZXQudit(4, 2), maxInOut = Option(2))
var stacks = BlockStackMaker(2, rowsAllowed)
stacks.foreach(s => results.add(s))
results.equivalenceClasses
@@ -130,7 +132,7 @@ class RuleSynthesisSpec extends FlatSpec {
// Pick out S1, S2 and REDUCIBLE
var smallRules = ctRules.filter(_.name.matches(raw"S\d|RED.*"))
var reducibleGraph = smallRules.filter(_.name.matches(raw"RED.*")).head.lhs
- var resultingDerivation = greedyReduce(RuleSynthesis.graphToDerivation(reducibleGraph, rg), smallRules)
+ var resultingDerivation = greedyReduce(basicGraphComparison, graphToDerivation(reducibleGraph), smallRules)
// println(resultingDerivation.stepsTo(resultingDerivation.firstHead))
assert(Derivation.derivationHeadPairToGraph(resultingDerivation).verts.size < reducibleGraph.verts.size)
}
@@ -138,10 +140,13 @@ class RuleSynthesisSpec extends FlatSpec {
it should "automatically reduce" in {
var ctRules = ZXRules
// Pick out S1, S2 and REDUCIBLE
- var smallRules = ctRules.filter(_.name.matches(raw"S\d.*"))
- var minimisedRules = RuleSynthesis.minimiseRuleset(smallRules ::: smallRules.map(_.inverse), rg)
+ def compare(left: Graph, right: Graph): Int = GraphAnalysis.zxGraphCompare(left, right)
+ var smallRules : List[Rule] =
+ ctRules.filter(_.name.matches(raw"REDUCIBLE")) :::
+ ctRules.filter(_.name.matches(raw"S[12]"))
+ var minimisedRules = RuleSynthesis.greedyReduceRules(compare)(smallRules)
minimisedRules.foreach(println)
- assert(minimisedRules.exists(_.name.matches(raw".*reduced.*")))
+ assert(minimisedRules.head.lhs.verts.size == 2)
}
it should "make a long derivation from annealing" in {
@@ -149,21 +154,152 @@ class RuleSynthesisSpec extends FlatSpec {
var target = ctRules.filter(_.name.matches(raw"RED.*")).head.lhs
var remaining = ctRules.filterNot(_.name.matches(raw"RED.*"))
var annealed = annealingReduce(
- RuleSynthesis.graphToDerivation(target, rg),
+ GraphAnalysis.zxGraphCompare,
+ graphToDerivation(target),
remaining ::: remaining.map(_.inverse),
100,
3,
new Random(3),
None)
assert(annealed._1.steps.size > target.verts.size)
+ assert(quanto.rewrite.Simproc.fromDerivationWithHead(annealed).hasNext)
}
it should "randomly apply appropriate rules" in {
var ctRules = ZXRules
var target = ctRules.filter(_.name.matches(raw"RED.*")).head.lhs
var remaining = ctRules.filter(_.name.matches(raw"S\d+.*"))
- val reducedDerivation = randomApply((new Derivation(rg, target), None),
+ val reducedDerivation = randomApply((new Derivation(target), None),
remaining, 100, alwaysTrue, new Random(1))
assert(reducedDerivation._1.steps(reducedDerivation._2.get).graph < target)
}
+ behavior of "reducing rulesets"
+
+ it should "greedily reduce one of these rules" in {
+ // Starting with rules a-> b and b -> c
+ // End with a -> c and b -> c
+ def reduce(listRules : List[Rule]): List[Rule] = greedyReduceRules(GraphAnalysis.zxGraphCompare)(listRules)
+
+ val theory = BlockGenerators.ZXTheory
+ val Z = NodeV(data = JsonObject("type" -> "Z"), theory = theory)
+ val r1 = new Rule(
+ QuickGraph(theory).addVertex(VName("v1"), Z)
+ .addVertex(VName("v2"), Z)
+ .addEdge(EName("e1"), UndirEdge(), (VName("v1"), VName("v2"))),
+ QuickGraph(theory).addVertex(VName("v1"), Z))
+ var r2 = new Rule(
+ QuickGraph(theory).addVertex(VName("v1"), Z),
+ QuickGraph(theory)
+ )
+ val start: List[Rule] = List(r1, r2)
+ val end : List[Rule] = reduce(start)
+ assert(start.toSet.intersect(end.toSet).size == 1)
+
+ // Check that it only runs rules forwards:
+ val startFlipped = start.map(_.inverse)
+ val endFlipped = reduce(startFlipped)
+ assert(startFlipped.toSet.intersect(endFlipped.toSet).size == 2)
+ // Two identical rules should not annihilate each other
+ // (One should be reduced, the other left as-is)
+ val duplicate = List(r1,r1.copy(description = "Different"))
+ val duplicateReduced = reduce(duplicate)
+ assert(duplicateReduced.toSet.intersect(duplicate.toSet).size == 1)
+ }
+
+ behavior of "colour swapping"
+
+ it should "Send X to Z" in {
+ val theory = rg
+ val g = QuickGraph(theory).addInput().
+ node(nodeType = "Z",nodeName = "z").
+ node(nodeType = "X",nodeName = "x").
+ node(nodeType = "hadamard",nodeName = "h").
+ join("x", "z")
+ val r = new Rule(g,g)
+ val r2 = r.colourSwap(Map("Z" -> "X", "X" -> "Z"))
+ assert(r2.lhs.vdata(VName("z")).typ == "X")
+ assert(r2.lhs.vdata(VName("x")).typ == "Z")
+ assert(r2.lhs.vdata(VName("h")).typ == "hadamard")
+ }
+
+ behavior of "extending rules"
+
+ it should "not try and extend the following rules" in {
+ val theory = rg
+ val r1l = QuickGraph(theory).addInput().node(nodeType = "Z",nodeName = "v").join("v","i-0")
+ val r1r = QuickGraph(theory).addInput().node(nodeType = "X",nodeName = "v").join("v","i-0")
+ val r1 = Rule(r1l, r1r)
+ assert(extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex) == r1)
+ }
+
+ it should "try and extend the following" in {
+ val theory = rg
+ val Z = NodeV(data = JsonObject("type" -> "Z"), theory = theory)
+ val X = NodeV(data = JsonObject("type" -> "X"), theory = theory)
+ val r1l = QuickGraph(theory).addInput().node(nodeType = "Z",nodeName = "v").join("v","i-0")
+ val r1 = Rule(r1l, r1l)
+ val extended = extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex)
+ assert(extended != r1)
+ assert(extended.hasBBoxes)
+ assert(extended.lhs.bboxesContaining(VName("i-0")).nonEmpty)
+ }
+
+ it should "satisfy one-input, one-node expansion" in {
+ val g = QuickGraph(rg)
+ val g1 = g.addInput().node(nodeType = "Z", nodeName = "z").join("i-0","z")
+ val r1 = Rule(g1,g1)
+ val ext = extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex)
+ assert(ext.lhs.bboxesContaining(VName("i-0")).nonEmpty)
+ assert(ext.lhs.bboxesContaining(VName("z")).isEmpty)
+ }
+
+ it should "satisfy two-input, two-node expansion" in {
+ val g = QuickGraph(rg)
+ val g1 = g.addInput(2).
+ node(nodeType = "Z", nodeName = "z").join("i-0","z").
+ node(nodeType = "X", nodeName = "x").join("i-1","x").
+ join("z","x")
+ val g2 = g.addInput(2).
+ node(nodeType = "Z", nodeName = "z2").join("i-0","z2").
+ node(nodeType = "X", nodeName = "x2").join("i-1","x2").
+ join("z2", "x2")
+ val r1 = Rule(g1,g2)
+ val ext = extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex)
+ assert(ext.lhs.bboxesContaining(VName("i-0")).nonEmpty)
+ assert(ext.lhs.bboxesContaining(VName("i-1")).nonEmpty)
+ assert(ext.lhs.bboxesContaining(VName("z")).isEmpty)
+ assert(ext.lhs.bboxesContaining(VName("x")).isEmpty)
+ }
+
+
+ it should "satisfy two-input, one-node expansion" in {
+ val g = QuickGraph(rg)
+ val g1 = g.addInput(2).
+ node(nodeType = "Z", nodeName = "z").
+ join("i-0","z").
+ join("i-1","z")
+ val r1 = Rule(g1,g1)
+ val ext = extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex)
+ assert(ext.lhs.adjacentNodesAndBoundaries(VName("z")).size == 1)
+ assert(ext.lhs.bboxesContaining(VName("z")).isEmpty)
+ }
+
+
+ it should "satisfy one-input, two-node expansion" in {
+ val g = QuickGraph(rg)
+ val g1 = g.addInput().node(nodeType = "Z", angle = "pi", nodeName = "z").join("i-0","z").
+ node(nodeType = "X", nodeName = "x").join("z","x")
+ val r1 = Rule(g1,g1)
+ val ext = extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex)
+ assert(ext.lhs.bboxesContaining(VName ("i-0")).nonEmpty)
+ }
+
+ it should "not match different colours, one-input, one-node" in {
+ val g = QuickGraph(rg)
+ val g1 = g.addInput().node(nodeType = "Z", nodeName = "v").join("i-0","v")
+ val g2 = g.addInput().node(nodeType = "X", nodeName = "v").join("i-0","v")
+ val r1 = Rule(g1,g2)
+ assert(extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex) == r1)
+ }
+
}
\ No newline at end of file
diff --git a/scala/src/test/scala/quanto/cosy/test/SimplificationProcedureSpec.scala b/scala/src/test/scala/quanto/cosy/test/SimplificationProcedureSpec.scala
index 1cc9b9ac..d559bb60 100644
--- a/scala/src/test/scala/quanto/cosy/test/SimplificationProcedureSpec.scala
+++ b/scala/src/test/scala/quanto/cosy/test/SimplificationProcedureSpec.scala
@@ -43,7 +43,7 @@ class SimplificationProcedureSpec extends FlatSpec {
3,
None)
val simplificationProcedure = new SimplificationProcedure[State](
- (new Derivation(rg, target), None),
+ (new Derivation(target), None),
initialState,
step,
progress,
@@ -77,7 +77,7 @@ class SimplificationProcedureSpec extends FlatSpec {
val t1 = raw"\beta"
val targetString: String = t1.replaceAll(raw"\\", raw"\\\\")
val replacementString: String = raw"\pi".replaceAll(raw"\\", raw"\\\\")
- val initialDerivation = graphToDerivation(target, rg)
+ val initialDerivation = graphToDerivation(target)
import SimplificationProcedure.Evaluation._
if (targetString.length > 0) {
@@ -125,7 +125,7 @@ class SimplificationProcedureSpec extends FlatSpec {
new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"),
Graph.fromJson(_, rg)
)
- val initialDerivation: DerivationWithHead = graphToDerivation(targetGraph, rg)
+ val initialDerivation: DerivationWithHead = graphToDerivation(targetGraph)
val graph = Derivation.derivationHeadPairToGraph(initialDerivation)
val boundaries = graph.verts.filter(v => graph.vdata(v).isBoundary)
import SimplificationProcedure.LTEByWeight._
@@ -175,7 +175,7 @@ class SimplificationProcedureSpec extends FlatSpec {
new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"),
Graph.fromJson(_, rg)
)
- val initialDerivation: DerivationWithHead = graphToDerivation(targetGraph, rg)
+ val initialDerivation: DerivationWithHead = graphToDerivation(targetGraph)
val graph = Derivation.derivationHeadPairToGraph(initialDerivation)
val boundaries = graph.verts.filter(v => graph.vdata(v).isBoundary)
import SimplificationProcedure.PullErrors._
@@ -213,7 +213,7 @@ class SimplificationProcedureSpec extends FlatSpec {
case Success(d) =>
println("Success")
println(data.Derivation.derivationHeadPairToGraph(d).vdata)
- assert(errorsDistance(targets.toSet)(d, GraphAnalysis.detectErrors(d)).get < 1)
+ assert(errorsDistance(targets.toSet)(d, GraphAnalysis.detectPiNodes(d)).get < 1)
case Failure(_) =>
assert(false)
}
@@ -221,61 +221,6 @@ class SimplificationProcedureSpec extends FlatSpec {
}
}
- // AK: This one takes too long to run as a unit test.
- // TODO: find a better solution for one-off tests like this.
-
- ignore should "perform large error pull simplifications" in {
- val allowedRules = ZXErrorRules
- val targetGraph = quanto.util.FileHelper.readFile[Graph](
- new File(examplesDirectory + "ZX_errors/Huge_With_Error.qgraph"),
- Graph.fromJson(_, rg)
- )
- val initialDerivation: DerivationWithHead = graphToDerivation(targetGraph, rg)
- val graph = Derivation.derivationHeadPairToGraph(initialDerivation)
- val boundaries = graph.verts.filter(v => graph.vdata(v).isBoundary)
- import SimplificationProcedure.PullErrors._
- val targets = boundaries.filter(t => t.toString.matches(raw"b(1[1-9]|2\d)")).toList
- if (targets.nonEmpty) {
- val initialState: State = State(
- allowedRules.filterNot(r => r.name.matches(raw".*g_ann")),
- 0,
- None,
- new Random(),
- weightFunction = errorsDistance(targets.toSet),
- Some(ZXErrorRules.filter(r => r.name.matches(raw".*g_ann"))),
- None,
- heldVertices = None,
- None
- )
- val simplificationProcedure = new SimplificationProcedure[State](
- initialDerivation,
- initialState,
- step,
- progress,
- (_, state) => state.currentStep == state.maxSteps.getOrElse(-1) || state.currentDistance.getOrElse(2.0) < 1
- )
- var returningDerivation = simplificationProcedure.initialDerivation
- val backgroundDerivation: Future[DerivationWithHead] = Future[DerivationWithHead] {
- //println("future started")
- while (!simplificationProcedure.stopped) {
- //println("futured loop")
- simplificationProcedure.step()
- returningDerivation = simplificationProcedure.current
- }
- simplificationProcedure.current
- }
- backgroundDerivation onComplete {
- case Success(_) =>
- println("Success")
- println(simplificationProcedure.state.currentDistance)
- println(data.Derivation.derivationHeadPairToGraph(returningDerivation).vdata)
- assert(true)
- case Failure(_) =>
- assert(false)
- }
- Await.result(backgroundDerivation, Duration(waitTime, "seconds"))
- }
- }
}
\ No newline at end of file
diff --git a/scala/src/test/scala/quanto/cosy/test/TensorSpec.scala b/scala/src/test/scala/quanto/cosy/test/TensorSpec.scala
index f13938a8..71c7f4a5 100644
--- a/scala/src/test/scala/quanto/cosy/test/TensorSpec.scala
+++ b/scala/src/test/scala/quanto/cosy/test/TensorSpec.scala
@@ -80,13 +80,16 @@ class TensorSpec extends FlatSpec {
assert((p2 o l1.t) == l2.t)
}
- it should "construct swap matrices" in {
+ it should "construct 2-wire swap matrices" in {
var s1 = Tensor.swap(2, x => 1 - x)
assert(s1.toString == "1 0 0 0\n0 0 1 0\n0 1 0 0\n0 0 0 1")
- var s2 = Tensor.swap(List(0, 2, 1))
- assert(s2.toStringSparse ==
- "1 . . . . . . .\n. . 1 . . . . .\n. 1 . . . . . .\n. . . 1 . . . ." +
- "\n. . . . 1 . . .\n. . . . . . 1 .\n. . . . . 1 . .\n. . . . . . . 1")
+ }
+
+ it should "make swaps the right way up" in {
+ val id = Tensor.idWires(1)
+ val s1 = Tensor.swap(List(1,0))
+ val s2 = Tensor.swap(List(2,0,1))
+ assert(((id x s1) o (s1 x id)) == s2)
}
it should "plug tensors into other tensors" in {
@@ -119,8 +122,8 @@ class TensorSpec extends FlatSpec {
}
it should "create direct sums" in {
- var t1 = Tensor.diagonal(Array[Complex](1,2))
- var t2 = Tensor(Array(Array(3,4)))
+ var t1 = Tensor.diagonal(Array[Complex](1, 2))
+ var t2 = Tensor(Array(Array(3, 4)))
assert((t1 sum t2) == Tensor(Array(Array(1, 0, 0, 0), Array(0, 2, 0, 0), Array(0, 0, 3, 4))))
assert((t1 sum t2.transpose) == Tensor(Array(Array(1, 0, 0), Array(0, 2, 0), Array(0, 0, 3), Array(0, 0, 4))))
}
@@ -161,4 +164,18 @@ class TensorSpec extends FlatSpec {
var t3 = new Tensor(Array(Array(Complex(0, 1), zero), Array(zero, Complex(1, 0))))
assert(!t1.isRoughlyUpToScalar(t3))
}
+
+
+ it should "compare zeroes" in {
+ var t1 = new Tensor(Array(Array(zero, zero)))
+ var t2 = Tensor(Array(Array(zero,Complex(0.000001,0))))
+ assert(!t1.isRoughlyUpToScalar(t2))
+ assert(t1.isRoughlyUpToScalar(t2.scaled(Complex(0.00000000001,0))))
+ }
+
+ it should "compare zero and one" in {
+ var t1 = new Tensor(Array(Array(Complex(0, 0.00000000000000001))))
+ var t2 = Tensor(Array(Array(one)))
+ assert(!t1.isRoughlyUpToScalar(t2))
+ }
}
diff --git a/scala/src/test/scala/quanto/data/test/AngleExpressionSpec.scala b/scala/src/test/scala/quanto/data/test/AngleExpressionSpec.scala
deleted file mode 100644
index c5cfd6af..00000000
--- a/scala/src/test/scala/quanto/data/test/AngleExpressionSpec.scala
+++ /dev/null
@@ -1,166 +0,0 @@
-package quanto.data.test
-
-import org.scalatest._
-import quanto.data._
-import AngleExpression._
-import quanto.util.Rational
-
-class AngleExpressionSpec extends FlatSpec {
- behavior of "A rational number"
-
- it should "add correctly" in {
- assert(Rational(1,2) + Rational(2,3) === Rational(7,6))
- }
-
- behavior of "An angle expression"
-
- def testReparse(e : AngleExpression) {
- assert(e === parse(e.toString))
- }
-
- it should "compare expressions" in {
- val a = AngleExpression(Rational(0), Map("a" -> Rational(1)))
- val b = AngleExpression(Rational(0), Map("b" -> Rational(1)))
- assert(ZERO === AngleExpression(Rational(0)))
- assert(ONE_PI === AngleExpression(Rational(1)))
- assert(a + b === b + a)
- assert(ZERO !== ONE_PI)
- assert(ONE_PI + (a * Rational(1,2)) + (b * 4) === (a * Rational(1,2)) + ONE_PI + (b * 4))
- assert((a * Rational(1,2)) + (a * Rational(2,3)) === (a * Rational(7,6)))
- }
-
- it should "parse '0'" in {
- testReparse(ZERO)
- assert(parse("") === ZERO)
- assert(parse("0") === ZERO)
- }
-
- it should "parse 'PI'" in {
- testReparse(ONE_PI)
- assert(parse("\\pi") === ONE_PI)
- assert(parse("1\\pi") === ONE_PI)
- assert(parse("1*\\pi") === ONE_PI)
- assert(parse("1/1\\pi") === ONE_PI)
- assert(parse("1/1*\\pi") === ONE_PI)
- assert(parse("1\\pi/1") === ONE_PI)
- assert(parse("1*\\pi/1") === ONE_PI)
- assert(parse("\\pi/1") === ONE_PI)
-
- assert(parse("pi") === ONE_PI)
- assert(parse("1pi") === ONE_PI)
- assert(parse("1*pi") === ONE_PI)
- assert(parse("1/1pi") === ONE_PI)
- assert(parse("1/1*pi") === ONE_PI)
- assert(parse("1pi/1") === ONE_PI)
- assert(parse("1*pi/1") === ONE_PI)
- assert(parse("pi/1") === ONE_PI)
-
- assert(parse("PI") === ONE_PI)
- assert(parse("1PI") === ONE_PI)
- assert(parse("1*PI") === ONE_PI)
- assert(parse("1/1PI") === ONE_PI)
- assert(parse("1/1*PI") === ONE_PI)
- assert(parse("1PI/1") === ONE_PI)
- assert(parse("1*PI/1") === ONE_PI)
- assert(parse("PI/1") === ONE_PI)
- }
-
- it should "parse 'a'" in {
- val a = AngleExpression(Rational(0), Map("a" -> Rational(1)))
- testReparse(a)
- assert(parse("a") === a)
- assert(parse("1a") === a)
- assert(parse("1*a") === a)
- assert(parse("1/1a") === a)
- assert(parse("1/1*a") === a)
- assert(parse("1a/1") === a)
- assert(parse("1*a/1") === a)
- assert(parse("a/1") === a)
- }
-
- it should "parse '-PI'" in {
- val minusPI = ONE_PI * -1
- testReparse(minusPI)
- assert(parse("-pi") === minusPI)
- assert(parse("-1pi") === minusPI)
- assert(parse("-1*pi") === minusPI)
- assert(parse("-1/1pi") === minusPI)
- assert(parse("-1/1*pi") === minusPI)
- assert(parse("-1pi/1") === minusPI)
- assert(parse("-1*pi/1") === minusPI)
- assert(parse("-pi/1") === minusPI)
- }
-
- it should "parse '-a'" in {
- val minusA = AngleExpression(Rational(0), Map("a" -> Rational(-1)))
- testReparse(minusA)
- assert(parse("-a") === minusA)
- assert(parse("-1a") === minusA)
- assert(parse("-1*a") === minusA)
- assert(parse("-1/1a") === minusA)
- assert(parse("-1/1*a") === minusA)
- assert(parse("-1a/1") === minusA)
- assert(parse("-1*a/1") === minusA)
- assert(parse("-a/1") === minusA)
- }
-
- it should "parse '+-3/4 PI'" in {
- val tfPI = AngleExpression(Rational(3,4))
- testReparse(tfPI)
- assert(parse("3/4") === tfPI)
- assert(parse("3/4pi") === tfPI)
- assert(parse("3/4*pi") === tfPI)
- assert(parse("3pi/4") === tfPI)
- assert(parse("3*pi/4") === tfPI)
-
- val mtfPI = AngleExpression(Rational(-3,4))
- testReparse(mtfPI)
- assert(parse("-3/4") === mtfPI)
- assert(parse("-3/4pi") === mtfPI)
- assert(parse("-3/4*pi") === mtfPI)
- assert(parse("-3pi/4") === mtfPI)
- assert(parse("-3*pi/4") === mtfPI)
- }
-
- it should "parse '+-1/4 PI' and '+-1/4 a'" in {
- val fPI = AngleExpression(Rational(1,4))
- val mfPI = fPI * -1
- val fA = AngleExpression(Rational(0), Map("a" -> Rational(1,4)))
- val mfA = fA * -1
- assert(parse("pi/4") === fPI)
- assert(parse("-pi/4") === mfPI)
- assert(parse("a/4") === fA)
- assert(parse("-a/4") === mfA)
- }
-
- it should "parse addition and subtraction correctly" in {
- val a = AngleExpression(Rational(0), Map("a" -> Rational(1)))
- val b = AngleExpression(Rational(0), Map("b" -> Rational(1)))
- testReparse(a + b)
- testReparse(a - b)
- testReparse((a * -1) - b)
- assert(parse("a + b") === a + b)
- assert(parse("a - b") === a - b)
- assert(parse("-a + b") === b - a)
- assert(parse("- a - b") === (a * -1) - b)
- assert(parse("-(a + b)") === (a * -1) - b)
- assert(parse("-(a - b)") === b - a)
- }
-
- it should "throw an exception on failed parse" in {
- intercept[AngleParseException] { parse("x + ") }
- intercept[AngleParseException] { parse("%") }
- }
-
- it should "do substitutions correctly" in {
- val e1 = parse("x - 2 y")
- val e2 = parse("a + b - c")
- assert(e1.subst("x", e2) === parse("a + b - c - 2y"))
- assert(e1.subst("y", e2) === parse("x - 2a - 2b + 2c"))
- }
-
- it should "evaluate a polynomial" in {
- val e1 = parse("2x + 3/4")
- assert(Math.abs(e1.evaluate(Map("x" -> 1.0/8.0)) - 1) < 1e-15)
- }
-}
diff --git a/scala/src/test/scala/quanto/data/test/BooleanExpressionSpec.scala b/scala/src/test/scala/quanto/data/test/BooleanExpressionSpec.scala
new file mode 100644
index 00000000..552061df
--- /dev/null
+++ b/scala/src/test/scala/quanto/data/test/BooleanExpressionSpec.scala
@@ -0,0 +1,122 @@
+package quanto.data.test
+
+import org.scalatest._
+import quanto.data.Theory.ValueType
+import quanto.data.{PhaseParseException, PhaseExpression}
+import quanto.util.Rational
+
+class BooleanExpressionSpec extends FlatSpec {
+ behavior of "A boolean expression"
+
+ def BooleanExpression(constant: Rational) = PhaseExpression(constant, Map(), ValueType.Boolean)
+
+ def BooleanExpression(constant: Int, coefficients: Map[String, Rational]) =
+ PhaseExpression(constant, coefficients, ValueType.Boolean)
+
+ def BOOL_FALSE = PhaseExpression.zero(ValueType.Boolean)
+
+ def BOOL_TRUE = PhaseExpression.one(ValueType.Boolean)
+
+ def testReparse(e: PhaseExpression) {
+ assert(e === parse(e.toString))
+ }
+
+ def parse(s: String): PhaseExpression = PhaseExpression.parse(s, ValueType.Boolean)
+
+
+ val a = BooleanExpression(0,Map("a"-> 1))
+ val b = BooleanExpression(0,Map("b"-> 1))
+
+ it should "output to string" in {
+ assert(BOOL_FALSE.toString == "\\False")
+ assert(BOOL_TRUE.toString == "\\True")
+ assert(BooleanExpression(0,Map("a"-> 1)).toString == "a")
+ assert(BooleanExpression(1,Map("a"-> 1)).toString == "\\True + a")
+ assert(BooleanExpression(1,Map("a"-> 1, "b"-> 1)).toString == "\\True + a + b")
+ assert(BooleanExpression(0,Map("a"-> 1, "b"-> 1)).toString == "a + b")
+ }
+
+ it should "compare expressions" in {
+ assert(BOOL_FALSE === BooleanExpression(0))
+ assert(BOOL_TRUE === BooleanExpression(1))
+ assert(a + b === b + a)
+ assert(BOOL_FALSE !== BOOL_TRUE)
+ assert(BOOL_TRUE + (a * 1) + (b * 4) === (a * 1) + BOOL_TRUE + (b * 4))
+ assert((a * 1) + (a * 1) === (a * 0))
+ }
+
+ it should "parse '0'" in {
+ assert(BOOL_FALSE.toString == "\\False")
+ testReparse(BOOL_FALSE)
+ assert(parse("") === BOOL_FALSE)
+ assert(parse("0") === BOOL_FALSE)
+ }
+
+ it should "parse 't'" in {
+ testReparse(BOOL_TRUE)
+ assert(parse("\\t") === BOOL_TRUE)
+ assert(parse("1*\\t") === BOOL_TRUE)
+
+ assert(parse("t") === BOOL_TRUE)
+ assert(parse("1*t") === BOOL_TRUE)
+
+ assert(parse("True") === BOOL_TRUE)
+ assert(parse("1*True") === BOOL_TRUE)
+ }
+
+ it should "parse 'a'" in {
+ testReparse(a)
+ assert(parse("a") === a)
+ assert(parse("1*a") === a)
+ }
+
+ it should "parse '-T'" in {
+ val minusTrue = BOOL_TRUE * -1
+ assert(minusTrue === BOOL_TRUE)
+ testReparse(minusTrue)
+ assert(parse("-t") === minusTrue)
+ assert(parse("-1*t") === minusTrue)
+ }
+
+ it should "parse '-a'" in {
+ val minusA = a * -1
+ testReparse(minusA)
+ assert(parse("-a") === minusA)
+ assert(parse("-1*a") === minusA)
+ }
+
+ it should "parse '+-1'" in {
+ val t = BooleanExpression(1)
+ testReparse(t)
+ assert(parse("-1") === t)
+ assert(parse("-1*true") === t)
+ assert(parse("+-1") === t)
+ }
+
+ it should "parse addition and subtraction correctly" in {
+ testReparse(a + b)
+ testReparse(a - b)
+ testReparse((a * -1) - b)
+ assert(parse("a + b") === a + b)
+ assert(parse("a - b") === a - b)
+ assert(parse("-a + b") === b - a)
+ assert(parse("- a - b") === (a * -1) - b)
+ assert(parse("-(a + b)") === (a * -1) - b)
+ assert(parse("-(a - b)") === b - a)
+ }
+
+
+ it should "do substitutions correctly" in {
+ val e1 = parse("x - 2*y")
+ val e2 = parse("a + b - c")
+ assert(e1.subst("x", e2) === parse("a + b - c - 2y"))
+ assert(e1.subst("y", e2) === parse("x - 2a - 2b + 2c"))
+ }
+
+ it should "evaluate a sum" in {
+ val e1 = parse("x + 1")
+ assert(e1.evaluate(Map("x" -> 1)) == 0)
+ val e2 = parse("x + y + z")
+ assert(e2.evaluate(Map("x" -> 1, "y" ->1, "z" ->0)) == 0)
+ }
+}
diff --git a/scala/src/test/scala/quanto/data/test/CompositeExpressionSpec.scala b/scala/src/test/scala/quanto/data/test/CompositeExpressionSpec.scala
new file mode 100644
index 00000000..3aad80c1
--- /dev/null
+++ b/scala/src/test/scala/quanto/data/test/CompositeExpressionSpec.scala
@@ -0,0 +1,41 @@
+package quanto.data.test
+
+import org.scalatest._
+import quanto.data.CompositeExpression._
+import quanto.data.Theory.ValueType
+import quanto.data._
+import quanto.util.Rational
+
+class CompositeExpressionSpec extends FlatSpec {
+
+ behavior of "Type Parsing"
+
+ val vs : List[ValueType] = ValueType.values.toList
+
+ it should "parse singletons" in {
+ assert(parseTypes("Angle") == Vector(ValueType.AngleExpr))
+ assert(parseTypes("string") == Vector(ValueType.String))
+ assert(parseTypes("String") == Vector(ValueType.String))
+ assert(parseTypes("angle_expr") == Vector(ValueType.AngleExpr))
+ assert(parseTypes("long") == Vector(ValueType.Long))
+ assert(parseTypes("Empty") == Vector(ValueType.Empty))
+ assert(parseTypes("empty") == Vector(ValueType.Empty))
+ }
+
+ it should "parse pairs" in {
+ var pairs = vs.flatMap(v => vs.map(w => (v,w)))
+ pairs.foreach(p => {
+ var v1 = p._1
+ var v2 = p._2
+ var combined = s"$v1, $v2"
+ assert(parseTypes(combined) === Vector(v1, v2))
+ }
+ )
+ }
+
+ it should "parse all in a row" in {
+ var all = vs.mkString("(", ", ", ")")
+ assert(parseTypes(all) === vs.toVector)
+ }
+
+}
\ No newline at end of file
diff --git a/scala/src/test/scala/quanto/data/test/GraphAdjmatSpec.scala b/scala/src/test/scala/quanto/data/test/GraphAdjmatSpec.scala
index 019b6744..248ef8e9 100644
--- a/scala/src/test/scala/quanto/data/test/GraphAdjmatSpec.scala
+++ b/scala/src/test/scala/quanto/data/test/GraphAdjmatSpec.scala
@@ -65,9 +65,13 @@ class GraphAdjMatSpec extends FlatSpec {
|}
""".stripMargin), rg)
- assert(g.isBoundary("v0"))
- assert(g.isBoundary("v1"))
- assert(g === g1)
+ assert(g.isTerminalWire("v0"))
+ assert(g.isTerminalWire("v1"))
+ var g2 = g1.copy()
+ g1.verts.foreach(vn => g2 = g2.updateVData(vn) { vd => vd.withCoord(0, 0) })
+ var g3 = g.copy()
+ g.verts.foreach(vn => g3 = g3.updateVData(vn) { vd => vd.withCoord(0, 0) })
+ assert(g3 === g2)
// LAYOUT A GRAPH LIKE THIS:
// val layoutProc = new ForceLayout
diff --git a/scala/src/test/scala/quanto/data/test/GraphSpec.scala b/scala/src/test/scala/quanto/data/test/GraphSpec.scala
index 76b721a9..92f9adda 100644
--- a/scala/src/test/scala/quanto/data/test/GraphSpec.scala
+++ b/scala/src/test/scala/quanto/data/test/GraphSpec.scala
@@ -3,12 +3,15 @@ package quanto.data.test
import org.scalatest._
import quanto.data._
import quanto.data.Names._
+import quanto.data.Theory.{EdgeDesc, EdgeStyleDesc, ValueDesc, ValueType}
import quanto.util.json._
+
import scala.collection.immutable.TreeSet
class GraphSpec extends FlatSpec with GivenWhenThen {
- val rg = Theory.fromFile("red_green")
+ val rg : Theory = Theory.fromFile("red_green")
+ val composite_thy : Theory = Theory.fromFile("composite")
behavior of "A graph"
var g : Graph = _
@@ -70,9 +73,34 @@ class GraphSpec extends FlatSpec with GivenWhenThen {
it should "be equal to its copy" in {
val g1 = g.copy()
+ assert(g1 != null)
assert(g1 === g)
}
+
+ it should "store and retrieve non-default edge types" in {
+
+ val eDesc = EdgeDesc(
+ value = ValueDesc(typ = Vector(ValueType.Empty)),
+ style = EdgeStyleDesc(),
+ defaultData = JsonObject("type" -> "recorded")
+ )
+
+ val TwoWireTheory = Theory.DefaultTheory.mixin(Map(), Map("recorded" -> eDesc), Some("TwoWireTheory"))
+
+ val eData = UndirEdge(eDesc.defaultData, JsonObject(), TwoWireTheory)
+
+ val g = new Graph()
+ .addVertex(VName("v1"), NodeV())
+ .addVertex(VName("v2"), NodeV())
+ .addEdge(EName("e"), eData, ("v1", "v2"))
+ .addEdge(EName("f"), UndirEdge(), ("v1", "v2"))
+ val json = g.toJson(TwoWireTheory)
+ assert (! (json / "undir_edges" / "e").isEmpty)
+ assert (! (json / "undir_edges" / "e" /"data").isEmpty)
+ assert (((json / "undir_edges" / "f") ? "data").isEmpty)
+ }
+
behavior of "Another graph"
var otherG : Graph = _
@@ -160,6 +188,96 @@ class GraphSpec extends FlatSpec with GivenWhenThen {
assert(jsonGraphShouldBe === Graph.fromJson(Graph.toJson(jsonGraphShouldBe)))
}
+ behavior of "Normalisation"
+
+ it should "Normalise a single wire vertex to a single wire vertex" in {
+ val g1 = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "wire_vertices": ["w0"],
+ | "undir_edges": {
+ | }
+ |}
+ """.stripMargin))
+ assert(g1.normalise === g1)
+ }
+
+
+ it should "Normalise K2 to itself" in {
+ val g1 = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "wire_vertices": ["w0", "w1"],
+ | "undir_edges": {
+ | "e0": {"src": "w0", "tgt": "w1"}
+ | }
+ |}
+ """.stripMargin))
+ assert(g1.normalise === g1)
+ }
+
+ it should "Normalise P3 to itself" in {
+ val g1 = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "wire_vertices": ["w0", "w1", "w2"],
+ | "undir_edges": {
+ | "e0": {"src": "w0", "tgt": "w1"},
+ | "e1": {"src": "w1", "tgt": "w2"}
+ | }
+ |}
+ """.stripMargin))
+ assert(g1.normalise === g1)
+ }
+
+
+ it should "Normalise P4 to P3" in {
+ val g1 = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "wire_vertices": ["w0", "w1", "w2","w3"],
+ | "undir_edges": {
+ | "e0": {"src": "w0", "tgt": "w1"},
+ | "e1": {"src": "w1", "tgt": "w2"},
+ | "e2": {"src": "w2", "tgt": "w3"}
+ | }
+ |}
+ """.stripMargin))
+ val g2 = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "wire_vertices": ["w0", "w1", "w3"],
+ | "undir_edges": {
+ | "e0": {"src": "w0", "tgt": "w1"},
+ | "e2": {"src": "w1", "tgt": "w3"}
+ | }
+ |}
+ """.stripMargin))
+ assert(g1.normalise === g2)
+ }
+
+
+ it should "Not normalise self loop" in {
+ val g1 = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "wire_vertices": ["w0"],
+ | "undir_edges": {
+ | "e0": {"src": "w0", "tgt": "w0"}
+ | }
+ |}
+ """.stripMargin))
+ val g2 = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "wire_vertices": ["w0"],
+ | "undir_edges": {
+ | }
+ |}
+ """.stripMargin))
+ assert(g1.normalise !== g2)
+ }
+
behavior of "Some more graphs"
it should "normalise" in {
@@ -301,6 +419,18 @@ class GraphSpec extends FlatSpec with GivenWhenThen {
assert(twobb1.inBBox.domf(v1) === bs)
}
+ it should "rename correctly" in {
+ var renamed = twobb.addEdge(twobb.edges.fresh, UndirEdge(), "v0" -> "v0").rename(Map("v0" -> "v1"))
+ assert(renamed.verts.contains("v1"))
+ assert(renamed.inBBox.domf("v1").nonEmpty)
+ }
+
+ it should "add correctly" in {
+ var added = twobb.appendGraph(twobb.renameAvoiding(twobb))
+ assert(added.verts.size == 2)
+ println(added)
+ }
+
behavior of "A graph with angles"
it should "return the free variables" in {
@@ -314,10 +444,114 @@ class GraphSpec extends FlatSpec with GivenWhenThen {
|}
""".stripMargin), thy = rg)
- assert(g.freeVars === Set("x", "y", "z"))
+ assert(g.freeVars === Set(
+ (ValueType.AngleExpr, "x"),
+ (ValueType.AngleExpr, "y"),
+ (ValueType.AngleExpr, "z")))
+ }
+
+ behavior of "A graph with composite values"
+
+ it should "accept a composite-valued graph" in {
+ val g = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "node_vertices": {
+ | "v0": {"data": {"type": "Z", "value": "x + y, true"}},
+ | "v1": {"data": {"type": "X", "value": "z + pi, false"}}
+ | }
+ |}
+ """.stripMargin), thy = composite_thy)
+ var pi_commute = Graph.fromJson(Json.parse(
+ """
+ |{"wire_vertices":{
+ | "b1":{"annotation":{"boundary":true,"coord":[-0.75,2.0]}},
+ | "b0":{"annotation":{"boundary":true,"coord":[-11.5,1.5]}}
+ | },
+ |"node_vertices":{
+ | "v1":{"data":{"type":"Z","value":"\\pi/4,false"},"annotation":{"coord":[-3.0,0.5]}},
+ | "v0":{"data":{"type":"Z","value":"\\pi/2,false"},"annotation":{"coord":[-8.5,1.75]}},
+ | "v2":{"data":{"type":"X","value":"\\pi, \\true"},"annotation":{"coord":[-5.5,2.5]}}},
+ | "undir_edges":{
+ | "e0":{"src":"b0","tgt":"v0"},
+ | "e1":{"src":"v0","tgt":"v2"},
+ | "e2":{"src":"v2","tgt":"v1"},
+ | "e3":{"src":"v1","tgt":"b1"}
+ | }
+ |}""".stripMargin
+ ), thy = composite_thy)
+
+ assert(pi_commute.freeVars === Set())
+ assert(pi_commute.vdata("v2").asInstanceOf[NodeV].phaseData ===
+ new CompositeExpression(
+ Vector(ValueType.AngleExpr, ValueType.Boolean),
+ Vector(PhaseExpression(1, Map(), ValueType.AngleExpr), PhaseExpression(1, Map(), ValueType.Boolean)))
+ )
}
// it should "support Graph.Flavor clipboard flavor" in {
// assert(Graph().isDataFlavorSupported(Graph.Flavor))
// }
+
+ behavior of "vertex cutting"
+ // Create a graph with left, middle and right vertices
+ // Join middle and right,
+ // Experiment with cutting out the left vertex
+ // Note that we are normalising the graphs, which can cause unintuitive behaviour
+ var cut_g = new Graph()
+ cut_g = cut_g.addVertex("v0", NodeV()).
+ addVertex("v1", NodeV()).
+ addVertex("v2", NodeV()).
+ addEdge("e", DirEdge(), "v1" -> "v2")
+
+ it should "count arities" in {
+ val g0 = cut_g.normalise
+ assert(g0.arity("v0") == 0)
+ assert(g0.arity("v1") == 1)
+ assert(g0.arity("v2") == 1)
+ }
+
+
+ it should "find end-points of wires" in {
+ val g0 = cut_g.normalise
+ assert(g0.edgeEndPoints("e0")._1 == Set("v1", "v2").map(VName))
+ }
+
+ it should "find no end-points in a loop" in {
+ val g0 = new Graph().
+ addVertex("v0", WireV()).
+ addVertex("v1", WireV()).
+ addEdge("e0", DirEdge(), "v0"->"v1").
+ addEdge("e1", DirEdge(), "v1"->"v0")
+ assert(g0.edgeEndPoints("e0") == (Set(), Set("e0","e1").map(EName), Set("v0","v1").map(VName)))
+ }
+
+ it should "cut out a lone vertex" in {
+ val(g0,nb, rb) = cut_g.
+ normalise.cutVertex("v0")
+ assert(nb.isEmpty)
+ assert(g0.nodesThatAreNotWires.size == 2)
+ assert(g0.edges.size == 2) //double expected because normalisation
+ }
+
+ it should "cut out a vertex attached to another vertex" in {
+ val (g0,nb, rb) = cut_g.
+ normalise.addEdge("ee0", DirEdge(), "v0" -> "v1").cutVertex("v0")
+ assert(nb.size == 1)
+ assert(g0.nodesThatAreNotWires.size == 2)
+ assert(g0.edges.size == 3) // One going left to boundary, two between v1 and v2
+ }
+
+
+ it should "cut out a vertex attached to another vertex and a boundary" in {
+ val (g0,nb, rb) = cut_g.
+ normalise.
+ addEdge("ee0", DirEdge(), "v0" -> "v1").
+ addVertex("b0", WireV()).
+ addEdge("ee1", DirEdge(), "b0" -> "v0").
+ cutVertex("v0")
+ assert(nb.size == 1)
+ assert(g0.nodesThatAreNotWires.size == 2)
+ assert(g0.edges.size == 3) // One going left to boundary, two between v1 and v2
+ }
}
diff --git a/scala/src/test/scala/quanto/data/test/PhaseExpressionSpec.scala b/scala/src/test/scala/quanto/data/test/PhaseExpressionSpec.scala
new file mode 100644
index 00000000..181486fe
--- /dev/null
+++ b/scala/src/test/scala/quanto/data/test/PhaseExpressionSpec.scala
@@ -0,0 +1,224 @@
+package quanto.data.test
+
+import org.scalatest._
+import quanto.data.Theory.ValueType
+import quanto.data._
+import quanto.util.Rational
+
+class PhaseExpressionSpec extends FlatSpec {
+ behavior of "A rational number"
+
+ it should "add correctly" in {
+ assert(Rational(1, 2) + Rational(2, 3) === Rational(7, 6))
+ }
+
+ behavior of "An angle expression"
+
+
+ def AngleExpression(constant: Rational) = PhaseExpression(constant, Map(), ValueType.AngleExpr)
+
+ def AngleExpression(constant: Rational, coefficients: Map[String, Rational]) =
+ PhaseExpression(constant, coefficients, ValueType.AngleExpr)
+
+ def zero = PhaseExpression.zero(ValueType.AngleExpr)
+
+ def one = PhaseExpression.one(ValueType.AngleExpr)
+
+ def testReparse(e: PhaseExpression) {
+ assert(e === parse(e.toString))
+ }
+
+ def parse(s: String): PhaseExpression = PhaseExpression.parse(s, ValueType.AngleExpr)
+
+ it should "create expressions" in {
+ val a1 = AngleExpression(Rational(1,2), Map("a" -> Rational(2,1)))
+ assert(a1.constant == Rational(1,2))
+ assert(a1.coefficients.size == 1)
+ testReparse(a1)
+ }
+
+
+ it should "compare expressions" in {
+ val a = AngleExpression(Rational(0), Map("a" -> Rational(1)))
+ val b = AngleExpression(Rational(0), Map("b" -> Rational(1)))
+ assert(zero === AngleExpression(Rational(0)))
+ assert(one === AngleExpression(Rational(1)))
+ assert(a + b === b + a)
+ assert(zero !== one)
+ assert(one + (a * Rational(1, 2)) + (b * 4) === (a * Rational(1, 2)) + one + (b * 4))
+ assert((a * Rational(1, 2)) + (a * Rational(2, 3)) === (a * Rational(7, 6)))
+ }
+
+ it should "parse '0'" in {
+ testReparse(zero)
+ assert(parse("") === zero)
+ assert(parse("0") === zero)
+ }
+
+ it should "parse 'PI'" in {
+ testReparse(one)
+ assert(parse("\\pi") === one)
+ assert(parse("1\\pi") === one)
+ assert(parse("1*\\pi") === one)
+ assert(parse("1/1\\pi") === one)
+ assert(parse("1/1*\\pi") === one)
+ assert(parse("1\\pi/1") === one)
+ assert(parse("1*\\pi/1") === one)
+ assert(parse("\\pi/1") === one)
+
+ assert(parse("pi") === one)
+ assert(parse("1pi") === one)
+ assert(parse("1*pi") === one)
+ assert(parse("1/1pi") === one)
+ assert(parse("1/1*pi") === one)
+ assert(parse("1pi/1") === one)
+ assert(parse("1*pi/1") === one)
+ assert(parse("pi/1") === one)
+
+ assert(parse("PI") === one)
+ assert(parse("1PI") === one)
+ assert(parse("1*PI") === one)
+ assert(parse("1/1PI") === one)
+ assert(parse("1/1*PI") === one)
+ assert(parse("1PI/1") === one)
+ assert(parse("1*PI/1") === one)
+ assert(parse("PI/1") === one)
+ }
+
+ it should "parse 'a'" in {
+ val a = AngleExpression(Rational(0), Map("a" -> Rational(1)))
+ testReparse(a)
+ assert(parse("a") === a)
+ assert(parse("1a") === a)
+ assert(parse("1*a") === a)
+ assert(parse("1/1a") === a)
+ assert(parse("1/1*a") === a)
+ assert(parse("1a/1") === a)
+ assert(parse("1*a/1") === a)
+ assert(parse("a/1") === a)
+ }
+
+ it should "parse '-PI'" in {
+ val minusPI = one * -1
+ testReparse(minusPI)
+ assert(parse("-pi") === minusPI)
+ assert(parse("-1pi") === minusPI)
+ assert(parse("-1*pi") === minusPI)
+ assert(parse("-1/1pi") === minusPI)
+ assert(parse("-1/1*pi") === minusPI)
+ assert(parse("-1pi/1") === minusPI)
+ assert(parse("-1*pi/1") === minusPI)
+ assert(parse("-pi/1") === minusPI)
+ }
+
+ it should "parse '-a'" in {
+ val minusA = AngleExpression(Rational(0), Map("a" -> Rational(-1)))
+ testReparse(minusA)
+ assert(parse("-a") === minusA)
+ assert(parse("-1a") === minusA)
+ assert(parse("-1*a") === minusA)
+ assert(parse("-1/1a") === minusA)
+ assert(parse("-1/1*a") === minusA)
+ assert(parse("-1a/1") === minusA)
+ assert(parse("-1*a/1") === minusA)
+ assert(parse("-a/1") === minusA)
+ }
+
+ it should "parse \\pi/2" in {
+ val ohPI = AngleExpression(Rational(1, 2))
+ testReparse(ohPI)
+ assert(parse("1/2") === ohPI)
+ assert(parse("\\pi/2") === ohPI)
+ }
+
+ it should "parse '+-3/4 PI'" in {
+ val tfPI = AngleExpression(Rational(3, 4))
+ testReparse(tfPI)
+ assert(parse("3/4") === tfPI)
+ assert(parse("3/4pi") === tfPI)
+ assert(parse("3/4*pi") === tfPI)
+ assert(parse("3pi/4") === tfPI)
+ assert(parse("3*pi/4") === tfPI)
+
+ val mtfPI = AngleExpression(Rational(-3, 4))
+ testReparse(mtfPI)
+ assert(parse("-3/4") === mtfPI)
+ assert(parse("-3/4pi") === mtfPI)
+ assert(parse("-3/4*pi") === mtfPI)
+ assert(parse("-3pi/4") === mtfPI)
+ assert(parse("-3*pi/4") === mtfPI)
+ }
+
+ it should "parse '+-1/4 PI' and '+-1/4 a'" in {
+ val fPI = AngleExpression(Rational(1, 4))
+ val mfPI = fPI * -1
+ val fA = AngleExpression(Rational(0), Map("a" -> Rational(1, 4)))
+ val mfA = fA * -1
+ assert(parse("pi/4") === fPI)
+ assert(parse("-pi/4") === mfPI)
+ assert(parse("a/4") === fA)
+ assert(parse("-a/4") === mfA)
+ }
+
+ it should "parse addition and subtraction correctly" in {
+ val a = AngleExpression(Rational(0), Map("a" -> Rational(1)))
+ val b = AngleExpression(Rational(0), Map("b" -> Rational(1)))
+ testReparse(a + b)
+ testReparse(a - b)
+ testReparse((a * -1) - b)
+ assert(parse("a + b") === a + b)
+ assert(parse("a - b") === a - b)
+ assert(parse("-a + b") === b - a)
+ assert(parse("- a - b") === (a * -1) - b)
+ assert(parse("-(a + b)") === (a * -1) - b)
+ assert(parse("-(a - b)") === b - a)
+ }
+
+
+ it should "do substitutions correctly" in {
+ val e1 = parse("x - 2 y")
+ val e2 = parse("a + b - c")
+ assert(e1.subst("x", e2) === parse("a + b - c - 2y"))
+ assert(e1.subst("y", e2) === parse("x - 2a - 2b + 2c"))
+ }
+
+ it should "evaluate a polynomial" in {
+ val e1 = parse("2x + 3/4")
+ assert(Math.abs(e1.evaluate(Map("x" -> 1.0/8.0)) - 1) < 1e-15)
+ }
+
+ it should "evaluate a string" in {
+ val s1 = PhaseExpression.parse("Hi", ValueType.String)
+ assert(s1.toString == "Hi")
+ }
+
+ it should "evaluate an empty" in {
+ val e1 = PhaseExpression.parse("Hi", ValueType.Empty)
+ assert(e1.toString == "")
+ }
+
+ it should "print a rational" in {
+ val r1 = PhaseExpression.parse("3/27", ValueType.Rational)
+ assert(r1.toString == "1/9")
+ val r2 = PhaseExpression.parse("3/27 + a/4", ValueType.Rational)
+ assert(r2.toString == "1/9 + 1/4 a")
+ }
+
+ it should "parse alpha'" in {
+ val e = parse("alpha' + alpha")
+ assert(e.coefficients.keys.size == 2)
+ }
+
+ it should "substitute beta" in {
+ val e = parse("-beta")
+ val f = e.subst("beta", parse("0"))
+ assert(f == zero)
+ }
+
+ it should "substitute composite beta" in {
+ val e = parse("-beta")
+ val ee = CompositeExpression.wrap(e)
+ val f = ee.substSubVariables(Map((ValueType.AngleExpr, "beta") -> "0"))
+ assert(f.values.head == zero)
+ }
+}
diff --git a/scala/src/test/scala/quanto/data/test/TheorySpec.scala b/scala/src/test/scala/quanto/data/test/TheorySpec.scala
index 9650ab5b..c0c6b528 100644
--- a/scala/src/test/scala/quanto/data/test/TheorySpec.scala
+++ b/scala/src/test/scala/quanto/data/test/TheorySpec.scala
@@ -9,13 +9,13 @@ class TheorySpec extends FlatSpec {
behavior of "A theory"
val rgValueDesc = Theory.ValueDesc(
- typ = Theory.ValueType.String,
+ typ = Vector(Theory.ValueType.String),
latexConstants = true,
validateWithCore = true
)
val hValueDesc = Theory.ValueDesc(
- typ = Theory.ValueType.Empty,
+ typ = Vector(Theory.ValueType.Empty),
latexConstants = false,
validateWithCore = false
)
@@ -68,9 +68,9 @@ class TheorySpec extends FlatSpec {
| "vertex_types" : {
| "red" : {
| "value" : {
- | "validate_with_core" : true,
+ | "type" : "string",
| "latex_constants" : true,
- | "type" : "string"
+ | "validate_with_core" : true
| },
| "style" : {
| "label" : {
@@ -79,7 +79,8 @@ class TheorySpec extends FlatSpec {
| },
| "stroke_color" : [ 0.0, 0.0, 0.0 ],
| "fill_color" : [ 1.0, 0.0, 0.0 ],
- | "shape" : "circle"
+ | "shape" : "circle",
+ | "stroke_width" : 1
| },
| "default_data" : {
| "type" : "red",
@@ -91,9 +92,9 @@ class TheorySpec extends FlatSpec {
| },
| "green" : {
| "value" : {
- | "validate_with_core" : true,
+ | "type" : "string",
| "latex_constants" : true,
- | "type" : "string"
+ | "validate_with_core" : true
| },
| "style" : {
| "label" : {
@@ -102,7 +103,8 @@ class TheorySpec extends FlatSpec {
| },
| "stroke_color" : [ 0.0, 0.0, 0.0 ],
| "fill_color" : [ 0.0, 1.0, 0.0 ],
- | "shape" : "circle"
+ | "shape" : "circle",
+ | "stroke_width" : 1
| },
| "default_data" : {
| "type" : "green",
@@ -114,9 +116,9 @@ class TheorySpec extends FlatSpec {
| },
| "hadamard" : {
| "value" : {
- | "validate_with_core" : false,
+ | "type" : "empty",
| "latex_constants" : false,
- | "type" : "empty"
+ | "validate_with_core" : false
| },
| "style" : {
| "label" : {
@@ -125,19 +127,22 @@ class TheorySpec extends FlatSpec {
| },
| "stroke_color" : [ 0.0, 0.0, 0.0 ],
| "fill_color" : [ 1.0, 1.0, 0.0 ],
- | "shape" : "rectangle"
+ | "shape" : "rectangle",
+ | "stroke_width" : 1
| },
| "default_data" : {
| "type" : "hadamard"
| }
| }
| },
+ | "default_vertex_type" : "red",
+ | "default_edge_type" : "plain",
| "edge_types" : {
| "plain" : {
| "value" : {
- | "validate_with_core" : false,
+ | "type" : "empty",
| "latex_constants" : false,
- | "type" : "empty"
+ | "validate_with_core" : false
| },
| "style" : {
| "stroke_color" : [ 0.0, 0.0, 0.0 ],
@@ -151,9 +156,7 @@ class TheorySpec extends FlatSpec {
| "type" : "plain"
| }
| }
- | },
- | "default_vertex_type" : "red",
- | "default_edge_type" : "plain"
+ | }
|}
""".stripMargin)
@@ -164,6 +167,27 @@ class TheorySpec extends FlatSpec {
}
it should "load from JSON" in {
- assert(Theory.fromJson(thyJson) === thy)
+ var loaded : Theory = Theory.fromJson(thyJson)
+ assert(loaded.vertexTypes === thy.vertexTypes)
+ print(loaded.vertexTypes)
+ assert(loaded === thy)
+ }
+
+ behavior of "mixing theories"
+
+ it should "mix with itself" in {
+ assert(thy.mixin(thy, None) == thy)
+ }
+
+ val plain = Theory.fromFile("plain")
+ val rg = Theory.fromFile("red_green")
+
+ it should "mix with others" in {
+ val mixedPlainRG = plain.mixin(rg, Some("plain with red_green"))
+ assert(mixedPlainRG.vertexTypes.keySet == Set("var", "hadamard", "Z", "X"))
+ }
+
+ it should "mix with fragments" in {
+ assert(plain.mixin(newVertexTypes = rg.vertexTypes.filter(_._1 == "Z")).vertexTypes.keySet == Set("var", "Z"))
}
}
diff --git a/scala/src/test/scala/quanto/rewrite/test/AngleExpressionMatcherSpec.scala b/scala/src/test/scala/quanto/rewrite/test/AngleExpressionMatcherSpec.scala
index 64372e4f..fb222fe7 100644
--- a/scala/src/test/scala/quanto/rewrite/test/AngleExpressionMatcherSpec.scala
+++ b/scala/src/test/scala/quanto/rewrite/test/AngleExpressionMatcherSpec.scala
@@ -1,11 +1,30 @@
package quanto.rewrite.test
import org.scalatest._
+import quanto.data.Theory.ValueType
import quanto.rewrite._
import quanto.data._
+import quanto.util.Rational
class AngleExpressionMatcherSpec extends FlatSpec {
- import AngleExpression.parse
+
+ def AngleExpressionMatcher(pVars: Vector[String], tVars: Vector[String]) = PhaseExpressionMatcher(pVars, tVars, Some(2))
+
+ def AngleExpression(constant: Rational) = PhaseExpression(constant, Map(), ValueType.AngleExpr)
+
+ def AngleExpression(constant: Rational, coefficients: Map[String, Rational]) =
+ PhaseExpression(constant, coefficients, ValueType.AngleExpr)
+
+ def zero = PhaseExpression.zero(ValueType.AngleExpr)
+
+ def one = PhaseExpression.one(ValueType.AngleExpr)
+
+ def testReparse(e: PhaseExpression) {
+ assert(e === parse(e.toString))
+ }
+
+ def parse(s: String): PhaseExpression = PhaseExpression.parse(s, ValueType.AngleExpr)
+
behavior of "An angle expression matcher"
it should "handle single-variable matches" in {
@@ -15,7 +34,7 @@ class AngleExpressionMatcherSpec extends FlatSpec {
val mp = m.toMap
// check we got correct map
- assert(m.toMap === Map("a" -> parse("x + 2 y"), "b" -> parse("z + pi")))
+ assert(m.toMap.mapValues(_.as(ValueType.AngleExpr)) === Map("a" -> parse("x + 2 y"), "b" -> parse("z + pi")))
// check substitutions into pattern yield target
assert(parse("a").subst(mp) === parse("x + 2 y"))
@@ -27,7 +46,7 @@ class AngleExpressionMatcherSpec extends FlatSpec {
m = m.addMatch(parse("a + 2 b"), parse("x + 2 y")).get
m = m.addMatch(parse("b + c"), parse("z + pi")).get
m = m.addMatch(parse("a - c"), parse("4 x")).get
- val mp = m.toMap
+ val mp = m.toMap(ValueType.AngleExpr)
// check we got correct map
assert(mp === Map(
diff --git a/scala/src/test/scala/quanto/rewrite/test/BBoxMatcherSpec.scala b/scala/src/test/scala/quanto/rewrite/test/BBoxMatcherSpec.scala
index 3a0542b7..111594f7 100644
--- a/scala/src/test/scala/quanto/rewrite/test/BBoxMatcherSpec.scala
+++ b/scala/src/test/scala/quanto/rewrite/test/BBoxMatcherSpec.scala
@@ -842,4 +842,43 @@ class BBoxMatcherSpec extends FlatSpec {
// should match once as empty graph (killing both bboxes), and once using two copies of each bbox
assert(matches.size === 2)
}
+
+ it should "match spider-law with multiple connecting edges (and no external ones)" in {
+ val g1 = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "wire_vertices": ["w1"],
+ | "node_vertices": {
+ | "v0": {"data": {"type": "Z"}},
+ | "v1": {"data": {"type": "Z"}}
+ | },
+ | "undir_edges": {
+ | "e3": {"src": "v0", "tgt": "w1"},
+ | "e4": {"src": "v1", "tgt": "w1"}
+ | },
+ | "bang_boxes": {
+ | "bb2": {"contents": ["w1"]}
+ | }
+ |}
+ """.stripMargin), thy = rg)
+
+
+ val g2 = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "node_vertices": {
+ | "v0": {"data": {"type": "Z"}},
+ | "v1": {"data": {"type": "Z"}}
+ | },
+ | "undir_edges": {
+ | "e0": {"src": "v0", "tgt": "v1"},
+ | "e1": {"src": "v0", "tgt": "v1"}
+ | }
+ |}
+ """.stripMargin), thy = rg)
+
+ val matches = Matcher.findMatches(g1, g1)
+
+ assert(matches.size === 2)
+ }
}
diff --git a/scala/src/test/scala/quanto/rewrite/test/CompositeExpressionMatcherSpec.scala b/scala/src/test/scala/quanto/rewrite/test/CompositeExpressionMatcherSpec.scala
new file mode 100644
index 00000000..05995aeb
--- /dev/null
+++ b/scala/src/test/scala/quanto/rewrite/test/CompositeExpressionMatcherSpec.scala
@@ -0,0 +1,74 @@
+package quanto.rewrite.test
+
+import org.scalatest._
+import quanto.data.Theory.ValueType
+import quanto.data._
+import quanto.rewrite._
+import quanto.util.Rational
+
+
+class CompositeExpressionMatcherSpec extends FlatSpec {
+
+ def AngleExpression(constant: Rational) = PhaseExpression(constant, Map(), ValueType.AngleExpr)
+
+ def AngleExpression(constant: Rational, coefficients: Map[String, Rational]) =
+ PhaseExpression(constant, coefficients, ValueType.AngleExpr)
+
+ def parse(types: String, values: String) = CompositeExpression.parse(types, values)
+
+ def composite(p: PhaseExpression): CompositeExpression = CompositeExpression.wrap(p)
+
+ def zero = PhaseExpression.zero(ValueType.AngleExpr)
+
+ def one = PhaseExpression.one(ValueType.AngleExpr)
+
+ def testReparse(e: PhaseExpression) {
+ assert(e === parse(e.toString))
+ }
+
+ def parse(s: String): PhaseExpression = PhaseExpression.parse(s, ValueType.AngleExpr)
+
+ behavior of "An angle expression matcher"
+
+ it should "handle single-variable matches" in {
+
+ var m = CompositeExpressionMatcher()
+ m = m.addMatch(parse("angle", "a"), parse("angle", "x + 2 y")).get
+ m = m.addMatch(parse("angle", "b"), parse("angle", "z + pi")).get
+ val mp = m.toMap
+
+ // check we got correct map
+ assert(m.toMap.values.flatten.toMap ===
+ Map("a" -> parse("angle", "x + 2 y").firstOrError(ValueType.AngleExpr), "b" -> parse("angle", "z + pi").firstOrError(ValueType.AngleExpr)))
+
+ // check substitutions into pattern yield target
+ assert(parse("angle", "a").substSubValues(mp.values.flatten.toMap) === parse("angle", "x + 2 y"))
+ assert(parse("angle", "b").substSubValues(mp.values.flatten.toMap) === parse("angle", "z + pi"))
+ }
+
+ it should "handle expression matches" in {
+ var m = CompositeExpressionMatcher()
+ m = m.addMatch(parse("angle, boolean", "a + 2 b, d"), parse("angle, boolean", "x + 2 y, true")).get
+ m = m.addMatch(parse("angle, boolean", "b + c, false"), parse("angle, boolean", "z + pi, false")).get
+ m = m.addMatch(parse("angle, boolean", "a - c,false"), parse("angle, boolean", "4 x, false")).get
+ val mp = m.toMap
+
+ // check we got correct map
+ assert(mp(ValueType.AngleExpr) === Map(
+ "a" -> parse("angle", "7 x - 2 y + 2 z").firstOrError(ValueType.AngleExpr),
+ "b" -> parse("angle", "-pi - 3 x + 2 y - z").firstOrError(ValueType.AngleExpr),
+ "c" -> parse("angle", "3 x - 2 y + 2 z").firstOrError(ValueType.AngleExpr)
+ ))
+ assert(mp(ValueType.Boolean) === Map(
+ "d" -> parse("boolean", "true").firstOrError(ValueType.Boolean)
+ ))
+
+ // check substitutions into pattern yield target
+ assert(parse("angle, boolean", "a + 2 b, d").substSubValues(mp.values.flatten.toMap) ===
+ parse("angle,boolean","x + 2 y, true"))
+ assert(parse("angle, boolean", "b + c").substSubValues(mp.values.flatten.toMap) ===
+ parse("angle, boolean","z + pi"))
+ assert(parse("angle, boolean", "a - c").substSubValues(mp.values.flatten.toMap) ===
+ parse("angle, boolean","4 x"))
+ }
+}
diff --git a/scala/src/test/scala/quanto/rewrite/test/MatcherSpec.scala b/scala/src/test/scala/quanto/rewrite/test/MatcherSpec.scala
index b0e5a842..8137c296 100644
--- a/scala/src/test/scala/quanto/rewrite/test/MatcherSpec.scala
+++ b/scala/src/test/scala/quanto/rewrite/test/MatcherSpec.scala
@@ -3,6 +3,7 @@ package quanto.rewrite.test
import quanto.rewrite._
import quanto.data._
import org.scalatest._
+import quanto.data.Theory.{EdgeDesc, ValueType}
import quanto.util.json.Json
class MatcherSpec extends FlatSpec {
@@ -189,7 +190,8 @@ class MatcherSpec extends FlatSpec {
""".stripMargin), thy = rg)
val matches = Matcher.findMatches(g1, g2)
assert(matches.size === 1)
- assert(matches.head.subst === Map("x" -> AngleExpression.parse("(1/2) \\pi")))
+ assert(matches.head.subst(ValueType.AngleExpr).mapValues(_.as(ValueType.AngleExpr)) ===
+ Map("x" -> PhaseExpression.parse("(1/2) \\pi", ValueType.AngleExpr)))
}
it should "match a graph with one wire on itself" in {
@@ -635,6 +637,81 @@ class MatcherSpec extends FlatSpec {
assert(matches.forall { _.isHomomorphism })
}
+
+ val stringWire = rg.edgeTypes("string")
+ val twoWire : Theory = rg.mixin(newEdgeTypes = Map("q" -> stringWire.copy()))
+
+ it should "fail to match edges of different types" in {
+ val g = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "node_vertices": {
+ | "v0": {"data": {"type": "Z", "value": ""}},
+ | "v1": {"data": {"type": "X", "value": ""}}
+ | },
+ | "undir_edges": {
+ | "e0": {"data": {"type": "string"},"src": "v0", "tgt": "v1"},
+ | "e1": {"data": {"type": "q"},"src": "v0", "tgt": "v1"}
+ | }
+ |}
+ """.stripMargin), thy = twoWire)
+ val matches = Matcher.findMatches(g, g)
+ assert(matches.size === 1)
+ }
+
+
+ it should "fail to match bare edges of different types" ignore {
+ val g = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "node_vertices": {
+ | "v0": {"data": {"type": "Z", "value": ""}},
+ | "v1": {"data": {"type": "Z", "value": ""}}
+ | },
+ | "wire_vertices": ["i0"],
+ | "undir_edges": {
+ | "e0": {"data": {"type": "string"},"src": "v0", "tgt": "v1"},
+ | "e1": {"data": {"type": "string"},"src": "v0", "tgt": "i0"}
+ | }
+ |}
+ """.stripMargin), thy = twoWire)
+
+ val g2 = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "wire_vertices": ["i0", "o0"],
+ | "undir_edges": {
+ | "e0": {"data": {"type": "q"},"src": "i0", "tgt": "o0"}
+ | }
+ |}
+ """.stripMargin), thy = twoWire)
+ val matches = Matcher.findMatches(g2, g)
+ assert(matches.size === 0)
+ }
+
+
+ it should "find 2x2 matches from different wire types" in {
+ val g = Graph.fromJson(Json.parse(
+ """
+ |{
+ | "node_vertices": {
+ | "v0": {"data": {"type": "Z", "value": ""}},
+ | "v1": {"data": {"type": "X", "value": ""}}
+ | },
+ | "undir_edges": {
+ | "e0": {"data": {"type": "string"},"src": "v0", "tgt": "v1"},
+ | "e1": {"data": {"type": "q"},"src": "v0", "tgt": "v1"},
+ | "e2": {"data": {"type": "string"},"src": "v0", "tgt": "v1"},
+ | "e3": {"data": {"type": "q"},"src": "v0", "tgt": "v1"}
+ | }
+ |}
+ """.stripMargin), thy = rg)
+ val matches = Matcher.findMatches(g, g)
+ assert(matches.size === 4)
+ }
+
+
+
it should "match a graph with 2 components on itself" in {
val g1 = Graph.fromJson(Json.parse(
"""
diff --git a/scala/src/test/scala/quanto/util/test/RationalMatrixSpec.scala b/scala/src/test/scala/quanto/util/test/RationalMatrixSpec.scala
index 9a39deaf..34075e80 100644
--- a/scala/src/test/scala/quanto/util/test/RationalMatrixSpec.scala
+++ b/scala/src/test/scala/quanto/util/test/RationalMatrixSpec.scala
@@ -9,13 +9,13 @@ class RationalMatrixSpec extends FlatSpec {
behavior of "A rational matrix"
it should "be constructable" in {
- val m = new RationalMatrix(Vector(Vector(1,2,3), Vector(4,5,6), Vector(7,8,9)), 3)
+ val m = new RationalMatrix(Vector(Vector(1,2,3), Vector(4,5,6), Vector(7,8,9)), 3, Some(2))
}
it should "perform gaussian elimination" in {
- val m1 = new RationalMatrix(Vector(Vector(1,2,3,4), Vector(2,2,2,2), Vector(2,1,3,1)), 3)
- val m2 = new RationalMatrix(Vector(Vector(1,2), Vector(2,2)), 2)
- val m3 = new RationalMatrix(Vector(Vector(Rational(1,5),2,3,4), Vector(2,2,2,2)), 3)
+ val m1 = new RationalMatrix(Vector(Vector(1,2,3,4), Vector(2,2,2,2), Vector(2,1,3,1)), 3, Some(2))
+ val m2 = new RationalMatrix(Vector(Vector(1,2), Vector(2,2)), 2, Some(2))
+ val m3 = new RationalMatrix(Vector(Vector(Rational(1,5),2,3,4), Vector(2,2,2,2)), 3, Some(2))
assert(!m1.isReduced)
assert(m1.gauss.get.isReduced)