diff --git a/.travis.yml b/.travis.yml index 618f2f51..7287cbb3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,6 @@ language: scala scala: -- 2.11.8 +- 2.12.4 script: "./build.sh" deploy: - provider: bintray diff --git a/bintray-release.json b/bintray-release.json index 2aca41ad..f7db64a3 100644 --- a/bintray-release.json +++ b/bintray-release.json @@ -9,7 +9,7 @@ "released": "2017-11-23" }, "files": [{ - "includePattern": "scala/target/scala-2.11/Quantomatic.jar", + "includePattern": "scala/target/scala-(.*)/Quantomatic.jar", "uploadPattern": "release-0.6.1/Quantomatic.jar", "matrixParams": {"override": 1} }], diff --git a/bintray.json b/bintray.json index 8473b0ad..112b7a43 100644 --- a/bintray.json +++ b/bintray.json @@ -1,6 +1,6 @@ { "package": { - "name": "Quantomatic", + "name": "quantomatic", "repo": "quantomatic", "subject": "quantomatic" }, @@ -8,7 +8,7 @@ "name": "bleeding-edge" }, "files": [{ - "includePattern": "scala/target/scala-2.11/Quantomatic.jar", + "includePattern": "scala/target/scala-2.12/Quantomatic.jar", "uploadPattern": "bleeding-edge/Quantomatic.jar", "matrixParams": {"override": 1} }], diff --git a/docs/json_formats.txt b/docs/json_formats.txt index 46427410..52d08fe7 100644 --- a/docs/json_formats.txt +++ b/docs/json_formats.txt @@ -134,6 +134,7 @@ VTYPE_DESC ::= "style": { "shape": "circle" | "rectangle" | "custom", "custom_shape_path": JSON_STRING, + "stroke_width": JSON_INT, "stroke_color": COLOR, "fill_color": COLOR, "label" : { @@ -209,3 +210,40 @@ STEP ::= } + +======================== +Simproc Batch Run +======================== + +File extension: .qsbr + +A simproc batch run takes a collection of simprocs and applies them to a list of graphs. The resulting derivations are then recorded and timestamped. +Notes are optional notes added to the file by the job creator. +"python" contains all the python source code of the simprocs in memory. +"selected_simprocs" contains the names of just those simprocs used in this job. + +SIMPROC_BATCH_RUN ::= +{ + "python": {SIMPROC_NAME: PYTHON_STRING (, SIMPROC_NAME: PYTHON_STRING)*}, + "selected_simprocs": Array[String], + "results": {DERIVATION_WITH_TIMINGS (, DERIVATION_WITH_TIMINGS)*}, + "notes": String +} + +DERIVATION_WITH_TIMINGS ::= +{ + "simproc" : STRING, + "derivation": DERIVATION, + "timings": {STEP_TIME(, STEP_TIME)*} +} + +STEP_TIME ::= +{ + "step": STEP_NAME, + "time": TIMESTAMP +} + +TIMESTAMP : NUMBER = time in seconds since process began + + + diff --git a/scala/build.sbt b/scala/build.sbt index 35c2e624..293c0356 100644 --- a/scala/build.sbt +++ b/scala/build.sbt @@ -2,7 +2,7 @@ name := "quanto" version := "1.0" -scalaVersion := "2.11.8" +scalaVersion := "2.12.6" scalacOptions ++= Seq("-feature", "-language:implicitConversions") @@ -14,19 +14,23 @@ fork := true resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" -libraryDependencies += "com.typesafe.akka" %% "akka-actor" % "2.4.10" withSources() withJavadoc() +libraryDependencies += "com.typesafe.akka" %% "akka-actor" % "2.5.12" withSources() withJavadoc() -libraryDependencies += "com.fasterxml.jackson.core" % "jackson-core" % "2.1.2" +libraryDependencies += "com.fasterxml.jackson.core" % "jackson-core" % "2.9.5" -libraryDependencies += "com.fasterxml.jackson.module" % "jackson-module-scala" % "2.1.2" +libraryDependencies += "com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.9.5" -libraryDependencies += "org.scalatest" % "scalatest_2.11" % "2.2.1" % "test" +//libraryDependencies += "org.scalatest" % "scalatest_2.11" % "2.2.1" % "test" -libraryDependencies += "org.scala-lang.modules" % "scala-swing_2.11" % "1.0.1" +//libraryDependencies += "org.scalactic" %% "scalactic" % "3.0.4" + +libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.5" % "test" + +libraryDependencies += "org.scala-lang.modules" %% "scala-swing" % "2.0.3" //libraryDependencies += "org.scala-lang" % "scala-compiler" % scalaVersion.value -libraryDependencies += "org.scala-lang.modules" %% "scala-parser-combinators" % "1.0.5" +libraryDependencies += "org.scala-lang.modules" %% "scala-parser-combinators" % "1.1.0" //EclipseKeys.withSource := true @@ -36,7 +40,7 @@ seq(appbundle.settings: _*) appbundle.mainClass := Some("quanto.gui.QuantoDerive") -appbundle.javaVersion := "1.7+" +appbundle.javaVersion := "1.8+" appbundle.screenMenu := true @@ -46,7 +50,7 @@ appbundle.normalizedName := "quantoderiveapp" appbundle.organization := "org.quantomatic" -appbundle.version := "0.2.0" +appbundle.version := "0.3.0" appbundle.icon := Some(file("../docs/graphics/quantoderive.icns")) diff --git a/scala/dist/mk-linux-generic.sh b/scala/dist/mk-linux-generic.sh index 41730b26..a29fcfa8 100755 --- a/scala/dist/mk-linux-generic.sh +++ b/scala/dist/mk-linux-generic.sh @@ -5,8 +5,14 @@ BUNDLE=target/QuantoDerive +# Pre-build cleanup, so we know we're building in a consistent environment. +sbt clean +rm -r ../core/heaps/* +rm -r $BUNDLE/* + # Rebuild the core heap echo Rebuilding the core heap... +mkdir -p ../core/heaps (cd ../core; ../scala/dist/linux-dist/poly --use build_heap.ML) @@ -22,7 +28,8 @@ sbt package echo Including binaries... cp -f dist/linux-dist/quanto-derive.sh $BUNDLE/ -cp -f dist/linux-dist/polybin dist/linux-dist/poly $BUNDLE/bin +cp -f dist/linux-dist/polybin $BUNDLE/bin +cp -f dist/linux-dist/poly $BUNDLE/bin cp -f dist/linux-dist/libpolyml.so.4 $BUNDLE/bin echo Including heap... @@ -36,7 +43,8 @@ echo Including jars... cp -f lib_managed/jars/*/*/akka-actor*.jar $BUNDLE/jars cp -f lib_managed/jars/*/*/scala-library*.jar $BUNDLE/jars cp -f lib_managed/jars/*/*/scala-swing*.jar $BUNDLE/jars -cp -f lib_managed/jars/*/*/jackson-core*.jar $BUNDLE/jars +cp -f lib_managed/bundles/*/*/scala-parser-combinators*.jar $BUNDLE/jars +cp -f lib_managed/bundles/*/*/jackson-core*.jar $BUNDLE/jars cp -f lib_managed/bundles/*/*/config*.jar $BUNDLE/jars # grab local dependences diff --git a/scala/lib/jedit-textArea.jar b/scala/lib/jedit-textArea.jar index d3c087a1..0bcb9541 100644 Binary files a/scala/lib/jedit-textArea.jar and b/scala/lib/jedit-textArea.jar differ diff --git a/scala/src/main/java/quanto/core/data/TexConstants.java b/scala/src/main/java/quanto/core/data/TexConstants.java index da5a92c5..6f63d6f9 100644 --- a/scala/src/main/java/quanto/core/data/TexConstants.java +++ b/scala/src/main/java/quanto/core/data/TexConstants.java @@ -48,6 +48,8 @@ private static void initialize() { c.put("Phi", "\u03a6"); c.put("Psi", "\u03a8"); c.put("Omega", "\u03a9"); + c.put("True", "\u22A4"); + c.put("False", "\u22A5"); constants = c; } diff --git a/scala/src/main/java/quanto/util/json/Json.scala b/scala/src/main/java/quanto/util/json/Json.scala index 5d2306d9..a17836a2 100644 --- a/scala/src/main/java/quanto/util/json/Json.scala +++ b/scala/src/main/java/quanto/util/json/Json.scala @@ -277,7 +277,7 @@ object Json { def this(out: java.io.OutputStream) = this(factory.createJsonGenerator(out, JsonEncoding.UTF8)) def this(out: java.io.Writer) = this(factory.createJsonGenerator(out)) - def this(f: java.io.File) = this(factory.createJsonGenerator(f, JsonEncoding.UTF8)) + def this(f: java.io.File) = this(factory.createJsonGenerator(new java.io.FileOutputStream(f), JsonEncoding.UTF8)) private var _pp = false def prettyPrint_=(b: Boolean) { diff --git a/scala/src/main/resources/quanto/data/ZW.qtheory b/scala/src/main/resources/quanto/data/ZW.qtheory new file mode 100644 index 00000000..3d6ced8d --- /dev/null +++ b/scala/src/main/resources/quanto/data/ZW.qtheory @@ -0,0 +1,74 @@ +{ + "name" : "ZW", + "core_name" : "zw", + "vertex_types" : { + "Z" : { + "value" : { + "type" : "string", + "latex_constants" : true, + "validate_with_core" : true + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 0.0, 0.0, 0.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 1.0, 1.0, 1.0 ], + "shape" : "circle" + }, + "default_data" : { + "type" : "Z", + "label" : "", + "value" : { + "pretty" : "" + } + } + }, + "W" : { + "value" : { + "type" : "string", + "latex_constants" : true, + "validate_with_core" : true + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 0.0, 0.0, 0.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 0.0, 0.0, 0.0 ], + "shape" : "circle" + }, + "default_data" : { + "type" : "W", + "label" : "", + "value" : { + "pretty" : "" + } + } + } + }, + "default_vertex_type" : "Z", + "default_edge_type" : "plain", + "edge_types" : { + "plain" : { + "value" : { + "type" : "empty", + "latex_constants" : false, + "validate_with_core" : false + }, + "style" : { + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "stroke_width" : 1, + "label" : { + "position" : "auto", + "fg_color" : [ 0.0, 0.0, 0.0 ] + } + }, + "default_data" : { + "type" : "plain" + } + } + } +} diff --git a/scala/src/main/resources/quanto/data/ZX.qtheory b/scala/src/main/resources/quanto/data/ZX.qtheory new file mode 100644 index 00000000..eb92d16f --- /dev/null +++ b/scala/src/main/resources/quanto/data/ZX.qtheory @@ -0,0 +1,110 @@ +{ + "name" : "Red/green theory", + "core_name" : "red_green", + "vertex_types" : { + "X" : { + "value" : { + "path" : "$.value", + "latex_constants" : true, + "type" : "angle_expr" + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 1.0, 1.0, 1.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 1.0, 0.0, 0.0 ], + "shape" : "circle" + }, + "default_data" : { + "type" : "X", + "value" : "" + } + }, + "Z" : { + "value" : { + "path" : "$.value", + "latex_constants" : true, + "type" : "angle_expr" + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 0.0, 0.0, 0.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 0.0, 0.8, 0.0 ], + "shape" : "circle" + }, + "default_data" : { + "type" : "Z", + "value" : "" + } + }, + "hadamard" : { + "value" : { + "path" : "$.value", + "latex_constants" : false, + "type" : "string" + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 0.0, 0.2, 0.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 1.0, 1.0, 0.0 ], + "shape" : "rectangle" + }, + "default_data" : { + "type" : "hadamard", + "value" : "" + } + }, + "var" : { + "value" : { + "path" : "$.value", + "latex_constants" : false, + "type" : "string" + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 0.0, 0.0, 0.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 0.6, 1.0, 0.8 ], + "shape" : "rectangle" + }, + "default_data" : { + "type" : "var", + "value" : "" + } + } + }, + "edge_types" : { + "string" : { + "value" : { + "path" : "$.value", + "latex_constants" : false, + "type" : "string" + }, + "style" : { + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "stroke_width" : 1, + "label" : { + "position" : "center", + "fg_color" : [ 0.0, 0.0, 1.0 ], + "bg_color" : [ 0.8, 0.8, 1.0, 0.7 ] + } + }, + "default_data" : { + "type" : "string", + "value" : "" + } + } + }, + "default_vertex_type" : "Z", + "default_edge_type" : "string" +} \ No newline at end of file diff --git a/scala/src/main/resources/quanto/data/ZXRails.qtheory b/scala/src/main/resources/quanto/data/ZXRails.qtheory new file mode 100644 index 00000000..4e17c08e --- /dev/null +++ b/scala/src/main/resources/quanto/data/ZXRails.qtheory @@ -0,0 +1,130 @@ +{ + "name" : "Red/green with rails theory", + "core_name" : "red_green_rails", + "vertex_types" : { + "X" : { + "value" : { + "path" : "$.value", + "latex_constants" : true, + "type" : "angle_expr" + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 1.0, 1.0, 1.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 1.0, 0.0, 0.0 ], + "shape" : "circle" + }, + "default_data" : { + "type" : "X", + "value" : "" + } + }, + "Z" : { + "value" : { + "path" : "$.value", + "latex_constants" : true, + "type" : "angle_expr" + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 0.0, 0.0, 0.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 0.0, 0.8, 0.0 ], + "shape" : "circle" + }, + "default_data" : { + "type" : "Z", + "value" : "" + } + }, + "hadamard" : { + "value" : { + "path" : "$.value", + "latex_constants" : false, + "type" : "string" + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 0.0, 0.2, 0.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 1.0, 1.0, 0.0 ], + "shape" : "rectangle" + }, + "default_data" : { + "type" : "hadamard", + "value" : "" + } + }, + "var" : { + "value" : { + "path" : "$.value", + "latex_constants" : false, + "type" : "string" + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 0.0, 0.0, 0.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 0.6, 1.0, 0.8 ], + "shape" : "rectangle" + }, + "default_data" : { + "type" : "var", + "value" : "" + } + } + }, + "edge_types" : { + "string" : { + "value" : { + "path" : "$.value", + "latex_constants" : false, + "type" : "string" + }, + "style" : { + "stroke_color" : [ 0.5, 0.5, 0.5 ], + "stroke_width" : 1, + "label" : { + "position" : "center", + "fg_color" : [ 0.0, 0.0, 1.0 ], + "bg_color" : [ 0.8, 0.8, 1.0, 0.7 ] + } + }, + "default_data" : { + "type" : "string", + "value" : "" + } + }, + "rail" : { + "value" : { + "path" : "$.value", + "latex_constants" : false, + "type" : "string" + }, + "style" : { + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "stroke_width" : 1, + "label" : { + "position" : "center", + "fg_color" : [ 0.0, 0.0, 1.0 ], + "bg_color" : [ 0.8, 0.8, 1.0, 0.7 ] + } + }, + "default_data" : { + "type" : "rail", + "value" : "" + } + } + }, + "default_vertex_type" : "Z", + "default_edge_type" : "rail" +} \ No newline at end of file diff --git a/scala/src/main/resources/quanto/data/composite.qtheory b/scala/src/main/resources/quanto/data/composite.qtheory new file mode 100644 index 00000000..1f5ae40e --- /dev/null +++ b/scala/src/main/resources/quanto/data/composite.qtheory @@ -0,0 +1,110 @@ +{ + "name" : "ZX with boolean data", + "core_name" : "red_green", + "vertex_types" : { + "X" : { + "value" : { + "path" : "$.value", + "latex_constants" : true, + "type" : "angle_expr, boolean" + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 1.0, 1.0, 1.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 1.0, 0.0, 0.0 ], + "shape" : "circle" + }, + "default_data" : { + "type" : "X", + "value" : "" + } + }, + "Z" : { + "value" : { + "path" : "$.value", + "latex_constants" : true, + "type" : "angle_expr, boolean" + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 0.0, 0.0, 0.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 0.0, 0.8, 0.0 ], + "shape" : "circle" + }, + "default_data" : { + "type" : "Z", + "value" : "" + } + }, + "hadamard" : { + "value" : { + "path" : "$.value", + "latex_constants" : false, + "type" : "string" + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 0.0, 0.2, 0.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 1.0, 1.0, 0.0 ], + "shape" : "rectangle" + }, + "default_data" : { + "type" : "hadamard", + "value" : "" + } + }, + "var" : { + "value" : { + "path" : "$.value", + "latex_constants" : false, + "type" : "string" + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 0.0, 0.0, 0.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 0.6, 1.0, 0.8 ], + "shape" : "rectangle" + }, + "default_data" : { + "type" : "var", + "value" : "" + } + } + }, + "edge_types" : { + "string" : { + "value" : { + "path" : "$.value", + "latex_constants" : false, + "type" : "string" + }, + "style" : { + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "stroke_width" : 1, + "label" : { + "position" : "center", + "fg_color" : [ 0.0, 0.0, 1.0 ], + "bg_color" : [ 0.8, 0.8, 1.0, 0.7 ] + } + }, + "default_data" : { + "type" : "string", + "value" : "" + } + } + }, + "default_vertex_type" : "Z", + "default_edge_type" : "string" +} diff --git a/scala/src/main/resources/quanto/data/plain.qtheory b/scala/src/main/resources/quanto/data/plain.qtheory new file mode 100644 index 00000000..b699a4b3 --- /dev/null +++ b/scala/src/main/resources/quanto/data/plain.qtheory @@ -0,0 +1,51 @@ +{ + "name" : "Plain", + "core_name" : "plain", + "vertex_types" : { + "var" : { + "value" : { + "type" : "string", + "latex_constants" : true, + "validate_with_core" : true + }, + "style" : { + "label" : { + "position" : "inside", + "fg_color" : [ 0.0, 0.0, 0.0 ] + }, + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "fill_color" : [ 1.0, 1.0, 1.0 ], + "shape" : "rectangle" + }, + "default_data" : { + "type" : "var", + "label" : "", + "value" : { + "pretty" : "" + } + } + } + }, + "default_vertex_type" : "var", + "default_edge_type" : "plain", + "edge_types" : { + "plain" : { + "value" : { + "type" : "empty", + "latex_constants" : false, + "validate_with_core" : false + }, + "style" : { + "stroke_color" : [ 0.0, 0.0, 0.0 ], + "stroke_width" : 1, + "label" : { + "position" : "auto", + "fg_color" : [ 0.0, 0.0, 0.0 ] + } + }, + "default_data" : { + "type" : "plain" + } + } + } +} diff --git a/scala/src/main/resources/quanto/gui/Mac_OS_X_keys.props b/scala/src/main/resources/quanto/gui/Mac_OS_X_keys.props new file mode 100644 index 00000000..2eeb3908 --- /dev/null +++ b/scala/src/main/resources/quanto/gui/Mac_OS_X_keys.props @@ -0,0 +1,194 @@ +action-bar.shortcut=C+ENTER +add-explicit-fold.shortcut=MC+a +add-marker-shortcut.shortcut=C+t +add-marker.shortcut=C+m +backspace-word-std.shortcut2=A+k +backspace-word.shortcut=C+BACK_SPACE +backspace.shortcut=BACK_SPACE +bottom-docking-area.shortcut=MC+DOWN +center-caret.shortcut=AC+n +clear-register.shortcut=C+r C+l +close-all.shortcut=AMC+w +close-buffer.shortcut=C+w +close-docking-area.shortcut=AC+BACK_QUOTE +closeall-bufferset.shortcut=AC+w +closeall-except-active.shortcut=CS+w +collapse-all-folds.shortcut=AS+BACK_SPACE +collapse-fold.shortcut=A+BACK_SPACE +complete-word.shortcut=C+b +copy-append-string-register.shortcut=MC+c +copy-append.shortcut=CS+c +copy-string-register.shortcut=AC+c +copy.shortcut2=C+INSERT +copy.shortcut=C+c +cut-append-string-register.shortcut=MC+x +cut-append.shortcut=CS+x +cut-string-register.shortcut=AC+x +cut.shortcut2=S+DELETE +cut.shortcut=C+x +delete-end-line.shortcut=CS+DELETE +delete-line.shortcut=C+d +delete-paragraph.shortcut2=A+h +delete-start-line.shortcut=CS+BACK_SPACE +delete-word.shortcut=C+DELETE +delete.shortcut2=A+d +delete.shortcut=DELETE +document-end.shortcut2=C+DOWN +document-end.shortcut=C+END +document-home.shortcut2=C+UP +document-home.shortcut=C+HOME +end.shortcut2=C+RIGHT +end.shortcut=END +exit.shortcut=C+q +expand-abbrev.shortcut=C+SEMICOLON +expand-fold.shortcut=AS+ENTER +expand-folds.shortcut=AC+ENTER +expand-one-level.shortcut=A+ENTER +find-next.shortcut=C+g +find-prev.shortcut=CS+g +find-previous.shortcut=C+e g +find.shortcut=C+f +focus-buffer-switcher.shortcut=A+BACK_QUOTE +global-close-buffer.shortcut=MC+w +global-options.shortcut=C+F12 +goto-line.shortcut=A+g +goto-marker.shortcut=C+y +help.shortcut=F1 +home.shortcut2=C+LEFT +home.shortcut=HOME +hypersearch-word.shortcut=A+PERIOD +hypersearch.shortcut2=MC+f +hypersearch.shortcut=C+PERIOD +ignore-case.shortcut=AC+i +indent-lines.shortcut=A+i +insert-literal.shortcut=MC+v +insert-newline-indent.shortcut=ENTER +insert-tab-indent.shortcut=TAB +invert-selection.shortcut=MC+i +last-action.shortcut=AC+SPACE +last-macro.shortcut=C+m C+l +latextools-compile.shortcut=F6 +left-docking-area.shortcut=MC+LEFT +line-comment.shortcut=AC+SLASH +line-end.shortcut=A+e +line-home.shortcut=A+a +new-file-in-mode.shortcut=CS+n +new-file.shortcut=C+n +next-bracket.shortcut=AC+CLOSE_BRACKET +next-buffer.shortcut2=A+f +next-buffer.shortcut=C+PAGE_DOWN +next-char.shortcut=RIGHT +next-fold.shortcut=A+DOWN +next-line.shortcut=DOWN +next-marker.shortcut=AC+PERIOD +next-page.shortcut2=A+v +next-page.shortcut=PAGE_DOWN +next-paragraph.shortcut=M+DOWN +next-textarea.shortcut=A+PAGE_DOWN +next-word.shortcut=M+RIGHT +open-file.shortcut=C+o +open-path.shortcut=C+e C+o +overwrite.shortcut=INSERT +page-setup.shortcut=CS+p +parent-fold.shortcut=MC+u +paste-deleted.shortcut=AC+y +paste-previous.shortcut=CS+v +paste-string-register.shortcut=AC+v +paste.shortcut2=S+INSERT +paste.shortcut=C+v +prev-bracket.shortcut=AC+OPEN_BRACKET +prev-buffer.shortcut=C+PAGE_UP +prev-char.shortcut2=A+b +prev-char.shortcut=LEFT +prev-fold.shortcut=A+UP +prev-line.shortcut=UP +prev-marker.shortcut=AC+COMMA +prev-page.shortcut=PAGE_UP +prev-paragraph.shortcut=M+UP +prev-textarea.shortcut=A+PAGE_UP +prev-word.shortcut=M+LEFT +print.shortcut=C+p +quick-search-word.shortcut=A+COMMA +quick-search.shortcut=C+COMMA +range-comment.shortcut=C+SLASH +recent-buffer.shortcut=C+BACK_QUOTE +record-macro.shortcut=AM+r +record-temp-macro.shortcut=AM+m +redo.shortcut=CS+z +reload.shortcut=F5 +remove-trailing-ws.shortcut=MC+r +replace-and-find-next.shortcut=CS+f +replace-in-selection.shortcut=CS+r +resplit.shortcut=C+4 +right-docking-area.shortcut=MC+RIGHT +run-temp-macro.shortcut=AM+p +save-all.shortcut=AC+s +save-as.shortcut=CS+s +save.shortcut=C+s +scroll-and-center.shortcut=C+l +scroll-down-page.shortcut=A+SLASH +scroll-to-current-line.shortcut2=A+l +scroll-to-current-line.shortcut=C+j +scroll-up-line.shortcut=C+QUOTE +scroll-up-page.shortcut=A+QUOTE +search-in-directory.shortcut=CS+d +search-in-open-buffers.shortcut=CS+b +select-all.shortcut=C+a +select-document-end.shortcut=CS+END +select-document-home.shortcut=CS+HOME +select-end.shortcut2=CS+RIGHT +select-end.shortcut=S+END +select-fold.shortcut=MC+s +select-home.shortcut2=CS+LEFT +select-home.shortcut=S+HOME +select-line-range.shortcut=AC+l +select-line.shortcut=MC+l +select-marker.shortcut=C+u +select-next-char.shortcut2=AS+l +select-next-char.shortcut=S+RIGHT +select-next-line.shortcut2=AS+k +select-next-line.shortcut=S+DOWN +select-next-page.shortcut2=AS+a +select-next-page.shortcut=S+PAGE_DOWN +select-next-paragraph.shortcut=MS+DOWN +select-next-word.shortcut=MS+RIGHT +select-none.shortcut=CS+a +select-paragraph.shortcut=MC+p +select-prev-char.shortcut2=AS+j +select-prev-char.shortcut=S+LEFT +select-prev-line.shortcut2=AS+i +select-prev-line.shortcut=S+UP +select-prev-page.shortcut2=AS+q +select-prev-page.shortcut=S+PAGE_UP +select-prev-paragraph.shortcut=MS+UP +select-prev-word.shortcut=MS+LEFT +shift-left.shortcut2=S+TAB +shift-left.shortcut=C+OPEN_BRACKET +shift-right.shortcut=C+CLOSE_BRACKET +show-context-menu.shortcut=CONTEXT_MENU +split-horizontal.shortcut=C+2 +split-vertical.shortcut=C+3 +stop-recording.shortcut=AM+s +swap-marker.shortcut=C+k +toggle-dock-areas.shortcut=F12 +toggle-full-screen.shortcut=F11 +toggle-line-numbers.shortcut=AC+t +toggle-multi-select.shortcut=C+BACK_SLASH +toggle-rect-select.shortcut=A+BACK_SLASH +top-docking-area.shortcut=MC+UP +undo.shortcut=C+z +unsplit-current.shortcut=C+0 +unsplit.shortcut=C+1 +vertical-paste-string-register.shortcut=AM+v +vertical-paste.shortcut=AC+p +vfs.browser.delete.shortcut=DELETE +vfs.browser.home.shortcut=~ +vfs.browser.new-directory.shortcut=INSERT +vfs.browser.new-file.shortcut=C+n +vfs.browser.next.shortcut=A+Right +vfs.browser.previous.shortcut=A+Left +vfs.browser.reload.shortcut=F5 +vfs.browser.rename.shortcut=F2 +vfs.browser.roots.shortcut=/ +vfs.browser.synchronize.shortcut=- +vfs.browser.up.shortcut=A+Up diff --git a/scala/src/main/resources/quanto/gui/add-edge.png b/scala/src/main/resources/quanto/gui/add-edge.png new file mode 100644 index 00000000..886818b1 Binary files /dev/null and b/scala/src/main/resources/quanto/gui/add-edge.png differ diff --git a/scala/src/main/resources/quanto/gui/blank.xml b/scala/src/main/resources/quanto/gui/blank.xml new file mode 100644 index 00000000..ca2c8ee1 --- /dev/null +++ b/scala/src/main/resources/quanto/gui/blank.xml @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/scala/src/main/resources/quanto/gui/expand.png b/scala/src/main/resources/quanto/gui/expand.png new file mode 100644 index 00000000..15c6bbe0 Binary files /dev/null and b/scala/src/main/resources/quanto/gui/expand.png differ diff --git a/scala/src/main/resources/quanto/gui/focus.png b/scala/src/main/resources/quanto/gui/focus.png new file mode 100644 index 00000000..42ae843f Binary files /dev/null and b/scala/src/main/resources/quanto/gui/focus.png differ diff --git a/scala/src/main/resources/quanto/gui/generic-change.svg b/scala/src/main/resources/quanto/gui/generic-change.svg new file mode 100644 index 00000000..1a49be36 --- /dev/null +++ b/scala/src/main/resources/quanto/gui/generic-change.svg @@ -0,0 +1,113 @@ + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + diff --git a/scala/src/main/resources/quanto/gui/jEdit_keys.props b/scala/src/main/resources/quanto/gui/jEdit_keys.props new file mode 100644 index 00000000..7b736f29 --- /dev/null +++ b/scala/src/main/resources/quanto/gui/jEdit_keys.props @@ -0,0 +1,231 @@ +#{{{ Function keys +help.shortcut=F1 +show-context-menu.shortcut=CONTEXT_MENU +#}}} + +#{{{ C+X +select-all.shortcut=C+a +complete-word.shortcut=C+b +copy.shortcut=C+c +delete-line.shortcut=C+d +# C+e is a prefix +find.shortcut=C+f +find-next.shortcut=F3 +# C+h is not usable on MacOS X +indent-lines.shortcut=C+i +join-lines.shortcut=C+j +swap-marker.shortcut=C+k +goto-line.shortcut=C+g +# C+m is a prefix +new-file.shortcut=C+n +new-file-in-mode.shortcut=CS+n +open-file.shortcut=C+o +reload.shortcut=F5 +print.shortcut=C+p +exit.shortcut=C+q +# C+r is a prefix +save.shortcut=C+s +add-marker-shortcut.shortcut=C+t +select-marker.shortcut=C+u +paste.shortcut=C+v +close-buffer.shortcut=C+w +cut.shortcut=C+x +goto-marker.shortcut=C+y +undo.shortcut=C+z +unsplit-current.shortcut=C+0 +unsplit.shortcut=C+1 +split-horizontal.shortcut=C+2 +split-vertical.shortcut=C+3 +resplit.shortcut=C+4 +#}}} + +#{{{ C+non-alpha +delete-start-line.shortcut=CS+BACK_SPACE +delete-end-line.shortcut=CS+DELETE +prev-paragraph.shortcut=C+UP +next-paragraph.shortcut=C+DOWN +select-prev-paragraph.shortcut=CS+UP +select-next-paragraph.shortcut=CS+DOWN +backspace-word.shortcut=C+BACK_SPACE +delete-word.shortcut=C+DELETE +document-home.shortcut=C+HOME +document-end.shortcut=C+END +select-document-home.shortcut=CS+HOME +select-document-end.shortcut=CS+END +prev-word.shortcut=C+LEFT +select-prev-word.shortcut=CS+LEFT +next-word.shortcut=C+RIGHT +select-next-word.shortcut=CS+RIGHT +action-bar.shortcut=C+ENTER +prev-buffer.shortcut=C+PAGE_UP +next-buffer.shortcut=C+PAGE_DOWN +last-action.shortcut=C+SPACE + +recent-buffer.shortcut=C+BACK_QUOTE +select-block.shortcut=C+OPEN_BRACKET +match-bracket.shortcut=C+CLOSE_BRACKET +expand-abbrev.shortcut=C+SEMICOLON +quick-search.shortcut=C+COMMA +hypersearch.shortcut=C+PERIOD +scroll-up-line.shortcut=C+QUOTE +scroll-down-line.shortcut=C+SLASH +toggle-multi-select.shortcut=C+BACK_SLASH +#}}} + +#{{{ C+e C+X +# Unused: f, h, q, y +copy-append.shortcut=C+e C+a +search-in-open-buffers.shortcut=C+e C+b +range-comment.shortcut=C+e C+c +search-in-directory.shortcut=C+e C+d +replace-and-find-next.shortcut=C+e C+g +ignore-case.shortcut=C+e C+i +scroll-to-current-line.shortcut=C+e C+j +line-comment.shortcut=C+e C+k +select-line-range.shortcut=C+e C+l +add-marker.shortcut=C+e C+m +center-caret.shortcut=C+e C+n +scroll-and-center.shortcut=C+l +open-path.shortcut=C+e C+o +vertical-paste.shortcut=C+e C+p +replace-in-selection.shortcut=C+e C+r +save-all.shortcut=C+e C+s +toggle-line-numbers.shortcut=C+e C+t +cut-append.shortcut=C+e C+u +paste-previous.shortcut=C+e C+v +close-all.shortcut=C+e C+w +regexp.shortcut=C+e C+x +paste-deleted.shortcut=C+e C+y +redo.shortcut=C+e C+z +#}}} + +#{{{ C+e C+non-alpha +left-docking-area.shortcut=C+e C+LEFT +top-docking-area.shortcut=C+e C+UP +right-docking-area.shortcut=C+e C+RIGHT +bottom-docking-area.shortcut=C+e C+DOWN +toggle-full-screen.shortcut=F11 +toggle-dock-areas.shortcut=F12 +combined-options.shortcut=C+F12 +prev-marker.shortcut=C+e C+COMMA +next-marker.shortcut=C+e C+PERIOD +prev-bracket.shortcut=C+e C+OPEN_BRACKET +next-bracket.shortcut=C+e C+CLOSE_BRACKET +close-docking-area.shortcut=C+e C+BACK_QUOTE +#}}} + +#{{{ C+e X +# Unused: b e g h j k m o q t y z +add-explicit-fold.shortcut=C+e a +collapse-all-folds.shortcut=C+e c +delete-paragraph.shortcut=C+e d +format-paragraph.shortcut=C+e f +find-previous.shortcut=C+e g +invert-selection.shortcut=C+e i +select-line.shortcut=C+e l +narrow-to-fold.shortcut=C+e n n +narrow-to-selection.shortcut=C+e n s +select-paragraph.shortcut=C+e p +remove-trailing-ws.shortcut=C+e r +select-fold.shortcut=C+e s +insert-literal.shortcut=C+e v +select-word.shortcut=C+e w +parent-fold.shortcut=C+e u +expand-all-folds.shortcut=C+e x +#}}} + +#{{{ C+e non-alpha +expand-folds.shortcut=C+e ENTER +#}}} + +#{{{ C+m C+X +record-temp-macro.shortcut=C+m C+m +run-temp-macro.shortcut=C+m C+p +record-macro.shortcut=C+m C+r +stop-recording.shortcut=C+m C+s +last-macro.shortcut=C+m C+l +#}}} + +#{{{ C+r C+X +copy-append-string-register.shortcut=C+r C+a +copy-string-register.shortcut=C+r C+c +clear-register.shortcut=C+r C+l +vertical-paste-string-register.shortcut=C+r C+p +cut-append-string-register.shortcut=C+r C+u +paste-string-register.shortcut=C+r C+v +cut-string-register.shortcut=C+r C+x +#}}} + +#{{{ A+non-alpha +prev-fold.shortcut=A+UP +next-fold.shortcut=A+DOWN +shift-left.shortcut=A+LEFT +shift-right.shortcut=A+RIGHT +collapse-fold.shortcut=A+BACK_SPACE +expand-fold.shortcut=AS+ENTER +expand-one-level.shortcut=A+ENTER +quick-search-word.shortcut=A+COMMA +hypersearch-word.shortcut=A+PERIOD +scroll-up-page.shortcut=A+QUOTE +scroll-down-page.shortcut=A+SLASH +prev-textarea.shortcut=A+PAGE_UP +next-textarea.shortcut=A+PAGE_DOWN + +focus-buffer-switcher.shortcut=A+BACK_QUOTE +toggle-rect-select.shortcut=A+BACK_SLASH +#}}} + +#{{{ Other keys +shift-left.shortcut2=S+TAB +select-none.shortcut=ESCAPE +backspace.shortcut=BACK_SPACE +delete.shortcut=DELETE +overwrite.shortcut=INSERT +home.shortcut=HOME +end.shortcut=END +select-home.shortcut=S+HOME +select-end.shortcut=S+END +prev-page.shortcut=PAGE_UP +next-page.shortcut=PAGE_DOWN +select-prev-page.shortcut=S+PAGE_UP +select-next-page.shortcut=S+PAGE_DOWN +prev-char.shortcut=LEFT +select-prev-char.shortcut=S+LEFT +next-char.shortcut=RIGHT +select-next-char.shortcut=S+RIGHT +prev-line.shortcut=UP +select-prev-line.shortcut=S+UP +next-line.shortcut=DOWN +select-next-line.shortcut=S+DOWN +insert-newline-indent.shortcut=ENTER +insert-newline.shortcut=S+ENTER +insert-tab-indent.shortcut=TAB +#}}} + +#{{{ Alternative shortcuts for frequently-used commands +select-next-page.shortcut2=AS+a +select-prev-line.shortcut2=AS+i +select-prev-char.shortcut2=AS+j +select-next-line.shortcut2=AS+k +select-next-char.shortcut2=AS+l +select-prev-page.shortcut2=AS+q +select-end.shortcut2=AS+x +select-home.shortcut2=AS+z +copy.shortcut2=C+INSERT +paste.shortcut2=S+INSERT +cut.shortcut2=S+DELETE +#}}} + +#{{{ VFS browser +vfs.browser.delete.shortcut=DELETE +vfs.browser.home.shortcut=~ +vfs.browser.new-directory.shortcut=INSERT +vfs.browser.new-file.shortcut=C+n +vfs.browser.reload.shortcut=F5 +vfs.browser.rename.shortcut=F2 +vfs.browser.roots.shortcut=/ +vfs.browser.synchronize.shortcut=- +vfs.browser.up.shortcut=A+Up +vfs.browser.previous.shortcut=A+Left +vfs.browser.next.shortcut=A+Right +#}}} diff --git a/scala/src/main/resources/quanto/gui/jedit.props b/scala/src/main/resources/quanto/gui/jedit.props new file mode 100644 index 00000000..2df4b86d --- /dev/null +++ b/scala/src/main/resources/quanto/gui/jedit.props @@ -0,0 +1,536 @@ +### +### jEdit global properties +### :tabSize=4:indentSize=4:noTabs=false: +### :folding=explicit:collapseFolds=1: +### :encoding=UTF-8: +### +### Copyright (C) 1998, 2003 Slava Pestov +### + +#{{{ Global settings + +# Swing look and feel +#lookAndFeel=javax.swing.plaf.metal.MetalLookAndFeel + +view.antiAlias=subpixel + +#{{{ Metal control and menu font +metal.primary.font=Dialog +metal.primary.fontsize=12 +metal.primary.fontstyle=0 +#}}} + +#{{{ Metal system and user text font +metal.secondary.font=Dialog +metal.secondary.fontsize=12 +metal.secondary.fontstyle=0 +#}}} + +#{{{ HelpViewer and Tip of the Day font +helpviewer.font=Dialog +helpviewer.fontsize=12 +helpviewer.fontstyle=0 +#}}} + + +# Decorate frames and dialogs using Swing L&F? +decorate.frames=false +decorate.dialogs=false + +# Draw multi-key shortcuts on screen menu bar? (OS X only) +menu.multiShortcut=false + +# If true, welcome screen will be displayed. +# Set to false after initial startup. +firstTime=false + +# Autosave interval in seconds, 0=off +autosave=30 + +# Autosave untitled buffers +autosaveUntitled=true + + +# Maximum number of elements in a history list +history=20 + +# Maximum size (in characters) of a history list +historyMaxSize=5000000 + +# Number of recent files +recentFiles=40 + +# Restore open files on startup? +restore=true + +# Restore even if file names specified on command line? +restore.cli=true + +# Persistent markers +persistentMarkers=true + +# Two-stage save (save to #filename#save# first, then filename) +twoStageSave=true + +# Strip trailing EOL +stripTrailingEOL=false + +# Complete words from all open buffers +completeFromAllBuffers=false + +# Insert a word completion when the corresponding digit is pressed +insertCompletionWithDigit=true + +# Need this so that new files have a trailing EOL +buffer.trailingEOL=true + +# Caret saving toggle +saveCaret=true + +# Backup on every save +backupEverySave=false + +# Number of backups to make, 0=no backups +backups=1 + +# Backup directory +backup.directory= + +# Backup filename prefix and suffix +backup.prefix= +backup.suffix=~ + +# Min time between backups (0 = always backup) +backup.minTime=0 + +# Sort buffer list +sortBuffers=true +sortByName=true + +# Sort recent file list +sortRecent=false + +# Default firewall properties +firewall.enabled=false + +# Keep dialog on by default +search.keepDialog.toggle=true + +# When this limit is reached a dialog appears to cancel the search +hypersearch.maxWarningResults=1000 + +# If the hypersearch query is longer than this value it will be truncated +# on display in the results +hypersearch.displayQueryLength=100 + +# Style for highlighting matches in hypersearch results +hypersearch.results.highlight=bgColor:#ccccff + +# Confirm dialogs +confirmSaveAll=true + +# Check modification status on focus? + +# if false and autoReload == false, do nothing +# if false and autoReload == true, autoreload silently +# If true and autoReload == false, prompt for reloading +# if true and autoReload == true, reload and notify user (message box) +autoReloadDialog=true + +# If this is true, auto reload; if false, use 'reload' button in dialog +autoReload=true + +# When to check file status on disk: 1=view focus. See GeneralOptionPane class for meanings of values. +checkFileStatus=1 + +# Encoding detectors +encodingDetectors=BOM XML-PI html python buffer-local-property + +#}}} + +# The critical size, over it a buffer will prompt when setting the edit mode +largeBufferSize=4000000 +longLineLimit=4000 +largefilemode=ask + +#{{{ Buffer settings +# These can also be specified as buffer-local properties + +# Line separator +# The OS default will be used if this is not specified +# Unix=\n +# Windows=\r\n +# MacOS=\r +#buffer.lineSeparator=\n + +# Encoding +# The OS default will be used if this is not specified +#buffer.encoding=ISO8859_1 + +# Auto-detect GZIP, UTF16, UTF8 and XML encodings? +buffer.encodingAutodetect=true + +# Tab width +buffer.tabSize=4 + +# Indent width +buffer.indentSize=4 + +# Automatic indentation +buffer.autoIndent=full + +# Soft tabs +buffer.noTabs=false + +#elastic tabstops +buffer.elasticTabstops=false + +# Default edit mode +buffer.defaultMode=text + +# Undo queue size +buffer.undoCount=100 + +# Wrap mode (none, soft, hard) +buffer.wrap=none + +# Wrap column +buffer.maxLineLen=80 + +# Word break characters +buffer.wordBreakChars= + +# Non-alphanumeric word characters +buffer.noWordSep=_ + +# Whether to separate "CamelCased" words +buffer.camelCasedWords=false + +# Fold mode (explicit, indent, or none) +buffer.folding=none + +# Folds with a level equal to or higher than this will be collapsed +buffer.collapseFolds=0 +#}}} + +#{{{ View settings + +# Apply jEdit colors to Swing text fields, text areas, lists, tables and trees? +textColors=false + +# Show toolbar? +view.showToolbar=true + +# Show searchbar? +view.showSearchbar=false + +# Show buffer switcher? +view.showBufferSwitcher=true + +# Show full path in title bar? +view.showFullPath=false + +# Font +view.font=Lucida Sans Typewriter +view.fontsize=11 +# 0=plain, 1=bold, 2=italic, 3=boldItalic +view.fontstyle=0 + +# Font Substitution +view.enableFontSubst=false +view.enableFontSubstSystemFonts=true + +# Background and foreground colors (for the text area) +view.bgColor=#ffffff +view.fgColor=#000000 + +# Line highlighting +view.lineHighlight=true +view.lineHighlightColor=#ffffe0 + +# Bracket highlighting +view.structureHighlight=true +view.structureHighlightColor=#000000 + +# EOL markers +view.eolMarkers=false +view.eolMarkerColor=#ff6633 +view.eolMarkerChar=↩ + +# Wrap guide +view.wrapGuide=true +view.wrapGuideColor=#8080ff + +# page breaks +view.pageBreaks=false +view.pageBreaksColor=#8080ff + +# Caret color +view.caretColor=#ff0000 + +# Selection color +view.selectionColor=#ccccff +view.multipleSelectionColor=#ccffcc + +# Caret blinking +view.caretBlink=true + +# Block caret +view.blockCaret=false + +# Electric borders +view.electricBorders=3 + +# Drag and drop of text +view.dragAndDrop=true + +# Abbreviate paths using environment variables +view.abbreviatePaths=true + +# Treat consecutive non-alphanumeric characters as one word +view.joinNonWordChars=true + +# Middle mouse button pastes % register +view.middleMousePaste=false + +# Pressing Ctrl while mouse actions makes them +# as if selection mode were rectangular mode +view.ctrlForRectangularSelection=true + +# Minimal view size that is considered "valid" when loading perspective +view.minStartupWidth=200 +view.minStartupHeight=200 + +Combined\ Options.width=1024 +Combined\ Options.height=768 + +# The default bufferSet scope +editpane.bufferset.default=view + +#{{{ Gutter + +# Gutter background color +view.gutter.bgColor=#dbdbdb + +# Gutter foreground color +view.gutter.fgColor=#000000 + +# Gutter highlight color +view.gutter.highlightColor=#990066 + +# Gutter current line highlight color +view.gutter.currentLineColor=#ff0033 + +# Gutter bracket highlighting +view.gutter.structureHighlight=true +view.gutter.structureHighlightColor=#666699 + +# Gutter marker highlight color +view.gutter.markerColor=#ccffcc + +# Gutter border colors +view.gutter.focusBorderColor=#990099 +view.gutter.noFocusBorderColor=#ffffff + +# Fold triangle color +view.gutter.foldColor=#838383 + +# Gutter font name +view.gutter.font=Monospaced + +# Gutter font style +# 0=plain, 1=bold, 2=italic, 3=boldItalic +view.gutter.fontstyle=0 + +# Gutter font size +view.gutter.fontsize=10 + +# Gutter border width +view.gutter.borderWidth=2 + +# Gutter displays line numbers +view.gutter.lineNumbers=true + +# Line numbers are drawn with this alignment (left, center, right) +view.gutter.numberAlignment=right + +# Gutter line numbers are highlighted at this interval +view.gutter.highlightInterval=5 + +# Gutter current line highlighting +view.gutter.highlightCurrentLine=true + +# Marker highlight +view.gutter.markerHighlight=true + +# Click behavior +view.gutter.foldClick=toggle-fold +view.gutter.SfoldClick=toggle-fold-fully +view.gutter.CfoldClick=select-fold +view.gutter.AfoldClick=narrow-fold +view.gutter.structClick=match-struct +view.gutter.CstructClick=select-struct +view.gutter.AstructClick=narrow-struct + +# Show gutter? +view.gutter.enabled=true + +# Gutter minimal number of digits to reserve for line numbers +view.gutter.minDigitCount=2 + +# Show selection area in gutter? +view.gutter.selectionAreaEnabled=true +# Gutter selection area background color +# - leave out so by default it is the same as the gutter's +#view.gutter.selectionAreaBgColor=#dbdbdb +#}}} + +# Expand abbrevs when space bar pressed +view.expandOnInput=false + +#{{{ Syntax styles +view.style.comment1=color:#cc0000 +view.style.comment2=color:#ff8400 +view.style.comment3=color:#6600cc +view.style.comment4=color:#cc6600 +view.style.digit=color:#ff0000 +view.style.foldLine.0=color:#000000 bgColor:#dafeda style:b +view.style.foldLine.1=color:#000000 bgColor:#fff0cc style:b +view.style.foldLine.2=color:#000000 bgColor:#e7e7ff style:b +view.style.foldLine.3=color:#000000 bgColor:#ffe0f0 style:b +view.style.function=color:#9966ff +view.style.invalid=color:#ff0066 bgColor:#ffffcc +view.style.keyword1=color:#006699 style:b +view.style.keyword2=color:#009966 style:b +view.style.keyword3=color:#0099ff style:b +view.style.keyword4=color:#66ccff style:b +view.style.label=color:#02b902 +view.style.literal1=color:#ff00cc +view.style.literal2=color:#cc00cc +view.style.literal3=color:#9900cc +view.style.literal4=color:#6600cc +view.style.markup=color:#0000ff +view.style.operator=color:#000000 style:b +#}}} + +# Docking and tool bar positioning +view.docking.alternateLayout=false +# "alternate" is actually closer to standard location IMHO... +view.toolbar.alternateLayout=true + +#{{{ Status bar +view.status.foreground=black +view.status.background=white +view.status.visible=true +view.status.plainview.visible=false +view.status.show-caret-status=true +view.status.memory.foreground=#cccccc +view.status.memory.background=#666699a +#}}} + +#}}} + +#{{{ Printing settings + +# Font +print.font=Monospaced +print.fontsize=9 +print.fontstyle=0 + +# Print header? +print.header=true + +# Print footer? +print.footer=true + +# Print line numbers? +print.lineNumbers=true + +# Print in color, or black and white? +print.color=false + +# Print tab size +print.tabSize=2 + +# Use old (JDK 1.3) printing API +print.force13=false + +# Force use of glyph vectors to work around spacing problems +print.glyphVector=false +#}}} + +#{{{ File system browser settings +vfs.browser.showMenubar=true +vfs.browser.showToolbar=true +vfs.browser.showIcons=true +vfs.browser.showHiddenFiles=false +vfs.browser.sortMixFilesAndDirs=false +vfs.browser.sortIgnoreCase=true +vfs.browser.doubleClickClose=false +vfs.browser.currentBufferFilter=false +vfs.browser.useDefaultIcons=true + +# Can be one of: buffer, home, favorites, last +vfs.browser.defaultPath=buffer + +# File list coloring +vfs.browser.colorize=true + +vfs.browser.colors.0.glob={CVS,#*,*~,\\.*} +vfs.browser.colors.0.color=#a0a0a0 +vfs.browser.colors.1.glob=*.class +vfs.browser.colors.1.color=#660066 +vfs.browser.colors.2.glob={build.xml,makefile*,*.hxml} +vfs.browser.colors.2.color=#666600 +vfs.browser.colors.3.glob=*.{gif,jpg,jpeg,png,bmp,xpm,svg} +vfs.browser.colors.3.color=#009933 +vfs.browser.colors.4.glob=*.{gz,jar,zip,tgz,z,war,ear} +vfs.browser.colors.4.color=#990000 +vfs.browser.colors.5.glob=tags +vfs.browser.colors.5.color=#003366 +vfs.browser.colors.6.glob=*.{sh,bat,cmd} +vfs.browser.colors.6.color=#006666 +vfs.browser.colors.7.glob={CHANGELOG,CHANGES,INSTALL,LICENSE,NEWS,README,TODO}{,.txt} +vfs.browser.colors.7.color=#330066 +vfs.browser.colors.8.glob=*.{props,properties} +vfs.browser.colors.8.color=#666666 + +browser.custom.context= +#}}} + +#{{{ Plugin manager settings +plugin-manager.downloadSource=false +plugin-manager.installUser=true +plugin-manager.showAll=true +plugin-manager.mirror.id=NONE +plugin-manager.deleteDownloads=true +plugin-manager.hide-libraries.toggle=true +plugin-blacklist.MacOS.jar=true +#}}} + +#{{{ Hidden settings +menu.spillover=20 +bufferSwitcher.maxRowCount=10 +showTooltips=true +ioThreadCount=4 +server.brokenToFront=false +search.dontSyncFilter=false +#}}} + +#{{{ Miscellaneous settings +# restore remote VFS files by default (new option) +options.general.restore.remote=true +optional.title-template={0} {1} +mime2mode.text/html=html +debug.beepOnOutput=false +#}}} + +#{{{ Keymaps +# the current keymap name +keymap.current=jEdit +# the default keymap name +keymap.default=jEdit +#}}} + +# if lang.usedefaultlocale is true, the lang.current is not used +lang.usedefaultlocale=true diff --git a/scala/src/main/resources/quanto/gui/normalise.png b/scala/src/main/resources/quanto/gui/normalise.png new file mode 100644 index 00000000..6e6c6547 Binary files /dev/null and b/scala/src/main/resources/quanto/gui/normalise.png differ diff --git a/scala/src/main/resources/quanto/gui/python.xml b/scala/src/main/resources/quanto/gui/python.xml index 1e97faa3..7d36a94b 100644 --- a/scala/src/main/resources/quanto/gui/python.xml +++ b/scala/src/main/resources/quanto/gui/python.xml @@ -356,6 +356,26 @@ __truediv__ __version__ __xor__ + + + new_graph_from_json + vertex_angle_is + EMPTY + JSON_REWRITE + REWRITE + REWRITE_METRIC + REWRITE_METRIC_TO + REWRITE_WEAK_METRIC + REWRITE_TARGETED + ANNEAL + LOG + REWRITE_TARGET_LIST + REPEAT + REDUCE + REDUCE_METRIC + REDUCE_METRIC_TO + REDUCE_WEAK_METRIC + REDUCE_TARGETED diff --git a/scala/src/main/resources/quanto/gui/quantoderive.ico b/scala/src/main/resources/quanto/gui/quantoderive.ico new file mode 100644 index 00000000..4dd67f50 Binary files /dev/null and b/scala/src/main/resources/quanto/gui/quantoderive.ico differ diff --git a/scala/src/main/scala/quanto/cosy/BlockEnumeration.scala b/scala/src/main/scala/quanto/cosy/BlockEnumeration.scala index 3517212a..26d1edcc 100644 --- a/scala/src/main/scala/quanto/cosy/BlockEnumeration.scala +++ b/scala/src/main/scala/quanto/cosy/BlockEnumeration.scala @@ -1,7 +1,9 @@ package quanto.cosy +import quanto.cosy.BlockGenerators.QuickGraph +import quanto.data.Names._ +import quanto.data.{VName, _} import quanto.util.json._ -import quanto.data._ /** * Enumerates diagrams by composing simple building blocks in a 2D fashion @@ -10,12 +12,13 @@ class BlockEnumeration { } -case class Block(inputs: List[Int], outputs: List[Int], name: String, tensor: Tensor) { +case class Block(inputs: List[Int], outputs: List[Int], name: String, tensor: Tensor, graph: Graph = new Graph()) { lazy val toJson = JsonObject( "inputs" -> inputs, "outputs" -> outputs, "name" -> name, - "tensor" -> tensor.toJson + "tensor" -> tensor.toJson, + "graph" -> graph.toJson() ) override def toString: String = this.name @@ -26,11 +29,17 @@ object Block { new Block((js / "inputs").asArray.map(x => x.intValue).toList, (js / "outputs").asArray.map(x => x.intValue).toList, (js / "name").stringValue, - Tensor.fromJson((js / "tensor").asObject)) + Tensor.fromJson((js / "tensor").asObject), + Graph.fromJson((js / "graph").asObject)) } + + implicit def toTensor(b: Block) : Tensor = b.tensor } -class BlockRow(val blocks: List[Block], suggestTensor: Option[Tensor] = None) { + +case class BlockRow(blocks: List[Block], suggestTensor: Option[Tensor] = None, suggestGraph: Option[Graph] = None) { + + implicit def optionTensor(t: Tensor): Option[Tensor] = { Option(t) } @@ -42,6 +51,13 @@ class BlockRow(val blocks: List[Block], suggestTensor: Option[Tensor] = None) { case None => blocks.foldLeft(Tensor.id(1))((a, b) => a x b.tensor) } + lazy val graph: Graph = suggestGraph match { + case Some(g) => g + case None => blocks.foldLeft(new Graph()) { (g, b) => + BlockRow.graphsSideBySide(g, b.graph) + } + } + lazy val toJson = JsonObject( "blocks" -> JsonArray(blocks.map(b => b.toJson)), "inputs" -> inputs, @@ -56,13 +72,46 @@ object BlockRow { def fromJson(js: JsonObject): BlockRow = { new BlockRow((js / "blocks").asArray.map(j => Block.fromJson(j.asObject)).toList) } + + + def graphsSideBySide(fixed: Graph, added: Graph): Graph = { + + + val startingInputs = fixed.verts.filter(vn => vn.prefix == "i-") + val startingOutputs = fixed.verts.filter(vn => vn.prefix == "o-") + val shift = math.max(startingInputs.size, startingOutputs.size) + + val aShifted: Graph = added.verts.foldLeft(added)((g, vn) => g.updateVData(vn)(vd => { + val currentCoord = added.vdata(vn).coord + added.vdata(vn).withCoord(currentCoord._1 + shift, currentCoord._2) + } + )) + + val renameMap = aShifted.verts.flatMap(vn => (vn.prefix, vn.suffix) match { + case ("i-", n) => Some(vn -> VName("i-" + (n + startingInputs.size))) + case ("o-", n) => Some(vn -> VName("o-" + (n + startingOutputs.size))) + case (a, b) => Some(vn -> VName("bl-" + shift + "-" + a + b)) + case _ => None + }).toMap + + val aShiftedRename = aShifted.rename(vrn = renameMap, ern = Map(), brn = Map()) + fixed.appendGraph(aShiftedRename.renameAvoiding(fixed), false) + } } -class BlockStack(val rows: List[BlockRow]) extends Ordered[BlockStack] { +case class BlockStack(rows: List[BlockRow], + suggestedTensor: Option[Tensor] = None, + suggestedGraph: Option[Graph] = None) extends Ordered[BlockStack] { + + require(rows.flatMap(_.blocks).nonEmpty) // require at least one block + lazy val tensor: Tensor = if (rows.isEmpty) { Tensor.id(1) } else { - rows.foldRight(Tensor.id(rows.last.tensor.width))((a, b) => a.tensor o b) + suggestedTensor match { + case Some(t) => t + case None => rows.foldRight(Tensor.id(rows.last.tensor.width))((a, b) => a.tensor o b) + } } lazy val toJson = JsonObject( "rows" -> JsonArray(rows.map(b => b.toJson)), @@ -79,197 +128,90 @@ class BlockStack(val rows: List[BlockRow]) extends Ordered[BlockStack] { this.rows.length - that.asInstanceOf[BlockStack].rows.length } -} - -object BlockStack { - def fromJson(js: JsonObject): BlockStack = { - new BlockStack((js / "rows").asArray.map(j => BlockRow.fromJson(j.asObject)).toList) - } -} - -object BlockRowMaker { - implicit def quickList(n: Int): List[Int] = { - n match { - case 0 => List() - case 1 => List(0) - case m => quickList(m - 1) ::: List(0) + //left-most row is on top! + lazy val graph: Graph = { + suggestedGraph match { + case Some(g) => g + case None => + val blocks = rows.flatMap(row => row.blocks) + if (blocks.isEmpty) { + new Graph() + } else { + val gdata = blocks.head.graph.data + BlockStack.joinRowsInStack( + rows.reverse.zipWithIndex.foldRight(new Graph(data = gdata))((ri, g) => + BlockStack.graphStackUnjoined(g, ri._1.graph, ri._2))) + + } } - } - - val BellTeleportation: List[Block] = List( - Block(List(0), List(0), " A ", Tensor.id(2)), - Block(List(-1), List(-1), " B ", Tensor.id(2)), - Block(List(0, 0), List(1), " m1", Tensor(Array(Array(1, 0, 0, 1)))), - Block(List(0, 0), List(2), " m2", Tensor(Array(Array(1, 0, 0, -1)))), - Block(List(0, 0), List(3), " m3", Tensor(Array(Array(0, 1, 1, 0)))), - Block(List(0, 0), List(4), " m4", Tensor(Array(Array(0, 1, -1, 0)))), - Block(List(-1, 1), List(-1), " c1", Tensor(Array(Array(1, 0), Array(0, 1)))), - Block(List(-1, 2), List(-1), " c2", Tensor(Array(Array(1, 0), Array(0, -1)))), - Block(List(-1, 3), List(-1), " c3", Tensor(Array(Array(0, 1), Array(1, 0)))), - Block(List(-1, 4), List(-1), " c4", Tensor(Array(Array(0, 1), Array(-1, 0)))), - Block(List(), List(0, -1), " p ", Tensor(Array(Array(1, 0, 0, 1))).transpose) - ) ::: - swapQuantumClassical(List(0, -1), Tensor.id(2), List(1, 2, 3, 4)) ::: - makeClassicalIdentites(List(1, 2, 3, 4)) - val ZW: List[Block] = List( - // BOTTOM TO TOP! - Block(1, 1, " 1 ", Tensor.idWires(1)), - Block(2, 2, " s ", Tensor.swap(List(1, 0))), - Block(2, 2, "crs", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 1, 0), Array(0, 1, 0, 0), Array(0, 0, 0, -1)))), - Block(0, 2, "cup", Tensor(Array(Array(1, 0, 0, 1))).transpose), - Block(2, 0, "cap", Tensor(Array(Array(1, 0, 0, 1)))), - Block(1, 1, " w ", Tensor(Array(Array(1, 0), Array(0, -1)))), - Block(1, 2, "1w2", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 0, -1))).transpose), - Block(2, 1, "2w1", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 0, -1)))), - Block(1, 1, " b ", Tensor(Array(Array(0, 1), Array(1, 0)))), - Block(1, 2, "1b2", Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 1))).transpose), - Block(2, 1, "2b1", Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 1)))) - ) - def swapQuantumClassical(listQuantum: List[Int], quantumTensor: Tensor, listClassical: List[Int]): List[Block] = { - (for (w1 <- listQuantum; w2 <- listClassical) yield { - List(Block(List(w1, w2), List(w2, w1), w1 + "s" + w2, quantumTensor), - Block(List(w2, w1), List(w1, w2), w2 + "s" + w1, quantumTensor)) - }).flatten } - def swapQuantumQuantum(listQuantum: List[Int], quantumTensor: Tensor): List[Block] = { - (for (w1 <- listQuantum; w2 <- listQuantum) yield { - List(Block(List(w1, w2), List(w2, w1), w1 + "s" + w2, quantumTensor), - Block(List(w2, w1), List(w1, w2), w2 + "s" + w1, quantumTensor)) - }).flatten + // newly added row is on top! + // And our graphs go bottom to top + // Forces computation of tensor and graph (as that's what this is designed for) + def append(row: BlockRow): BlockStack = { + //TODO: Make this parallel? + val newTensor = if (rows.nonEmpty) { + require(rows.head.outputs == row.inputs) + row.tensor o tensor + } else { + row.tensor + } + val newRows = row :: rows + val newGraph = BlockStack.joinRowsInStack(BlockStack.graphStackUnjoined(graph, row.graph, rows.length)) + new BlockStack(newRows, Some(newTensor), Some(newGraph)) } - def makeClassicalIdentites(listClassical: List[Int]): List[Block] = { - for (w <- listClassical) yield { - Block(List(w), List(w), "w" + w + " ", Tensor.id(1)) - } +} + +object BlockStack { + def fromJson(js: JsonObject): BlockStack = { + new BlockStack((js / "rows").asArray.map(j => BlockRow.fromJson(j.asObject)).toList) } - // Traditionally the number of angles is 3 (Clifford) or 9 (Clifford+T) - def ZXQudit(dimension: Int, numAngles: Int): List[Block] = { - require(dimension > 1) - def swapIndex(i: Int): Int = { - val left: Int = i / dimension - val right = i % dimension - right * dimension + left + def joinRowsInStack(graph: Graph): Graph = { + var g = QuickGraph(graph) + val InputPattern = raw"r-(\d+)-i-(\d+)".r + g.verts.foreach(vName => vName.s match { + case InputPattern(n, m) => + // For 0.toString is coming out as "" not "0", but this shouldn't affect us + if (g.verts.contains(s"r-${Integer.parseInt(n) - 1}-o-$m")) { + g = g.joinIfNotAlready(s"r-${Integer.parseInt(n) - 1}-o-$m", s"r-$n-i-$m", Some("rail")) + } + case _ => g } - - val H: Tensor = Hadamard(dimension) - - val greenFork = Tensor(dimension, dimension * dimension, - (i, j) => if (j == i * (dimension + 1)) Complex.one else Complex.zero) - - // Go through the diagonal entries creating all the different spiders - val greenBlocks = (1 until dimension).foldLeft( - List(Block(1, 1, "g", Tensor(dimension, dimension, (i, j) => if (i == 0 && j == 0) Complex.one else Complex.zero))) - )((lb, i) => lb.flatMap(b => (0 until numAngles).map(x => - Block(1, 1, b.name + "|" + x, b.tensor + Tensor(dimension, dimension, (j, k) => - if (j == i && k == i) ei(x * 2 * math.Pi / numAngles) else Complex.zero) - )))) - - List( - Block(1, 1, " 1 ", Tensor.id(dimension)), - Block(2, 2, " s ", Tensor.permutation((0 until dimension * dimension).toList.map(x => swapIndex(x)))), - Block(1, 1, " H ", H), - Block(1, 1, " H'", H.dagger), - Block(2, 1, "2g1", greenFork), - Block(1, 2, "1g2", greenFork.dagger), - Block(0, 1, "gu ", Tensor(dimension, 1, (i, j) => Complex.one).scaled(1.0 / math.sqrt(dimension))), - Block(1, 2, "1r2", (H.dagger o greenFork o (H x H)).dagger), - Block(2, 1, "2r1", H.dagger o greenFork o (H x H)), - Block(0, 1, "ru ", Tensor(dimension, 1, (i, j) => if (i == 0 && j == 0) Complex.one else Complex.zero)) - ) ::: greenBlocks ::: greenBlocks.map(b => - Block(1, 1, "r" + b.name.tail, H.dagger o b.tensor o H) ) + g } - // Traditionally the number of angles is 3 (Clifford) or 9 (Clifford+T) - def ZXQutrit(numAngles: Int = 9): List[Block] = { - val H3 = Hadamard(3) - List( - Block(1, 1, " 1 ", Tensor.id(3)), - Block(2, 2, " s ", Tensor.permutation(List(0, 3, 6, 1, 4, 7, 2, 5, 8))), - Block(1, 1, " H ", H3), - Block(1, 1, " H'", H3.dagger), - Block(2, 1, "2g1", Tensor(Array( - Array(1, 0, 0, 0, 0, 0, 0, 0, 0), - Array(0, 0, 0, 0, 1, 0, 0, 0, 0), - Array(0, 0, 0, 0, 0, 0, 0, 0, 1) - ))), - Block(1, 2, "1g2", Tensor(Array( - Array(1, 0, 0, 0, 0, 0, 0, 0, 0), - Array(0, 0, 0, 0, 1, 0, 0, 0, 0), - Array(0, 0, 0, 0, 0, 0, 0, 0, 1) - )).transpose), - Block(0, 1, "gu ", Tensor(Array(Array(1, 1, 1))).scaled(1.0 / math.sqrt(3)).transpose), - Block(1, 2, "1r2", (H3.dagger o Tensor(Array( - Array(1, 0, 0, 0, 0, 0, 0, 0, 0), - Array(0, 0, 0, 0, 1, 0, 0, 0, 0), - Array(0, 0, 0, 0, 0, 0, 0, 0, 1) - )) o (H3 x H3)).dagger), - Block(2, 1, "2r1", H3.dagger o Tensor(Array( - Array(1, 0, 0, 0, 0, 0, 0, 0, 0), - Array(0, 0, 0, 0, 1, 0, 0, 0, 0), - Array(0, 0, 0, 0, 0, 0, 0, 0, 1) - )) o (H3 x H3)), - Block(0, 1, "ru ", Tensor(Array(Array(1, 0, 0))).transpose) - ) ::: - (for (i <- 0 until numAngles; j <- 0 until numAngles) yield { - val gs = Tensor(Array( - Array[Complex](1, 0, 0), - Array[Complex](0, ei(i * 2 * math.Pi / numAngles), 0), - Array[Complex](0, 0, ei(j * 2 * math.Pi / numAngles)) - )) - List(Block(1, 1, "g|" + i.toString + "|" + j.toString, gs), - Block(1, 1, "r|" + i.toString + "|" + j.toString, H3.dagger o gs o H3)) - }).flatten.toList + def graphStackUnjoined(fixed: Graph, adding: Graph, depth: Int): Graph = { + val renameMap = adding.verts.map(vn => vn -> VName(s"r-$depth-${vn.s}")).toMap + val aRenamed = adding.rename(vrn = renameMap, ern = Map(), brn = Map()) + val aRenamedShifted = aRenamed.verts.foldLeft(aRenamed)((g, vn) => g.updateVData(vn)(vd => { + val currentCoord = aRenamed.vdata(vn).coord + aRenamed.vdata(vn).withCoord(currentCoord._1, currentCoord._2 + depth) + } + )) + + if(fixed.data.theory != adding.data.theory){ + fixed.copy(data = fixed.data.copy(theory = fixed.data.theory.mixin(adding.data.theory, None))) + .appendGraph(aRenamedShifted.renameAvoiding(fixed), noOverlap = false) + }else{ + fixed + .appendGraph(aRenamedShifted.renameAvoiding(fixed), noOverlap = false) + } } +} - def Hadamard(dimension: Int): Tensor = - Tensor(dimension, dimension, (i, j) => ei(2 * math.Pi * i * j / dimension)).scaled(1 / math.sqrt(dimension)) - - private def ei(angle: Double) = Complex(math.cos(angle), math.sin(angle)) - - def ZX(numAngles: Int = 8): List[Block] = List( - Block(1, 1, " 1 ", Tensor.idWires(1)), - Block(2, 2, " s ", Tensor.swap(List(1, 0))), - Block(0, 2, "cup", Tensor(Array(Array(1, 0, 0, 1))).transpose), - Block(2, 0, "cap", Tensor(Array(Array(1, 0, 0, 1))))) ::: - (for (i <- 0 until numAngles) yield { - Block(1, 1, "gT" + i.toString, Tensor(Array( - Array(Complex.one, Complex.zero), - Array(Complex.zero, ei(2 * i * math.Pi / numAngles))))) - }).toList ::: - (for (i <- 0 until numAngles) yield { - Block(1, 1, "rT" + i.toString, new Tensor(Array( - Array(1 + ei(2 * i * math.Pi / numAngles), 1 - ei(2 * i * math.Pi / numAngles)), - Array(1 - ei(2 * i * math.Pi / numAngles), 1 + ei(2 * i * math.Pi / numAngles))))) - }).toList ::: - List( - Block(1, 1, " H ", Tensor(Array(Array(1, 1), Array(1, -1))).scaled(1.0 / math.sqrt(2))), - Block(2, 1, "2g1", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 0, 1)))), - Block(1, 2, "1g2", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 0, 1))).transpose), - Block(0, 1, "gu ", Tensor(Array(Array(1, 1))).transpose), - Block(1, 2, "1r2", Tensor(Array(Array(1, 0, 0, 1), Array(0, 1, 1, 0))).transpose), - Block(2, 1, "2r1", Tensor(Array(Array(1, 0, 0, 1), Array(0, 1, 1, 0)))), - Block(0, 1, "ru ", Tensor(Array(Array(1, 0))).transpose) - ) +object BlockRowMaker { - def Bian2Qubit: List[Block] = List( - // Block(0, 0, " w ", Tensor.id(1).scaled(ei(math.Pi / 4))), Ignored for now. - Block(1, 1, " 1 ", Tensor.id(2)), - Block(2, 2, " Zc", Tensor.diagonal(Array(Complex.one, Complex.one, Complex.one, Complex.zero.-(Complex.one)))), - Block(1, 1, " T ", Tensor.diagonal(Array(Complex.one, ei(math.Pi / 4)))), - Block(1, 1, " H ", Tensor(Array(Array(1, 1), Array(1, -1))).scaled(1.0 / math.sqrt(2))), - Block(1, 1, " S ", Tensor.diagonal(Array(Complex.one, ei(math.Pi / 2)))) - ) - def stackToGraph(stack: BlockStack, blockToGraph: Block => Graph): Graph = { + def predicateStackToGraph(stack: BlockStack, blockToGraph: Block => Graph): Graph = { var g = (for ((row, index) <- stack.rows.zipWithIndex) yield { - val rowGraph = rowToGraph(row, blockToGraph) + val rowGraph = predicateRowToGraph(row, blockToGraph) val verts = rowGraph.verts.toList val vRename = verts.map(v => v -> VName("r" + index + v.s)).toMap val eRename = rowGraph.edges.map(e => e -> EName("r" + index + e.s)).toMap @@ -278,13 +220,17 @@ object BlockRowMaker { }).foldLeft(new Graph())((g, sg) => g.appendGraph(sg)) for ((row, index) <- stack.rows.init.zipWithIndex) { for (j <- row.outputs.indices) { - g = g.addEdge(EName("r" + index + "e" + j), UndirEdge(), VName("r" + index + "o" + j) -> VName("r" + (index + 1) + "i" + j)) + g = g.addEdge( + EName("r" + index + "e" + j), + UndirEdge(), + VName("r" + index + "o" + j) -> VName("r" + (index + 1) + "i" + j) + ) } } g } - def rowToGraph(row: BlockRow, blockToGraph: Block => Graph): Graph = { + def predicateRowToGraph(row: BlockRow, blockToGraph: Block => Graph): Graph = { var inputsCovered = 0 var outputsCovered = 0 val inputRegex = raw"i(\d+)".r @@ -305,81 +251,23 @@ object BlockRowMaker { }).foldLeft(new Graph())((g, bg) => g.appendGraph(bg)) } - - def Bian2QubitToGraph(block: Block): Graph = { - - // The graph produced must be 0-indexed on inputs and outputs, and of the form /i\d+/ and /o\d+/ - - implicit def vname(str: String): VName = VName(str) - - implicit def vnamepair(p: (String, String)): (VName, VName) = VName(p._1) -> VName(p._2) - - implicit def ename(str: String): EName = EName(str) - - val rg = Theory.fromFile("red_green") - - var g = new Graph() - - var eCount = 0 - - def join(v0: String, v1: String): Unit = { - g = g.addEdge("e" + eCount, UndirEdge(), v0 -> v1) - eCount += 1 - } - - def addVertex(name: String, data: VData): Unit = { - g = g.addVertex(name, data) - } - - for (i <- block.inputs.indices) { - addVertex("i" + i, WireV()) - } - for (i <- block.outputs.indices) { - addVertex("o" + i, WireV()) - } - block.name match { - case " 1 " => - join("i0", "o0") - case " T " => - addVertex("v0", NodeV(data = JsonObject("type" -> "X", "value" -> "pi/4"), theory = rg)) - join("i0", "v0") - join("v0", "o0") - case " H " => - addVertex("v0", NodeV(data = JsonObject("type" -> "hadamard", "value" -> "0"), theory = rg)) - join("i0", "v0") - join("v0", "o0") - case " S " => - addVertex("v0", NodeV(data = JsonObject("type" -> "X", "value" -> "pi/2"), theory = rg)) - join("i0", "v0") - join("v0", "o0") - case " Zc" => - addVertex("v0", NodeV(data = JsonObject("type" -> "X", "value" -> "0"), theory = rg)) - addVertex("v1", NodeV(data = JsonObject("type" -> "Z", "value" -> "0"), theory = rg)) - join("v0", "v1") - join("i0", "v0") - join("v0", "o0") - join("i1", "v1") - join("v1", "o1") - } - - g - } - def apply(maxBlocks: Int, allowedBlocks: List[Block], maxInOut: Option[Int] = None): List[BlockRow] = { require(maxBlocks >= 0) - (for (i <- 0 to maxBlocks) yield { - makeRowsOfSize(i, allowedBlocks, maxInOut) - }).flatten.toList + makeRowsUpToSize(maxBlocks, allowedBlocks, maxInOut) } - def makeRowsOfSize(size: Int, - allowedBlocks: List[Block], - maxInOut: Option[Int] = None): List[BlockRow] = { + def makeRowsUpToSize(size: Int, + allowedBlocks: List[Block], + maxInOut: Option[Int] = None): List[BlockRow] = { val maybeTooLargeRows: List[BlockRow] = size match { case 0 => List[BlockRow]() case 1 => allowedBlocks.map(b => new BlockRow(List(b))) - case n => for (base <- makeRowsOfSize(n - 1, allowedBlocks, maxInOut); block <- allowedBlocks) yield { - new BlockRow(block :: base.blocks) + case n => { + val fewerBlocks = makeRowsUpToSize(n - 1, allowedBlocks, maxInOut) + + (for (base <- fewerBlocks; block <- allowedBlocks) yield { + new BlockRow(block :: base.blocks) + }) ::: fewerBlocks } } maxInOut match { @@ -387,6 +275,8 @@ object BlockRowMaker { case Some(max) => maybeTooLargeRows.filter(r => (r.inputs.length <= max) && (r.outputs.length <= max)) } } + + } object BlockStackMaker { diff --git a/scala/src/main/scala/quanto/cosy/BlockGenerators.scala b/scala/src/main/scala/quanto/cosy/BlockGenerators.scala new file mode 100644 index 00000000..2743d8d3 --- /dev/null +++ b/scala/src/main/scala/quanto/cosy/BlockGenerators.scala @@ -0,0 +1,444 @@ +package quanto.cosy + +import quanto.data._ +import quanto.util.json.JsonObject +import quanto.data.Names._ + +import scala.util.matching.Regex + +object BlockGenerators { + + + implicit def quickList(n: Int): List[Int] = { + n match { + case 0 => List() + case 1 => List(0) + case m => quickList(m - 1) ::: List(0) + } + } + + + class QuickGraph(graph: Graph) { + val _g : Graph = graph + def node(nodeType: String, angle: String = "", xCoord : Double = 0, nodeName : String = "v-0") : QuickGraph = { + val name = _g.verts.freshWithSuggestion(VName(nodeName)) + val data = NodeV(data = JsonObject("type" -> nodeType, "value" -> angle), theory = _g.data.theory).withCoord((xCoord, 0)) + QuickGraph(_g.addVertex(name, data)) + } + + def bbox(name: String, vertices: Set[String]): QuickGraph = { + val bbname = _g.bboxes.freshWithSuggestion(BBName(name)) + val bbdata = BBData(theory = _g.data.theory) + QuickGraph(_g.addBBox(bbname, bbdata, vertices.map(s => VName(s)))) + } + + def addInput(count : Int = 1) : QuickGraph = { + count match { + case 0 => this + case 1 => + val name = _g.verts.freshWithSuggestion(VName("i-0")) + val data = WireV().withCoord(name.suffix,-0.5) + QuickGraph(_g.addVertex(name, data)) + case n => + val name = _g.verts.freshWithSuggestion(VName("i-0")) + val data = WireV().withCoord(name.suffix,-0.5) + QuickGraph(_g.addVertex(name, data)).addInput(count -1) + } + } + def addOutput(count: Int = 1) : QuickGraph = { + count match { + case 0 => this + case 1 => + val name = _g.verts.freshWithSuggestion(VName("o-0")) + val data = WireV().withCoord(name.suffix,0.5) + QuickGraph(_g.addVertex(name, data)) + case n => + val name = _g.verts.freshWithSuggestion(VName("o-0")) + val data = WireV().withCoord(name.suffix,0.5) + QuickGraph(_g.addVertex(name, data)).addOutput(count -1) + } + } + + def join(s1 : String, s2: String, edgeType : Option[String] = None) : QuickGraph = { + val name = _g.edges.freshWithSuggestion("e-0") + val eData = if(edgeType.isEmpty) { + _g.data.theory.edgeTypes(_g.data.theory.defaultEdgeType).defaultData + } else { + _g.data.theory.edgeTypes(edgeType.get).defaultData + } + val data = UndirEdge(eData, theory = _g.data.theory) + val v1 = VName(s1) + val v2 = VName(s2) + QuickGraph(_g.addEdge(name, data, v1 -> v2)) + } + + def join(s1: String, s2s: Set[String], edgeType: Option[String]) : QuickGraph = { + s2s.foldLeft(this)((g,v) => g.join(s1, v, edgeType)) + } + + def joinIfNotAlready(s1: String, s2: String, edgeType : Option[String] = None) : QuickGraph = { + val isJoined = _g.adjacentVerts(s1).contains(s2) + if(!isJoined){ + this.join(s1, s2, edgeType) + }else{ + this + } + } + + def apply() : Graph = _g + } + + object QuickGraph { + def apply(graph: Graph) = new QuickGraph(graph) + def apply(theory: Theory) = new QuickGraph( + new Graph(GData(data = new JsonObject(), annotation = new JsonObject(), theory = theory)) + ) + + implicit def slow(qg: QuickGraph) : Graph = qg() + + val boundaryRegex : Option[Regex] = Some(raw"""(i|o)-(\d+)""".r) + } + + val ZXTheory : Theory = Theory.fromFile("ZX") + val ZXRails : Theory = Theory.fromFile("ZXRails") + + def zxCNOT(size: Int = 2): Block = { + require(size >= 2) + // Size 2 is the standard + val tensorSize = Math.pow(2, size).toInt + val cut = tensorSize - 2 + val swap = Tensor.swap((0 until size).map { + case 0 => 0 + case 1 => size - 1 + case n => n - 1 + }.toList) + val penultimate = size - 1 + val graph = (0 until size).foldLeft(QuickGraph(ZXRails).addInput(size).addOutput(size)) { + (g, i) => + i match { + case 0 => g.node("Z", nodeName = "z", xCoord = i).join("i-" + i, "z").join("z", "o-" + i) + case `penultimate` => + g.node("X", nodeName = "x", xCoord = i).join("i-" + i, "x").join("x", "o-" + i) + case _ => g.join("i-" + i, "o-" + i, Some("rail")) + } + }.join("z", "x", Some("string")) + Block(size, size, "CNOT" + size, + swap o + ( + Tensor(Array(Array(1, 0, 0, 0), Array(0, 1, 0, 0), Array(0, 0, 0, 1), Array(0, 0, 1, 0))) + x Tensor.idWires(size - 2) + ) + o swap.transpose + , + graph + ) + } + def zxTONC(size: Int = 2): Block = { + require(size >= 2) + // Size 2 is the standard + val swapList = (0 until size).map { + i => { + if (i == 0) size - 1 + else if (i == size - 1) 0 + else i + } + }.toList + val swap = Tensor.swap(swapList) + val penultimate = size - 1 + Block(size, size, "TONC" + size, swap.transpose o zxCNOT(size) o swap, + (0 until size).foldLeft(QuickGraph(ZXRails).addInput(size).addOutput(size)) { + (g, i) => + i match { + case `penultimate` => g.node("Z", nodeName = "z", xCoord = i).join("i-" + i, "z").join("z", "o-" + i) + case 0 => g.node("X", nodeName = "x", xCoord = i).join("i-" + i, "x").join("x", "o-" + i) + case _ => g.join("i-" + i, "o-" + i) + } + }.join("z", "x", Some("string")) + ) + } + + + def zxCNOTs(maxWidth: Int = 2): List[Block] = (for (i <- 2 to maxWidth) yield zxCNOT(i)).toList + def zxTONCs(maxWidth: Int = 2): List[Block] = (for (i <- 2 to maxWidth) yield zxCNOT(i)).toList + + val zxQubitHadamard : Block = Block(1, 1, " H ", Hadamard(2), QuickGraph(ZXRails).addInput().addOutput() + .node("hadamard", nodeName = "h").join("i-0", "h").join("h", "o-0")) + + def zxQubitTwists(twoPiDivision: Int = 4): IndexedSeq[Block] = { + def tensor(angle: Int, nodeType: String): Tensor = { + nodeType match { + case "Z" => Tensor(Array( + Array(Complex.one, Complex.zero), + Array(Complex.zero, ei(2 * angle * math.Pi / twoPiDivision)))) + case "X" => + Tensor(Array( + Array(1 + ei(2 * angle * math.Pi / twoPiDivision), 1 - ei(2 * angle * math.Pi / twoPiDivision)), + Array(1 - ei(2 * angle * math.Pi / twoPiDivision), 1 + ei(2 * angle * math.Pi / twoPiDivision)))) + } + } + + def block(angle: Int, nodeType: String): Block = { + Block(1, 1, (2*angle) + nodeType + twoPiDivision, tensor(angle, nodeType), + QuickGraph(ZXRails) + .addInput().addOutput() + .node(nodeType = nodeType, nodeName = nodeType, angle = s"${2*angle} / $twoPiDivision") + .join("i-0", nodeType, Some("rail")).join(nodeType, "o-0", Some("rail"))) + } + + (0 until twoPiDivision).flatMap(i => + List(block(i, "X"), block(i, "Z")) + ) + } + + + val ZXClifford: List[Block] = List( + Block(1, 1, " 1 ", Tensor.idWires(1), QuickGraph(ZXRails).addInput().addOutput().join("i-0", "o-0", Some("rail"))), + zxQubitHadamard, + Block(2, 2, " s ", Tensor.swap(List(1, 0)), + QuickGraph(ZXRails).addInput(2).addOutput(2).join("i-0", "o-1", Some("rail")).join("i-1", "o-0", Some("rail"))), + Block(1, 1, "gpi", Tensor(Array(Array(1, 0), Array(0, -1))), QuickGraph(ZXRails).addInput().addOutput() + .node("Z", nodeName = "zpi", angle = raw"\pi").join("i-0", "zpi", Some("rail")).join("zpi", "o-0", Some("rail"))), + Block(1, 1, "rpi", Tensor(Array(Array(0, 1), Array(1, 0))), QuickGraph(ZXRails).addInput().addOutput() + .node("X", nodeName = "xpi", angle = raw"\pi").join("i-0", "xpi", Some("rail")).join("xpi", "o-0", Some("rail"))), + Block(1, 1, "gp2", Tensor(Array(Array(Complex(1, 0), Complex(0, 0)), Array(Complex(0, 0), Complex(0, 1)))), + QuickGraph(ZXRails).addInput().addOutput() + .node("Z", nodeName = "z", angle = raw"\pi / 2").join("i-0", "z", Some("rail")).join("z", "o-0", Some("rail"))), + Block(1, 1, "rp2", Tensor(Array(Array(Complex(1, 1), Complex(1, -1)), Array(Complex(1, -1), Complex(1, 1)))), + QuickGraph(ZXRails).addInput().addOutput() + .node("X", nodeName = "x", angle = raw"\pi / 2").join("i-0", "x", Some("rail")).join("x", "o-0", Some("rail"))), + Block(2, 2, "CNT", Tensor(Array(Array(1, 0, 0, 0), Array(0, 1, 0, 0), Array(0, 0, 0, 1), Array(0, 0, 1, 0))), + QuickGraph(ZXRails).addInput(2).addOutput(2).node("Z", nodeName = "z").node("X", xCoord = 1, nodeName = "x") + .join("i-0", "z", Some("rail")).join("i-1", "x", Some("rail")) + .join("o-0", "z", Some("rail")).join("o-1", "x", Some("rail")) + .join("z", "x", Some("string"))) + ) + + + val ZXCNOT: List[Block] = List( + Block(1, 1, " 1 ", Tensor.idWires(1), QuickGraph(ZXRails).addInput().addOutput().join("i-0", "o-0", Some("rail"))), + Block(2, 2, " s ", Tensor.swap(List(1, 0)), + QuickGraph(ZXRails).addInput(2).addOutput(2).join("i-0", "o-1", Some("rail")).join("i-1", "o-0", Some("rail"))), + Block(2, 2, "CNT", Tensor(Array(Array(1, 0, 0, 0), Array(0, 1, 0, 0), Array(0, 0, 0, 1), Array(0, 0, 1, 0))), + QuickGraph(ZXRails).addInput(2).addOutput(2).node("Z", nodeName = "z").node("X", xCoord = 1, nodeName = "x") + .join("i-0", "z", Some("rail")).join("i-1", "x", Some("rail")) + .join("o-0", "z", Some("rail")).join("o-1", "x", Some("rail")) + .join("z", "x", Some("string"))) + ) + + val BellTeleportation: List[Block] = List( + Block(List(0), List(0), " A ", Tensor.id(2)), + Block(List(-1), List(-1), " B ", Tensor.id(2)), + Block(List(0, 0), List(1), " m1", Tensor(Array(Array(1, 0, 0, 1)))), + Block(List(0, 0), List(2), " m2", Tensor(Array(Array(1, 0, 0, -1)))), + Block(List(0, 0), List(3), " m3", Tensor(Array(Array(0, 1, 1, 0)))), + Block(List(0, 0), List(4), " m4", Tensor(Array(Array(0, 1, -1, 0)))), + Block(List(-1, 1), List(-1), " c1", Tensor(Array(Array(1, 0), Array(0, 1)))), + Block(List(-1, 2), List(-1), " c2", Tensor(Array(Array(1, 0), Array(0, -1)))), + Block(List(-1, 3), List(-1), " c3", Tensor(Array(Array(0, 1), Array(1, 0)))), + Block(List(-1, 4), List(-1), " c4", Tensor(Array(Array(0, 1), Array(-1, 0)))), + Block(List(), List(0, -1), " p ", Tensor(Array(Array(1, 0, 0, 1))).transpose) + ) ::: + swapQuantumClassical(List(0, -1), Tensor.id(2), List(1, 2, 3, 4)) ::: + makeClassicalIdentites(List(1, 2, 3, 4)) + val ZW: List[Block] = List( + // BOTTOM TO TOP! + Block(1, 1, " 1 ", Tensor.idWires(1)), + Block(2, 2, " s ", Tensor.swap(List(1, 0))), + Block(2, 2, "crs", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 1, 0), Array(0, 1, 0, 0), Array(0, 0, 0, -1)))), + Block(0, 2, "cup", Tensor(Array(Array(1, 0, 0, 1))).transpose), + Block(2, 0, "cap", Tensor(Array(Array(1, 0, 0, 1)))), + Block(1, 1, " w ", Tensor(Array(Array(1, 0), Array(0, -1)))), + Block(1, 2, "1w2", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 0, -1))).transpose), + Block(2, 1, "2w1", Tensor(Array(Array(1, 0, 0, 0), Array(0, 0, 0, -1)))), + Block(1, 1, " b ", Tensor(Array(Array(0, 1), Array(1, 0)))), + Block(1, 2, "1b2", Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 1))).transpose), + Block(2, 1, "2b1", Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 1)))) + ) + + def swapQuantumClassical(listQuantum: List[Int], quantumTensor: Tensor, listClassical: List[Int]): List[Block] = { + (for (w1 <- listQuantum; w2 <- listClassical) yield { + List(Block(List(w1, w2), List(w2, w1), w1 + "s" + w2, quantumTensor), + Block(List(w2, w1), List(w1, w2), w2 + "s" + w1, quantumTensor)) + }).flatten + } + + def swapQuantumQuantum(listQuantum: List[Int], quantumTensor: Tensor): List[Block] = { + (for (w1 <- listQuantum; w2 <- listQuantum) yield { + List(Block(List(w1, w2), List(w2, w1), w1 + "s" + w2, quantumTensor), + Block(List(w2, w1), List(w1, w2), w2 + "s" + w1, quantumTensor)) + }).flatten + } + + def makeClassicalIdentites(listClassical: List[Int]): List[Block] = { + for (w <- listClassical) yield { + Block(List(w), List(w), "w" + w + " ", Tensor.id(1)) + } + } + + // Traditionally the number of angles is 3 (Clifford) or 9 (Clifford+T) + def ZXQudit(dimension: Int, numAngles: Int): List[Block] = { + require(dimension > 1) + + def swapIndex(i: Int): Int = { + val left: Int = i / dimension + val right = i % dimension + right * dimension + left + } + + val H: Tensor = Hadamard(dimension) + + val greenFork = Tensor(dimension, dimension * dimension, + (i, j) => if (j == i * (dimension + 1)) Complex.one else Complex.zero) + + // Go through the diagonal entries creating all the different spiders + val greenBlocks = (1 until dimension).foldLeft( + List( + Block(1, 1, "g", Tensor(dimension, dimension, (i, j) => if (i == 0 && j == 0) Complex.one else Complex.zero)) + ) + )((lb, i) => lb.flatMap(b => (0 until numAngles).map(x => + Block(1, 1, b.name + "|" + x, b.tensor + Tensor(dimension, dimension, (j, k) => + if (j == i && k == i) ei(x * 2 * math.Pi / numAngles) else Complex.zero) + )))) + + List( + Block(1, 1, " 1 ", Tensor.id(dimension)), + Block(2, 2, " s ", Tensor.permutation((0 until dimension * dimension).toList.map(x => swapIndex(x)))), + Block(1, 1, " H ", H), + Block(1, 1, " H'", H.dagger), + Block(2, 1, "2g1", greenFork), + Block(1, 2, "1g2", greenFork.dagger), + Block(0, 1, "gu ", Tensor(dimension, 1, (_, _) => Complex.one).scaled(1.0 / math.sqrt(dimension))), + Block(1, 2, "1r2", (H.dagger o greenFork o (H x H)).dagger), + Block(2, 1, "2r1", H.dagger o greenFork o (H x H)), + Block(0, 1, "ru ", Tensor(dimension, 1, (i, j) => if (i == 0 && j == 0) Complex.one else Complex.zero)) + ) ::: greenBlocks ::: greenBlocks.map(b => + Block(1, 1, "r" + b.name.tail, H.dagger o b.tensor o H) + ) + } + + def Hadamard(dimension: Int): Tensor = + Tensor(dimension, dimension, (i, j) => ei(2 * math.Pi * i * j / dimension)).scaled(1 / math.sqrt(dimension)) + + // Traditionally the number of angles is 3 (Clifford) or 9 (Clifford+T) + def ZXQutrit(numAngles: Int = 9): List[Block] = { + val H3 = Hadamard(3) + List( + Block(1, 1, " 1 ", Tensor.id(3)), + Block(2, 2, " s ", Tensor.permutation(List(0, 3, 6, 1, 4, 7, 2, 5, 8))), + Block(1, 1, " H ", H3), + Block(1, 1, " H'", H3.dagger), + Block(2, 1, "2g1", Tensor(Array( + Array(1, 0, 0, 0, 0, 0, 0, 0, 0), + Array(0, 0, 0, 0, 1, 0, 0, 0, 0), + Array(0, 0, 0, 0, 0, 0, 0, 0, 1) + ))), + Block(1, 2, "1g2", Tensor(Array( + Array(1, 0, 0, 0, 0, 0, 0, 0, 0), + Array(0, 0, 0, 0, 1, 0, 0, 0, 0), + Array(0, 0, 0, 0, 0, 0, 0, 0, 1) + )).transpose), + Block(0, 1, "gu ", Tensor(Array(Array(1, 1, 1))).scaled(1.0 / math.sqrt(3)).transpose), + Block(1, 2, "1r2", (H3.dagger o Tensor(Array( + Array(1, 0, 0, 0, 0, 0, 0, 0, 0), + Array(0, 0, 0, 0, 1, 0, 0, 0, 0), + Array(0, 0, 0, 0, 0, 0, 0, 0, 1) + )) o (H3 x H3)).dagger), + Block(2, 1, "2r1", H3.dagger o Tensor(Array( + Array(1, 0, 0, 0, 0, 0, 0, 0, 0), + Array(0, 0, 0, 0, 1, 0, 0, 0, 0), + Array(0, 0, 0, 0, 0, 0, 0, 0, 1) + )) o (H3 x H3)), + Block(0, 1, "ru ", Tensor(Array(Array(1, 0, 0))).transpose) + ) ::: + (for (i <- 0 until numAngles; j <- 0 until numAngles) yield { + val gs = Tensor(Array( + Array[Complex](1, 0, 0), + Array[Complex](0, ei(i * 2 * math.Pi / numAngles), 0), + Array[Complex](0, 0, ei(j * 2 * math.Pi / numAngles)) + )) + List(Block(1, 1, "g|" + i.toString + "|" + j.toString, gs), + Block(1, 1, "r|" + i.toString + "|" + j.toString, H3.dagger o gs o H3)) + }).flatten.toList + } + + def ZXGates(numAngles: Int = 8, CNOTWidth: Int = 2): List[Block] = List( + Block(1, 1, " 1 ", Tensor.idWires(1), QuickGraph(ZXRails).addInput().addOutput().join("i-0", "o-0")), + Block(2, 2, " s ", Tensor.swap(List(1, 0))) + //, zxQubitHadamard + ) ::: + zxQubitTwists(numAngles).toList ::: + zxCNOTs(CNOTWidth) ::: + zxTONCs(CNOTWidth) + + private def ei(angle: Double) = Complex(math.cos(angle), math.sin(angle)) + + def Bian2Qubit: List[Block] = List( + // Block(0, 0, " w ", Tensor.id(1).scaled(ei(math.Pi / 4))), Ignored for now. + Block(1, 1, " 1 ", Tensor.id(2)), + Block(2, 2, " Zc", Tensor.diagonal(Array(Complex.one, Complex.one, Complex.one, Complex.zero.-(Complex.one)))), + Block(1, 1, " T ", Tensor.diagonal(Array(Complex.one, ei(math.Pi / 4)))), + Block(1, 1, " H ", Tensor(Array(Array(1, 1), Array(1, -1))).scaled(1.0 / math.sqrt(2))), + Block(1, 1, " S ", Tensor.diagonal(Array(Complex.one, ei(math.Pi / 2)))) + ) + + + def Bian2QubitToGraph(block: Block): Graph = { + + // The graph produced must be 0-indexed on inputs and outputs, and of the form /i\d+/ and /o\d+/ + + implicit def vname(str: String): VName = VName(str) + + implicit def vnamepair(p: (String, String)): (VName, VName) = VName(p._1) -> VName(p._2) + + implicit def ename(str: String): EName = EName(str) + + val rg = Theory.fromFile("red_green") + + var g = new Graph() + + var eCount = 0 + + def join(v0: String, v1: String): Unit = { + g = g.addEdge(g.edges.fresh, UndirEdge(), vnamepair(v0,v1)) + eCount += 1 + } + + def addVertex(name: String, data: VData): Unit = { + g = g.addVertex(vname(name), data) + } + + for (i <- block.inputs.indices) { + addVertex("i" + i, WireV()) + } + for (i <- block.outputs.indices) { + addVertex("o" + i, WireV()) + } + block.name match { + case " 1 " => + join("i0", "o0") + case " T " => + addVertex("v0", NodeV(data = JsonObject("type" -> "X", "value" -> "pi/4"), theory = rg)) + join("i0", "v0") + join("v0", "o0") + case " H " => + addVertex("v0", NodeV(data = JsonObject("type" -> "hadamard", "value" -> "0"), theory = rg)) + join("i0", "v0") + join("v0", "o0") + case " S " => + addVertex("v0", NodeV(data = JsonObject("type" -> "X", "value" -> "pi/2"), theory = rg)) + join("i0", "v0") + join("v0", "o0") + case " Zc" => + addVertex("v0", NodeV(data = JsonObject("type" -> "X", "value" -> "0"), theory = rg)) + addVertex("v1", NodeV(data = JsonObject("type" -> "Z", "value" -> "0"), theory = rg)) + join("v0", "v1") + join("i0", "v0") + join("v0", "o0") + join("i1", "v1") + join("v1", "o1") + } + + g + } + +} diff --git a/scala/src/main/scala/quanto/cosy/CoSyRun.scala b/scala/src/main/scala/quanto/cosy/CoSyRun.scala new file mode 100644 index 00000000..5518a29c --- /dev/null +++ b/scala/src/main/scala/quanto/cosy/CoSyRun.scala @@ -0,0 +1,426 @@ +package quanto.cosy + +import java.io.File +import java.util.Calendar + +import quanto.cosy.Interpreter.{ZXAngleData, interpretZXSpider} +import quanto.data.Theory.{ValueType, VertexDesc} +import quanto.data._ +import quanto.rewrite.{Match, Matcher} +import quanto.util.FileHelper._ +import quanto.util.json.{Json, JsonObject} +import quanto.util.{FileHelper, Rational, UserAlerts} + +import scala.concurrent.duration.Duration +import scala.util.matching.Regex + +/** + * This class performs the actual batch conjecture synthesis + */ +abstract class CoSyRun[S, T]( + rulesDir: File, + theory: Theory, + duration: Duration, + outputDir: Option[File] + ) { + + val Generator: Iterator[S] + var reductionRules: List[Rule] = List() + var equivClasses: Map[T, Graph] = Map() + + // Turn your generator into a graph + def makeGraph(gen: S): Graph + + // Turn your generator into a tensor + def makeTensor(gen: S): T + + // If you want to check for isomorphisms specify a regex to match the boundaries here + val matchBorders: Option[Regex] + + def checkIsomorphic(graph1: Graph, graph2: Graph): Boolean = + GraphAnalysis.checkIsomorphic(theory, Some(matchBorders.getOrElse("".r)))(graph1,graph2) + + // Positive if left bigger than right: + def compareGraph(left: Graph, right: Graph): Int + + // See if two tensors should be considered equivalent + // e.g. isRoughly or isRoughlyUpToScalar + def compareTensor(a: T, b: T): Boolean + + def findClassesCloseTo(tensor: T): Map[T, Graph] = equivClasses.filter(tv => compareTensor(tv._1, tensor)) + + // How to store the values in values.txt + def makeString(a: S, b: T): String + + // what to do with graphs that aren't hit by reduction rules + // e.g. for circuits put that circuit into the pile to be added to for next iteration + def doWithUnmatched(a: S): Unit + + // Core loop. + // Comes with option of time restriction. + def begin(): List[Rule] = { + def now(): Long = Calendar.getInstance().getTimeInMillis + + val timeStart = now() + while (Duration(now() - timeStart, "millis") < duration && Generator.hasNext) { + // Get a graph + val next: S = Generator.next() + val nextGraph = makeGraph(next) + + + /* +// Print out each graph made + if (outputDir.nonEmpty) { + FileHelper.printToFile( + outputDir.get.toURI.resolve("./" + nextGraph.hashCode + ".qgraph"), + Graph.toJson(nextGraph).toString, + append = true) + } +*/ + + val matchesReductionRule = reductionRules.exists(rule => Matcher.findMatches(rule.lhs, nextGraph).nonEmpty) + + // Check if it can be reduced by known rules + if (!matchesReductionRule) { + val interpretation = makeTensor(next) + val nearbyClasses = findClassesCloseTo(interpretation) + + nearbyClasses.size match { + case 0 => + // doesn't fit into any existing class + doWithUnmatched(next) + equivClasses = equivClasses + (interpretation -> nextGraph) + updateValuesFile(next, interpretation) + case 1 => + // Something with that tensor exists + val equivClass = nearbyClasses.head + val existing: Graph = equivClass._2 + + + // Don't create a rule between isomorphic (constrained at the boundary) graphs + val isomorphic = if (matchBorders.nonEmpty) { + checkIsomorphic(nextGraph, existing) + } else { + false + } + + + if (!isomorphic) { + doWithUnmatched(next) + if (compareGraph(existing, nextGraph) > 0) { + // update class with smaller graph + equivClasses = equivClasses + (equivClass._1 -> nextGraph) + createRule(existing, nextGraph) + } else { + // new graph is at least as big as current + // keep status quo, create rule + createRule(nextGraph, existing) + } + } + + case n => + // Somehow in the approximation radius of two classes + // For now, add to both. + for (equivClass <- nearbyClasses) { + // Something with that tensor exists + val existing: Graph = equivClass._2 + + + val isomorphic = if (matchBorders.nonEmpty) { + checkIsomorphic(nextGraph, existing) + } else { + false + } + + + if (!isomorphic) { + doWithUnmatched(next) + createRule(nextGraph, existing) + // update class with smaller graph + if (compareGraph(existing, nextGraph) > 0) { + equivClasses = equivClasses + (equivClass._1 -> nextGraph) + } + } else { + // Don't create a rule between isomorphic (constrained) graphs + } + } + } + } + } + reductionRules + } + + private def updateValuesFile(next: S, interpretation: T): Unit = { + if (outputDir.nonEmpty) { + FileHelper.printToFile( + outputDir.get.toURI.resolve("./values.txt"), + makeString(next, interpretation), + append = true) + } + } + + def createRule(lhs: Graph, rhs: Graph): Rule = { + val name = s"${lhs.hashCode}_${rhs.hashCode}.qrule" + val r = new Rule(lhs, rhs, derivation = Some("CoSy"), description = RuleDesc(name)) + if (outputDir.nonEmpty) { + printJson(outputDir.get.toURI.resolve("./" + name).getPath, Rule.toJson(r, lhs.data.theory)) + } + loadRule(r) + r + } + + def loadRule(rule: Rule): Unit = { + def reduceRules(rules: List[Rule]) : List[Rule] = { + RuleSynthesis.greedyReduceRules(compareGraph, Some((theory, matchBorders)))(rules).filter( + rule => !checkIsomorphic(rule.lhs, rule.rhs) + ) + } + // Please don't put bbox rules into here unless you really mean them to be here and they reduce left->right + if (rule.lhs.bboxes.nonEmpty) { + reductionRules = rule :: reductionRules + } else { + // No bboxes, act normally + if (compareGraph(rule.lhs, rule.rhs) > 0) { + reductionRules = rule :: reductionRules + reductionRules = reduceRules(reductionRules) + } else if (compareGraph(rule.rhs, rule.lhs) > 0) { + reductionRules = rule.inverse :: reductionRules + reductionRules = reduceRules(reductionRules) + } else { + // Not a reduction rule, so leave it out + } + } + } + + FileHelper.readAllOfType(rulesDir.getAbsolutePath, ".*qrule", Rule.fromJson(_, theory)).foreach(loadRule) +} + +object CoSyRuns { + + + class CoSyCircuit(rulesDir: File, + theory: Theory, + duration: Duration, + outputDir: Option[File], + numBoundaries: Int + ) extends CoSyRun[BlockStack, Tensor](rulesDir, theory, duration, outputDir) { + + + // Include the empty diagram + // equivClasses = equivClasses + (Tensor(Array(Array(1))) -> new Graph()) + + override val Generator: Iterator[BlockStack] = new Iterator[BlockStack] { + + override def hasNext: Boolean = unusedStacks.hasNext || unusedRows.hasNext || nextRoundOfStacks.nonEmpty + + override def next(): BlockStack = { + if (!unusedRows.hasNext) { + if (!unusedStacks.hasNext) { + unusedStacks = nextRoundOfStacks.toIterator + nextRoundOfStacks = List() + UserAlerts.alert("Finished next row") + } + unusedRows = rows.toIterator + currentStack = unusedStacks.next() + } + currentStack.append(unusedRows.next()) + } + } + override val matchBorders = Some(raw"""(i|o)-(\d+)""".r) + val blocks: List[Block] = BlockGenerators.ZXGates(4, numBoundaries) + UserAlerts.alert(s"Created ${blocks.length} blocks") + val rows: List[BlockRow] = BlockRowMaker.makeRowsUpToSize(numBoundaries, blocks, Some(numBoundaries)) + .filter(br => br.inputs.size == numBoundaries && br.outputs.size == numBoundaries) + UserAlerts.alert(s"Created ${rows.length} rows") + var unusedRows: Iterator[BlockRow] = rows.toIterator + var unusedStacks: Iterator[BlockStack] = rows.map(r => BlockStack(List(r))).toIterator + var nextRoundOfStacks: List[BlockStack] = List() + var currentStack: BlockStack = unusedStacks.next() + + override def doWithUnmatched(a: BlockStack): Unit = { + nextRoundOfStacks = a :: nextRoundOfStacks + } + + + override def compareTensor(a: Tensor, b: Tensor): Boolean = a.isRoughlyUpToScalar(b) + + override def compareGraph(left: Graph, right: Graph) : Int = GraphAnalysis.zxCircuitCompare(left, right) + + override def makeTensor(gen: BlockStack): Tensor = gen.tensor + + override def makeString(a: BlockStack, b: Tensor): String = s"$a: ${b.toJson}," + + override def makeGraph(gen: BlockStack): Graph = { + val g = gen.graph.minimise + val IOPattern = raw"r-\d+-(i|o)-(\d+)".r + val renameMap: Map[VName, VName] = g.verts.map(vn => vn -> (vn.s match { + case IOPattern(io, n) => VName(io + "-" + n) + case _ => vn + })).toMap + g.rename(renameMap) + } + + } + + class CoSyZX(rulesDir: File, + theory: Theory, + duration: Duration, + outputDir: Option[File], + numAngles: Int, + numBoundaries: List[Int], + numVertices: Int, + scalars: Boolean + ) extends CoSyRun[AdjMat, Tensor](rulesDir, theory, duration, outputDir) { + + + override val Generator: Iterator[AdjMat] = { + val identitiesFirst = ColbournReadEnum.enumerate(1, 1, numBoundaries.max, 0). + filter(a => numBoundaries.contains(a.numBoundaries)) + + val CR = ColbournReadEnum.enumerate(numAngles, numAngles, numBoundaries.max, numVertices) + + UserAlerts.alert(s"CoSy: Finished Colbourn-Read (${CR.size})") + + val CRScalars = if(scalars) CR else CR.filter(adj => !GraphAnalysis.containsScalars(adj)) + + UserAlerts.alert(s"CoSy: Filtered out scalars (${CRScalars.size})") + + val CRScalarsSorted = CRScalars.sortBy(_.size) + + UserAlerts.alert("CoSy: Sorted AdjMats") + + val combined = identitiesFirst.iterator ++ + CRScalarsSorted.iterator.filter(a => numBoundaries.contains(a.numBoundaries)) + + combined + } + + private val gdata = (for (i <- 0 until numAngles) yield { + NodeV(data = JsonObject("type" -> "Z", "value" -> angleMap(i).toString), theory = theory) + }).toVector + private val rdata = (for (i <- 0 until numAngles) yield { + NodeV(data = JsonObject("type" -> "X", "value" -> angleMap(i).toString), theory = theory) + }).toVector + + override def compareTensor(a: Tensor, b: Tensor): Boolean = if (!scalars) { + a.isRoughlyUpToScalar(b) + } else { + a.isRoughly(b) + } + + override val matchBorders = None + + override def doWithUnmatched(a: AdjMat): Unit = { + // Don't need to do anything, since Colbourn-Read handles generating adj-mats + } + + override def makeTensor(gen: AdjMat): Tensor = { + val asGraph = makeGraph(gen) + Interpreter.interpretZXGraph(asGraph, asGraph.verts.filter(asGraph.isTerminalWire).toList.sortBy(_.s), List()) + } + + override def makeGraph(gen: AdjMat): Graph = Graph.fromAdjMat(gen, rdata, gdata) + + override def makeString(a: AdjMat, b: Tensor): String = { + s"adj${a.hash}: ${b.toJson}," + } + + override def compareGraph(left: Graph, right: Graph) : Int = GraphAnalysis.zxGraphCompare(left, right) + + private def angleMap = (x: Int) => PhaseExpression(new Rational(2 * x, numAngles), ValueType.AngleExpr) + + + } + + class CoSyZXBool(rulesDir: File, + theory: Theory, + duration: Duration, + outputDir: Option[File], + numAngles: Int, + numBoundaries: List[Int], + numVertices: Int, + scalars: Boolean + ) extends CoSyRun[AdjMat, Tensor](rulesDir, theory, duration, outputDir) { + + + override val Generator: Iterator[AdjMat] = + (ColbournReadEnum.enumerate(1, 1, numBoundaries.max, 0).iterator ++ + ColbournReadEnum.enumerate(2*numAngles, 2*numAngles, numBoundaries.max, numVertices).iterator). + filter(a => numBoundaries.contains(a.numBoundaries)) + + + + private val gdata = (for (i <- 0 until 2*numAngles) yield { + NodeV(data = JsonObject("type" -> "Z", "value" -> angleMap(i).toString), theory = theory) + }).toVector + private val rdata = (for (i <- 0 until 2*numAngles) yield { + NodeV(data = JsonObject("type" -> "X", "value" -> angleMap(i).toString), theory = theory) + }).toVector + + override def compareTensor(a: Tensor, b: Tensor): Boolean = if (!scalars) { + a.isRoughlyUpToScalar(b) + } else { + a.isRoughly(b) + } + + override val matchBorders = None + + private implicit def stringToPhase(s: String): PhaseExpression = { + CompositeExpression.parseKnowingTypes(s, Vector(ValueType.AngleExpr, ValueType.Boolean))(1) + } + + override def doWithUnmatched(a: AdjMat): Unit = { + // Don't need to do anything, since Colbourn-Read handles generating adj-mats + } + + override def makeTensor(gen: AdjMat): Tensor = { + val asGraph = makeGraph(gen) + + def spiderInterpreter(vdata: NodeV, inputs: Int, outputs: Int): Tensor = { + + val zxData: ZXAngleData = { + val isGreen = vdata.typ == "Z" + val angle = CompositeExpression.parseKnowingTypes(vdata.value, Vector(ValueType.AngleExpr, ValueType.Boolean)) + if (angle(1).constant == Rational(1, 1)) { + ZXAngleData(isGreen, angle(0)) + } else { + ZXAngleData(isGreen, PhaseExpression(new Rational(0, 1), ValueType.AngleExpr)) + } + } + + interpretZXSpider(zxData, inputs, outputs) + } + + + Interpreter.interpretSpiderGraph(spiderInterpreter)(asGraph, asGraph.verts.filter(asGraph.isTerminalWire).toList.sortBy(_.s), List()) + } + + override def makeGraph(gen: AdjMat): Graph = Graph.fromAdjMat(gen, rdata, gdata) + + override def makeString(a: AdjMat, b: Tensor): String = { + s"adj${a.hash}: ${b.toJson}," + } + + + //TODO: Graph comparison for ZXBool + override def compareGraph(left: Graph, right: Graph) : Int = RuleSynthesis.basicGraphComparison(left, right) + + private def angleMap(x: Int): CompositeExpression = + if (x < numAngles) { + CompositeExpression(Vector(ValueType.AngleExpr, ValueType.Boolean), + + Vector(PhaseExpression(new Rational(2 * x, numAngles), ValueType.AngleExpr), + PhaseExpression(new Rational(1, 1), ValueType.Boolean))) + + } else { + CompositeExpression(Vector(ValueType.AngleExpr, ValueType.Boolean), + + Vector(PhaseExpression(new Rational(2 * (x-numAngles), numAngles), ValueType.AngleExpr), + PhaseExpression(new Rational(0, 1), ValueType.Boolean))) + + } + + + } +} + diff --git a/scala/src/main/scala/quanto/cosy/ColbournReadEnum.scala b/scala/src/main/scala/quanto/cosy/ColbournReadEnum.scala index dd25db3d..520c9bc2 100644 --- a/scala/src/main/scala/quanto/cosy/ColbournReadEnum.scala +++ b/scala/src/main/scala/quanto/cosy/ColbournReadEnum.scala @@ -34,21 +34,20 @@ case class AdjMat(numRedTypes: Int, mat: Vector[Vector[Boolean]] = Vector()) extends Ordered[AdjMat] { lazy val size: Int = mat.length - lazy val numRed : Int = red.sum - lazy val numGreen : Int = green.sum - lazy val hash : String = makeHash() - def toJson : JsonObject = JsonObject("hash" -> hash) + lazy val numRed: Int = red.sum + lazy val numGreen: Int = green.sum + lazy val hash: String = makeHash() lazy val vertexColoursAndTypes: List[(VertexColour.EnumVal, Int)] = { var _vertexColoursAndTypes: List[(VertexColour.EnumVal, Int)] = List() var colCount = 0 var angleTypeCount = 0 - for (i <- 0 until numBoundaries) { + for (_ <- 0 until numBoundaries) { _vertexColoursAndTypes = (VertexColour.Boundary, 0) :: _vertexColoursAndTypes colCount += 1 } for (j <- red) { - for (i <- 0 until j) { + for (_ <- 0 until j) { _vertexColoursAndTypes = (VertexColour.Red, angleTypeCount) :: _vertexColoursAndTypes colCount += 1 } @@ -57,7 +56,7 @@ case class AdjMat(numRedTypes: Int, angleTypeCount = 0 for (j <- green) { - for (i <- 0 until j) { + for (_ <- 0 until j) { _vertexColoursAndTypes = (VertexColour.Green, angleTypeCount) :: _vertexColoursAndTypes colCount += 1 } @@ -66,6 +65,8 @@ case class AdjMat(numRedTypes: Int, _vertexColoursAndTypes.reverse } + def toJson: JsonObject = JsonObject("hash" -> hash) + // advance to the next type of vertex added by the addVertex method. The order is boundaries, // then each red type, then each green type. def nextType: Option[AdjMat] = { @@ -76,7 +77,7 @@ case class AdjMat(numRedTypes: Int, // This method grows the adjacency matrix by adding a new boundary, red node, or green node, with the given // vector of edges. - def addVertex(connection: Vector[Boolean]) : AdjMat = { + def addVertex(connection: Vector[Boolean]): AdjMat = { if (red.isEmpty && green.isEmpty) { // new vertex is a boundary copy(numBoundaries = numBoundaries + 1, mat = growMatrix(connection)) } else if (red.nonEmpty && green.isEmpty) { // new vertex is a red node @@ -110,15 +111,6 @@ case class AdjMat(numRedTypes: Int, // a matrix is canonical if it is lexicographically smaller than any vertex permutation def isCanonical(permuteBoundary: Boolean = false): Boolean = validPerms(permuteBoundary).forall { p => compareWithPerm(p) <= 0 } - // compare this matrix with itself, but with the rows and columns permuted according to "perm" - def compareWithPerm(perm: Vector[Int]): Int = { - for (i <- 0 until size) - for (j <- 0 to i) - if (mat(i)(j) < mat(perm(i))(perm(j))) return -1 - else if (mat(i)(j) > mat(perm(i))(perm(j))) return 1 - 0 - } - // return all the vertex-permutations which respect type and keep boundary fixed def validPerms(permuteBoundary: Boolean): Vector[Vector[Int]] = { var idx = numBoundaries @@ -186,6 +178,15 @@ case class AdjMat(numRedTypes: Int, false } + // compare this matrix with itself, but with the rows and columns permuted according to "perm" + def compareWithPerm(perm: Vector[Int]): Int = { + for (i <- 0 until size) + for (j <- 0 to i) + if (mat(i)(j) < mat(perm(i))(perm(j))) return -1 + else if (mat(i)(j) > mat(perm(i))(perm(j))) return 1 + 0 + } + // returns true if all boundaries are connected to something def isComplete: Boolean = (0 until numBoundaries).forall(i => mat(i).contains(true)) @@ -248,7 +249,7 @@ object AdjMat { case _ => Vector(Vector()) } - def fromHash(hash: String): AdjMat = { + def fromHash(hash: String): AdjMat = { // "boundaries.red1-red2.green1-green2.matBase36" val dotChunk = hash.split("\\.") val numBoundaries = dotChunk(0).toInt @@ -256,10 +257,10 @@ object AdjMat { val green: Vector[Int] = dotChunk(2).split("-").map(a => a.toInt).toVector val size = numBoundaries + red.sum + green.sum val longMatStringUnpadded = java.lang.Long.toString(java.lang.Long.parseLong(dotChunk(3), 36), 2) - val longMatString = (1 to size * size - longMatStringUnpadded.length).foldLeft("") {(a,b) => "0" + a} + + val longMatString = (1 to size * size - longMatStringUnpadded.length).foldLeft("") { (a, _) => "0" + a } + longMatStringUnpadded val longMatVec = longMatString.map(x => x == '1').toVector - val mat = if(size > 0) longMatVec.grouped(size).toVector else List().toVector + val mat = if (size > 0) longMatVec.grouped(size).toVector else List().toVector new AdjMat(red.length, green.length, numBoundaries, red, green, mat) } } diff --git a/scala/src/main/scala/quanto/cosy/EQCAnalysis.scala b/scala/src/main/scala/quanto/cosy/EQCAnalysis.scala index 24028e90..5804cf9f 100644 --- a/scala/src/main/scala/quanto/cosy/EQCAnalysis.scala +++ b/scala/src/main/scala/quanto/cosy/EQCAnalysis.scala @@ -1,7 +1,5 @@ package quanto.cosy -import quanto.data._ -import quanto.util.Rational /** * Analyse the properties of equivalence classes diff --git a/scala/src/main/scala/quanto/cosy/EquivClassBatchRunner.scala b/scala/src/main/scala/quanto/cosy/EquivClassBatchRunner.scala index e8e4c9ce..2f83a2e3 100644 --- a/scala/src/main/scala/quanto/cosy/EquivClassBatchRunner.scala +++ b/scala/src/main/scala/quanto/cosy/EquivClassBatchRunner.scala @@ -8,6 +8,7 @@ import quanto.util.json.{JsonArray, JsonObject} /** * Created by hector on 20/06/2017. * A wrapper object for the most common qsynth run + * Superseded by the CoSyRun system */ object EquivClassBatchRunner { @@ -15,7 +16,7 @@ object EquivClassBatchRunner { def apply(numAngles: Int = 8, boundaries: Int = 3, vertices: Int, outputFileName: String = "default.qrun"): Unit = { val rg = Theory.fromFile("red_green") - var results = EquivClassRunAdjMat( + val results = EquivClassRunAdjMat( numAngles = numAngles, tolerance = EquivClassRunAdjMat.defaultTolerance, rulesList = List(), @@ -26,7 +27,7 @@ object EquivClassBatchRunner { s"ColbournRead $numAngles $numAngles $boundaries $vertices") new File(outputPath).mkdirs() - var testFile = new File(outputPath + "/" + outputFileName) + val testFile = new File(outputPath + "/" + outputFileName) quanto.util.FileHelper.printToFile(testFile, append = false)( p => p.println(results.toJSON.toString()) ) @@ -43,7 +44,7 @@ object TensorBatchRunner { val rg = Theory.fromFile("red_green") val diagramStream = ColbournReadEnum.enumerate(numAngles, numAngles, boundaries, vertices) - var results = EquivClassRunAdjMat( + val results = EquivClassRunAdjMat( numAngles = numAngles, tolerance = EquivClassRunAdjMat.defaultTolerance, rulesList = List(), diff --git a/scala/src/main/scala/quanto/cosy/EquivalenceClasses.scala b/scala/src/main/scala/quanto/cosy/EquivalenceClasses.scala index 212b93ee..78e2edf3 100644 --- a/scala/src/main/scala/quanto/cosy/EquivalenceClasses.scala +++ b/scala/src/main/scala/quanto/cosy/EquivalenceClasses.scala @@ -2,14 +2,13 @@ package quanto.cosy import quanto.cosy.Interpreter._ import quanto.data._ +import quanto.util.Rational import quanto.util.json.{Json, JsonAccessException, JsonArray, JsonObject} -import scala.concurrent.{Await, Future} -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.duration.Duration - /** * Synthesises diagrams, holds the data and generates equivalence classes + * + * Everything here has essentially be superseded by the new CoSyRun system */ @@ -167,6 +166,11 @@ abstract class EquivClassRun[T](val tolerance: Double = 1e-14) { this } + def add(that: T, tensor: Tensor): EquivalenceClass[T] = { + compareAndAddToClass(that, tensor) + compareAndAddToClass(that, tensor, normalised = true) + } + // Finds the closest class and adds it, or creates a new class if outside tolerance def compareAndAddToClass(candidate: T, tensor: Tensor, normalised: Boolean = false): EquivalenceClass[T] = { @@ -190,6 +194,7 @@ abstract class EquivClassRun[T](val tolerance: Double = 1e-14) { } } + implicit def interpret(that: T): Tensor // Returns (new Class, -1) if no EquivalenceClasses here def closestClassTo(that: Tensor, normalised: Boolean = false): Option[(EquivalenceClass[T], Double)] = { @@ -204,18 +209,11 @@ abstract class EquivClassRun[T](val tolerance: Double = 1e-14) { } - implicit def interpret(that: T): Tensor - def add(that: T): EquivalenceClass[T] = { val tensor = interpret(that) add(that, tensor) } - def add(that: T, tensor: Tensor): EquivalenceClass[T] = { - compareAndAddToClass(that, tensor) - compareAndAddToClass(that, tensor, normalised = true) - } - def newClass(t: T, tensor: Tensor, normalised: Boolean): EquivalenceClass[T] } @@ -299,7 +297,7 @@ object EquivClassRunBlockStack { Tensor.fromJson((js / "tensor").asObject)) ) val ecrr = new EquivClassRunBlockStack(tolerance) - for ((adj, ten) <- results) { + for ((adj, _) <- results) { ecrr.add(adj) } ecrr @@ -365,7 +363,9 @@ object EquivClassRunAdjMat { tolerance: Double, theory: Theory, rulesList: List[Rule]): EquivClassRunAdjMat = { - def angleMap = (x: Int) => x * math.Pi * 2.0 / numAngles + def angleMap(x: Int): Rational = { + new Rational(x, numAngles) + } val gdata = (for (i <- 0 until numAngles) yield { NodeV(data = JsonObject("type" -> "Z", "value" -> angleMap(i).toString), theory = theory) diff --git a/scala/src/main/scala/quanto/cosy/GraphAnalysis.scala b/scala/src/main/scala/quanto/cosy/GraphAnalysis.scala index d62eea80..9a1f3108 100644 --- a/scala/src/main/scala/quanto/cosy/GraphAnalysis.scala +++ b/scala/src/main/scala/quanto/cosy/GraphAnalysis.scala @@ -1,9 +1,230 @@ package quanto.cosy +import quanto.data.Theory.{ValueType, VertexDesc} import quanto.data._ +import quanto.rewrite.Matcher import quanto.util.Rational +import quanto.util.json.{Json, JsonObject} + +import scala.util.matching.Regex +import scala.util.parsing.combinator.RegexParsers object GraphAnalysis { + + def zxCircuitCompare(left: Graph, right: Graph): Int = { + + // returns x where + // x < 0 iff this < that + // x == 0 iff this == that + // x > 0 iff this > that + + // Circuit comparison of graphs + // Cares about T-count + + val Angle = ValueType.AngleExpr + + def phase(vdata: VData): PhaseExpression = + vdata match { + case NodeV(d, a, t) => vdata.asInstanceOf[NodeV].phaseData.first[PhaseExpression](ValueType.AngleExpr) match { + case Some(p) => p + case None => PhaseExpression.zero(ValueType.AngleExpr) + } + case _ => PhaseExpression.zero(ValueType.AngleExpr) + } + + + // Number of T-gates + def countT(graph: Graph): Int = graph.vdata.count(nd => { + val const = phase(nd._2).constant + // T-gate if an odd multiple of \pi/4 + ((const.n * (const.d / 4)) % 2) == 1 + }) + + val tDiff = countT(left) - countT(right) + if (tDiff != 0) return tDiff + + // Number of nodes + def nodes(graph: Graph): Int = graph.vdata.size + + val nodeDiff = nodes(left) - nodes(right) + if (nodeDiff != 0) return nodeDiff + + // Number of edges + def edges(graph: Graph): Int = graph.edata.size + + val edgeDiff = edges(left) - edges(right) + if (edgeDiff != 0) return edgeDiff + + // Number of "Z" nodes + // We favour these! + // Purely for aesthetic reasons + def countZ(graph: Graph): Int = graph.vdata.count(nd => nd._2.typ == "Z") + + val zDiff = countZ(left) - countZ(right) + if (zDiff != 0) return zDiff + + // sum of the phases + def phaseSum(graph: Graph): Rational = graph.vdata.map(nd => phase(nd._2)). + foldLeft(Rational(0, 1)) { (s, a) => s + a.constant } + + val phaseDiff = phaseSum(left) - phaseSum(right) + if (phaseDiff > 0) return 1 + if (phaseDiff < 0) return -1 + + // Weighting by row + // example node name: r-2-bl-1-h-1 + def positionWeighting(graph: Graph): Double = { + def nodeWeighting(node: NodeV): Double = { + (node.typ match { + case "X" => + 2 * (1 + node.phaseData.values.head.constant) + case "Z" => + 1 + node.phaseData.values.head.constant + case "hadamard" => + 1 + }) / 3 // Scale so any given node has weight at most 1, bust still > 0 + } + + graph.vdata.toList.map(nameData => { + nameData._2 match { + case v: NodeV => + // Non-wire nodes are weighted biased towards the bottom left + val placement = CircuitPlacementParser.p(nameData._1.toString) + placement._1 + placement._2 + nodeWeighting(v) + case _ => + // wires have no weight + 0 + } + }).sum + } + + val circuitWeightLeft = positionWeighting(left) + val circuitWeightRight = positionWeighting(right) + if (circuitWeightLeft > circuitWeightRight) { + return 1 + } + if (circuitWeightLeft < circuitWeightRight) { + return -1 + } + + 0 + } + + case class CircuitPlacementParseException(input: String) extends Error + + object CircuitPlacementParser extends RegexParsers { + + override def skipWhitespace = true + + def INT: Parser[Int] = + """[0-9]+""".r ^^ { + _.toInt + } + + def IDENT: Parser[String] = + """[\\a-zA-Z_][a-zA-Z0-9_]*""".r ^^ { + _.toString + } + + // example node name: r-2-bl-1-h-1 + def expr: Parser[(Int, Int, String)] = + "r-" ~ INT ~ "-bl-" ~ INT ~ "-" ~ IDENT ~ "-" ~ INT ^^ { case _ ~ r ~ _ ~ bl ~ _ ~ s ~ _ ~ _ => (r, bl, s) } + + def p(s: String): (Int, Int, String) = parseAll(expr, s) match { + case Success(e, _) => e + case Failure(msg, _) => throw CircuitPlacementParseException(msg) + case Error(msg, _) => throw CircuitPlacementParseException(msg) + } + } + + + def zxGraphCompare(left: Graph, right: Graph): Int = { + + // returns x where + // x < 0 iff this < that + // x == 0 iff this == that + // x > 0 iff this > that + + + // Graph comparison for ZX diagrams + // Cares about node count, then phases + + implicit def stringToPhase(s: String): PhaseExpression = { + PhaseExpression.parse(s, ValueType.AngleExpr) + } + + + // First count number of nodes + def nodes(graph: Graph): Int = graph.vdata.size + + val node = nodes(left) - nodes(right) + if (node != 0) return node + + // Number of edges + def edges(graph: Graph): Int = graph.edata.size + + val edge = edges(left) - edges(right) + if (edge != 0) return edge + + // Number of "Z" nodes + def countZ(graph: Graph): Int = graph.vdata.count(nd => nd._2.typ == "Z") + + val zDiff = countZ(left) - countZ(right) + if (zDiff != 0) return zDiff + + // Sum of Z angles + val Pi = math.Pi + + def sumAngles(graph: Graph, filterType: String): Rational = graph.vdata. + filter(nd => nd._2.typ == filterType). + foldLeft(Rational(0, 1)) { + (angle, nd) => angle + stringToPhase(nd._2.asInstanceOf[NodeV].value).constant + } + + // sumAngles returns a rational that is probably bigger than 2 (remember that the pi is left out) + + val ZAngles: Rational = sumAngles(left, "Z") - sumAngles(right, "Z") + if (ZAngles > 0) return 1 + if (ZAngles < 0) return -1 + + // Sum of X angles + + val XAngles: Rational = sumAngles(left, "X") - sumAngles(right, "X") + if (XAngles > 0) return 1 + if (XAngles < 0) return -1 + + + 0 + } + + def connectionClasses(adjMat: AdjMat): Vector[Int] = { + val initialVector = (0 until adjMat.size).toVector + + def join(v: Vector[Int], i: Int, j: Int): Vector[Int] = { + v.map { a => if (a == v(i) || a == v(j)) { + math.min(v(i), v(j)) + } else a + } + } + + adjMat.mat.zipWithIndex.flatMap(rowWithIndex => { + + val rci = rowWithIndex._1.zipWithIndex + rci.map(bi => (bi._1, bi._2, rowWithIndex._2)) + }).filter(_._1).map(bii => (bii._2, bii._3)).foldLeft(initialVector) { + (v, ii) => join(v, ii._1, ii._2) + } + } + + def containsScalars(adjMat: AdjMat): Boolean = { + // Boundaries are at the front of the adjmat, and are given labels first + // So if any class still has labels higher than the number of boundaries + // it can't be connected to any of the boundaries + val cc = connectionClasses(adjMat) + !cc.forall(_ < adjMat.numBoundaries) + } + + type BMatrix = Vector[Vector[Boolean]] def tensorToBooleanMatrix(tensor: Tensor): BMatrix = { @@ -20,39 +241,40 @@ object GraphAnalysis { def namesToIndex(name: VName) = vertexList.indexOf(name) val targets = ends.map(namesToIndex).toSet - val errorNames = detectErrors(graph) + val errorNames = detectPiNodes(graph) val errors = errorNames.map(namesToIndex) val rawAdjacencyMatrix = adjacencyMatrix(graph) - val bypassedAdjacencyMatrix = bypassSpecial(detectErrors)(graph, rawAdjacencyMatrix) + val bypassedAdjacencyMatrix = bypassSpecial(detectPiNodes)(graph, rawAdjacencyMatrix) pathDistanceSet(bypassedAdjacencyMatrix, errors, targets) match { case None => None - case Some(d) => Some(d-1) //account for errors being over-counted in standard distance + case Some(d) => Some(d - 1) //account for errors being over-counted in standard distance } } - def distanceOfSingleErrorFromEnd(ends: Set[VName])(graph : Graph, vNames : Set[VName]) : Option[Double] = { + def distanceOfSingleErrorFromEnd(ends: Set[VName])(graph: Graph, vNames: Set[VName]): Option[Double] = { val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(graph) val vertexList = graph.verts.toList def namesToIndex(name: VName) = vertexList.indexOf(name) - val targets = ends.map(namesToIndex).toSet - val errorNames = detectErrors(graph) + val targets = ends.map(namesToIndex) + // val errorNames = detectPiNodes(graph) - val errors = errorNames.map(namesToIndex) - val bypassedAdjacencyMatrix = bypassSpecial(detectErrors)(graph, adjacencyMatrix) + // val errors = errorNames.map(namesToIndex) + // val bypassedAdjacencyMatrix = bypassSpecial(detectPiNodes)(graph, adjacencyMatrix) + implicit def nameToInt(name: VName): Int = { adjacencyMatrix._1.indexOf(name) } GraphAnalysis.pathDistanceSet(adjacencyMatrix, vNames.map(nameToInt), targets) match { case None => None - case Some(d) => Some(d-1) + case Some(d) => Some(d - 1) } } @@ -81,48 +303,12 @@ object GraphAnalysis { (matrixWithNames._1, bypassedMatrix) } - def detectErrors(graph: Graph): Set[VName] = { + def detectPiNodes(graph: Graph): Set[VName] = { graph.verts. filterNot(name => graph.vdata(name).isBoundary || graph.vdata(name).isWireVertex). - filter(name => graph.vdata(name).asInstanceOf[NodeV].angle.equals(AngleExpression(Rational(1)))) - } - - def distanceSpecialFromEnds(specials: List[VName])(ends: List[VName])(graph: Graph): Option[Double] = { - - val vertexList = graph.verts.toList - - def namesToIndex(name: VName) = vertexList.indexOf(name) - - val targets = ends.map(namesToIndex).toSet - // Called errors for historical reasons - val errors = specials.map(namesToIndex).toSet - val aMatrix = adjacencyMatrix(graph) - pathDistanceSet(aMatrix, errors, targets) - } - - def pathConnectionMatrices(graph: Graph): List[(Int, Tensor)] = { - val matrixWithNames = adjacencyMatrix(graph) - pathConnectionMatrices(matrixWithNames) - } - - def pathConnectionMatrices(matrixWithNames: (List[VName], BMatrix)): List[(Int, Tensor)] = { - val adjTensor = booleanMatrixToTensor(matrixWithNames._2) - - var rollingPower = Tensor.id(adjTensor.width) - (for (i <- matrixWithNames._1.indices) yield { - (i, { - rollingPower = rollingPower.compose(adjTensor) - rollingPower - }) - }).toList - } - - def booleanMatrixToTensor(bMatrix: BMatrix): Tensor = { - def complexify(v: Vector[Boolean]): Array[Complex] = v.map(b => if (b) Complex(1, 0) else Complex(0, 0)).toArray - - def complexify2(v: Vector[Vector[Boolean]]): Array[Array[Complex]] = v.map(b => complexify(b)).toArray - - Tensor(complexify2(bMatrix)) + filter(name => graph.vdata(name).asInstanceOf[NodeV].phaseData. + firstOrError(ValueType.AngleExpr). + equals(PhaseExpression(Rational(1), ValueType.AngleExpr))) // Pull out those angle expressions with value \pi } def adjacencyMatrix(graph: Graph): (List[VName], BMatrix) = { @@ -135,12 +321,24 @@ object GraphAnalysis { (vertexNames, vertexNames.foldLeft(Vector[Vector[Boolean]]())((vs, v) => vs :+ setToVector(graph.adjacentVerts(v)))) } + def pathDistanceSet(matrixWithNames: (List[VName], BMatrix), + distancesToFind: Set[Int], + measuredFrom: Set[Int]): + Option[Double] = { - def neighbours(matrixWithNames: (List[VName], BMatrix), target: VName) : Set[VName] = { val names = matrixWithNames._1 - val matrix = matrixWithNames._2 - val index = names.indexOf(target) - matrix(index).zipWithIndex.filter(_._1).map(_._2).map(names(_)).toSet + // val matrix = matrixWithNames._2 + + val distances = distancesFromInitial(matrixWithNames, Set(), measuredFrom.map(i => names(i))) + + val importantDistances = distances.filter(nd => distancesToFind.contains(names.indexOf(nd._1))) + + if (importantDistances.nonEmpty) { + val d = intsWithCount(importantDistances.values.toList) + if (d < 0) None else Some(d) + } else { + None + } } def distancesFromInitial(matrixWithNames: (List[VName], BMatrix), ignoring: Set[VName], initials: Set[VName]): @@ -177,7 +375,6 @@ object GraphAnalysis { names.zipWithIndex.map(name_index => name_index._1 -> finalDistances(name_index._2)).toMap } - private def intsWithCount(ints: List[Int]): Double = { val max = ints.max val count = max match { @@ -188,23 +385,147 @@ object GraphAnalysis { max + ((count - 1).toDouble / count.toDouble) } - def pathDistanceSet(matrixWithNames : (List[VName], BMatrix), - distancesToFind: Set[Int], - measuredFrom: Set[Int]): - Option[Double] = { + def distanceSpecialFromEnds(specials: List[VName])(ends: List[VName])(graph: Graph): Option[Double] = { + + val vertexList = graph.verts.toList + + def namesToIndex(name: VName) = vertexList.indexOf(name) + + val targets = ends.map(namesToIndex).toSet + // Called errors for historical reasons + val errors = specials.map(namesToIndex).toSet + val aMatrix = adjacencyMatrix(graph) + pathDistanceSet(aMatrix, errors, targets) + } + + def pathConnectionMatrices(graph: Graph): List[(Int, Tensor)] = { + val matrixWithNames = adjacencyMatrix(graph) + pathConnectionMatrices(matrixWithNames) + } + + def pathConnectionMatrices(matrixWithNames: (List[VName], BMatrix)): List[(Int, Tensor)] = { + val adjTensor = booleanMatrixToTensor(matrixWithNames._2) + + var rollingPower = Tensor.id(adjTensor.width) + (for (i <- matrixWithNames._1.indices) yield { + (i, { + rollingPower = rollingPower.compose(adjTensor) + rollingPower + }) + }).toList + } + + def booleanMatrixToTensor(bMatrix: BMatrix): Tensor = { + def complexify(v: Vector[Boolean]): Array[Complex] = v.map(b => if (b) Complex(1, 0) else Complex(0, 0)).toArray + + def complexify2(v: Vector[Vector[Boolean]]): Array[Array[Complex]] = v.map(b => complexify(b)).toArray + Tensor(complexify2(bMatrix)) + } + + def neighbours(matrixWithNames: (List[VName], BMatrix), target: VName): Set[VName] = { val names = matrixWithNames._1 - // val matrix = matrixWithNames._2 + val matrix = matrixWithNames._2 + val index = names.indexOf(target) + matrix(index).zipWithIndex.filter(_._1).map(_._2).map(names(_)).toSet + } - val distances = distancesFromInitial(matrixWithNames, Set(), measuredFrom.map(i => names(i))) + def boundariesFromRegex(graph: Graph, regex: Option[Regex]) : Set[VName] = { + regex match { + case Some(r) => graph.verts.filter(vn => vn.s.matches(r.regex)) + case None => graph.verts.filter(vn => graph.vdata(vn).isBoundary) + } + } - val importantDistances = distances.filter(nd => distancesToFind.contains(names.indexOf(nd._1))) + def checkIsomorphic(theory: Theory = Theory.DefaultTheory, boundaryByRegex: Option[Regex] = None) + (g1: Graph, g2: Graph): Boolean = { + // Check whether graphs are isomorphic, after constraining their boundaries - if (importantDistances.nonEmpty) { - val d = intsWithCount(importantDistances.values.toList) - if(d < 0) None else Some(d) + if(g1.verts.size != g2.verts.size) return false + + // Currently the .isBoundary on wires is set by JSON, not programmatically, and as such I don't trust it. + def borderNodes(g: Graph): Set[VName] = boundariesFromRegex(g, boundaryByRegex) + + val overlappingNodes = borderNodes(g1).intersect(borderNodes(g2)) + + //First check they have the same number of boundaries, and the same names between them. + + if(overlappingNodes != borderNodes(g1) || overlappingNodes != borderNodes(g2)){ + return false + } + + val dummyVertexDesc = VertexDesc.fromJson( + Json.parse("""{ + | "value": { + | "type": "empty", + | "latex_constants": true, + | "validate_with_core": false + | }, + | "style": { + | "label": { + | "position": "center", + | "fg_color": [ + | 0.0, + | 0.0, + | 0.0 + | ] + | }, + | "stroke_color": [ + | 0.0, + | 0.0, + | 0.0 + | ], + | "fill_color": [ + | 0.0, + | 1.0, + | 1.0 + | ], + | "shape": "rectangle" + | }, + | "default_data": { + | "type": "dummyBoundary", + | "value": "" + | } + | } """.stripMargin)) + + def enforceUnique(s: String) : String = if(theory.vertexTypes.keys.exists(_ == s)){ + enforceUnique(s + "1") } else { - None + s + } + val dummyVertexName: String = enforceUnique("dummyBoundary") + val dummyTheory = theory.mixin(newVertexTypes = Map(dummyVertexName -> dummyVertexDesc)) + val boundaryData = NodeV(JsonObject( + "type" -> dummyVertexName, + "value" -> "" + ), + JsonObject(), + dummyTheory) + + def makeSolidBoundaries(graph: Graph): Graph = { + overlappingNodes.foldLeft(graph.copy(data = graph.data.copy(theory = dummyTheory))) { (g, vn) => + g.updateVData(vn) { _ => boundaryData } + } } + + val solid1 = makeSolidBoundaries(g1) + val solid2 = makeSolidBoundaries(g2) + + // The graphs are isomorphic if there are matches in both directions that are the identity on boundaries + + val matches12 = Matcher.findMatches(solid1, solid2).filter( + m => overlappingNodes.forall( + vn => m.map.v.dom.toList.contains(vn) && m.map.v.domf(vn).contains(vn) + ) + ) + + val matches21 = Matcher.findMatches(solid2, solid1).filter( + m => overlappingNodes.forall( + vn => m.map.v.dom.toList.contains(vn) && m.map.v.domf(vn).contains(vn) + ) + ) + + matches21.nonEmpty && matches12.nonEmpty } + } \ No newline at end of file diff --git a/scala/src/main/scala/quanto/cosy/Interpreter.scala b/scala/src/main/scala/quanto/cosy/Interpreter.scala index 7d634e65..49fa2143 100644 --- a/scala/src/main/scala/quanto/cosy/Interpreter.scala +++ b/scala/src/main/scala/quanto/cosy/Interpreter.scala @@ -1,6 +1,8 @@ package quanto.cosy +import quanto.data.Theory.ValueType import quanto.data._ +import quanto.util.json.JsonObject /** * An interpreter is given a diagram (as an adjMat and variable assignment) and returns a tensor @@ -8,8 +10,7 @@ import quanto.data._ object Interpreter { // Converts a given graph or spider into tensor form - type cachedSpiders = collection.mutable.Map[String, Tensor] - type AngleMap = Int => Double + private type cachedSpiders = collection.mutable.Map[String, Tensor] val cached: cachedSpiders = collection.mutable.Map.empty[String, Tensor] def makeHadamards(n: Int, current: Tensor = Tensor.id(1)): Tensor = n match { @@ -18,52 +19,67 @@ object Interpreter { case _ => Tensor.hadamard x makeHadamards(n - 1, current) } - def interpretZXSpider(green: Boolean, angle: Double, inputs: Int, outputs: Int): Tensor = { + def interpretZXSpider(zxAngleData: ZXAngleData, inputs: Int, outputs: Int): Tensor = { // Converts spider to tensor. If green==false then it is a red spider - val toString = "ZX:" + green.toString + ":" + angle + ":" + inputs + ":" + outputs + + val colour = if (zxAngleData.isGreen) { + "green" + } else { + "red" + } + val angle = zxAngleData.angle.constant + val toString = s"ZX:$colour:$angle:$inputs:$outputs" + + if (cached.contains(toString)) cached(toString) else { def gen(i: Int, j: Int): Complex = { Complex.zero + (if (i == 0 && j == 0) Complex.one else Complex.zero) + (if (i == math.pow(2, outputs) - 1 && j == math.pow(2, inputs) - 1) - Complex(math.cos(angle), math.sin(angle)) else Complex.zero) + Complex(math.cos(angle * math.Pi), math.sin(angle * math.Pi)) else Complex.zero) } val mid = Tensor(math.pow(2, outputs).toInt, math.pow(2, inputs).toInt, gen) - val spider = if (green) mid else makeHadamards(outputs) o mid o makeHadamards(inputs) + val spider = if (zxAngleData.isGreen) mid else makeHadamards(outputs) o mid o makeHadamards(inputs) cached += (toString -> spider) spider } } - def interpretZWSpider(black: Boolean, outputs: Int): Tensor = { + + implicit def stringToPhase(s: String): PhaseExpression = { + PhaseExpression.parse(s, ValueType.AngleExpr) + } + + // ASSUME EVERYTHING HAS ANGLE DATA + implicit def pullOutAngleData(composite: CompositeExpression): PhaseExpression = { + composite.firstOrError(ValueType.AngleExpr) + } + + def interpretZWSpiderNoInputs(black: Boolean, outputs: Int): Tensor = { require(outputs >= 0) val toString = "ZW:" + black.toString + ":" + outputs val spider = if (cached.contains(toString)) cached(toString) else { - black match { - case true => - // Black spider - outputs match { - case 0 => Tensor(Array(Array(0))) - case 1 => Tensor(Array(Array(0, 1))).transpose - case 2 => Tensor(Array(Array(0, 1, 1, 0))).transpose - case 3 => Tensor(Array(Array(0, 1, 1, 0, 1, 0, 0, 0))).transpose - case _ => - val bY = Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 0))).transpose - val base = interpretZWSpider(black, outputs - 1) - (Tensor.idWires(outputs - 2) x bY) o base - } - - case false => - // White spider - outputs match { - case 0 => Tensor(Array(Array(0))) - case _ => - Tensor(1, - Math.pow(2, outputs).toInt, - (i, j) => (if (j == 0) Complex.one else Complex.zero) - - (if (j == Math.pow(2, outputs) - 1) Complex.one else Complex.zero)).transpose - } + if (black) { + outputs match { + case 0 => Tensor(Array(Array(0))) + case 1 => Tensor(Array(Array(0, 1))).transpose + case 2 => Tensor(Array(Array(0, 1, 1, 0))).transpose + case 3 => Tensor(Array(Array(0, 1, 1, 0, 1, 0, 0, 0))).transpose + case _ => + val bY = Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 0))).transpose + val base = interpretZWSpiderNoInputs(black, outputs - 1) + (Tensor.idWires(outputs - 2) x bY) o base + } + } else { + outputs match { + case 0 => Tensor(Array(Array(0))) + case _ => + Tensor(1, + Math.pow(2, outputs).toInt, + (_, j) => (if (j == 0) Complex.one else Complex.zero) - + (if (j == Math.pow(2, outputs) - 1) Complex.one else Complex.zero)).transpose + } } } @@ -75,7 +91,240 @@ object Interpreter { interpretZXAdjMatSpidersFirst(adjMat, greenAM, redAM) } - def interpretZWAdjMat(adjMat: AdjMat): Tensor = interpretZWAdjMatSpidersFirst(adjMat) + def zwSpiderInterpreter(nodeV: NodeV, inputs: Int, outputs: Int) : Tensor = { + (inputs, outputs) match { + case (0, 0) => Tensor.id(1) + case (0, n) => interpretZWSpiderNoInputs( nodeV.typ.toLowerCase == "w",n) + case (m, n) => + (Tensor.idWires(n) x Tensor(Array(Array(1, 0, 0, 1)))) o + (zwSpiderInterpreter(nodeV, m-1, n+1) x Tensor.idWires(1)) + } + } + + def interpretZWAdjMat(adjMat: AdjMat, inputs: List[VName], outputs: List[VName]): Tensor = { + val blackNodes = Vector( + NodeV( + JsonObject( + "type" -> "w", + "value" -> "" + ) + ) + ) + val whiteNodes = Vector( + NodeV( + JsonObject( + "type" -> "z", + "value" -> "1" + ) + ) + ) + val graph = Graph.fromAdjMat(adjMat, blackNodes, whiteNodes) + interpretSpiderGraph(zwSpiderInterpreter)(graph, inputs, outputs) + } + + def interpretSpiderGraph(spiderInterpreter: (NodeV, Int, Int) => Tensor)(graph: Graph, inputList: List[VName], outputList: List[VName]): Tensor = { + + // remove any wire vertices etc + val minGraph = graph.minimise + + minGraph.vdata.count(nd => !nd._2.isWireVertex) match { + case 0 => + stringGraph(minGraph, interpretZXSpider( + ZXAngleData(isGreen = true, PhaseExpression.zero(ValueType.AngleExpr)), 2, 0), + inputList, + outputList + ) + case _ => + pullOutVertexGraph(minGraph, inputList, outputList, interpretSpiderGraph(spiderInterpreter), spiderInterpreter) + } + } + + def zxSpiderInterpreter(vdata: NodeV, inputs: Int, outputs: Int): Tensor = { + + val zxData: ZXAngleData = { + val isGreen = vdata.typ == "Z" + val angle = PhaseExpression.parse(vdata.value, ValueType.AngleExpr) + ZXAngleData(isGreen, angle) + } + + interpretZXSpider(zxData, inputs, outputs) + } + + val interpretZXGraph : (Graph, List[VName], List[VName]) => Tensor = interpretSpiderGraph(zxSpiderInterpreter) + + /** + * Given a graph of just boundaries and edges, + * return the tensor treating those edges as caps + * + * @param graph : Graph + * @return + */ + def stringGraph(graph: Graph, cap: Tensor, inputList: List[VName], outputList: List[VName]): Tensor = { + if (graph.verts.size % 2 != 0) { + throw new Error("String graph should have an even number of boundaries") + } + + // GRAPHS ARE READ BOTTOM TO TOP HERE. + + + // Errors here are from node names appearing in the input or output list but not in the graph + val numInternalInputs = inputList.count(vn => { + inputList.contains(graph.adjacentVerts(vn).head) + }) + val numInternalOutputs = outputList.count(vn => { + outputList.contains(graph.adjacentVerts(vn).head) + }) + + val joinLowerList = inputList.filterNot(vn => inputList.contains(graph.adjacentVerts(vn).head)) + val joinUpperList = outputList.filterNot(vn => outputList.contains(graph.adjacentVerts(vn).head)) + + var claimedInternalCaps: List[VName] = List() + var claimedInternalCups: List[VName] = List() + + val numJoin = (inputList.size + outputList.size - numInternalInputs - numInternalOutputs) / 2 + + def toLowerNeighbour(place: Int): Int = { + val name: VName = inputList(place) + val neighbour: VName = graph.adjacentVerts(name).head + if (inputList.contains(neighbour) && inputList.contains(name)) { + val index = claimedInternalCaps.indexOf(name) + if (index < 0) { + claimedInternalCaps = claimedInternalCaps :+ name + claimedInternalCaps = claimedInternalCaps :+ neighbour + toLowerNeighbour(place) + } else { + index + } + } else { + numInternalInputs + joinLowerList.indexOf(name) + } + } + + + def toUpperNeighbour(place: Int): Int = { + val name: VName = outputList(place) + val neighbour: VName = graph.adjacentVerts(name).head + if (outputList.contains(neighbour) && outputList.contains(name)) { + val index = claimedInternalCups.indexOf(name) + if (index < 0) { + claimedInternalCups = claimedInternalCups :+ name + claimedInternalCups = claimedInternalCups :+ neighbour + toUpperNeighbour(place) + } else { + index + } + } else { + numInternalOutputs + joinUpperList.indexOf(name) + } + } + + def toJoin(place: Int) : Int = { + val name: VName = joinLowerList(place) + val neighbour: VName = graph.adjacentVerts(name).head + joinUpperList.indexOf(neighbour) + } + + val caps = cap.power(numInternalInputs / 2) + val cups = cap.transpose.power(numInternalOutputs / 2) + + val sigmaLower = if (inputList.nonEmpty) { + Tensor.swap(inputList.length, toLowerNeighbour) + } else { + Tensor(Array(Array(Complex(1, 0)))) + } + + val sigmaUpper = if (outputList.nonEmpty) { + Tensor.swap(outputList.length, toUpperNeighbour).transpose + } else { + Tensor(Array(Array(Complex(1, 0)))) + } + + val sigmaMid = if (joinLowerList.nonEmpty) { + Tensor.swap(joinLowerList.length, toJoin) + } else { + Tensor(Array(Array(Complex(1, 0)))) + } + + sigmaUpper o (cups x Tensor.idWires(numJoin)) o sigmaMid o (caps x Tensor.idWires(numJoin)) o sigmaLower + } + + private def pullOutVertexGraph(startingGraph: Graph, + inputList: List[VName], + outputList: List[VName], + graphInterpreter: (Graph, List[VName], List[VName]) => Tensor, + spiderInterpreter: (NodeV, Int, Int) => Tensor): Tensor = { + def ratioDanglingWires(name: VName): Double = { + // A measure of how many boundary vs non-boundary neighbours the node has + val numBoundary = startingGraph.adjacentVerts(name).count(vn => inputList.contains(vn)) + (1 + numBoundary).toDouble / (1 + startingGraph.adjacentVerts(name).size - numBoundary) + } + + // def boundaries(g: Graph): Set[VName] = g.verts.filter(g.isTerminalWire) + + // Pick a vertex to shift from graph to tensor + val cutVertex = startingGraph.verts + .filterNot(vn => startingGraph.vdata(vn).isWireVertex) + .maxBy(ratioDanglingWires) + val numCutSpiderInOuts = startingGraph.adjacentVerts(cutVertex).size + + // cut out that vertex, as well as any boundaries it was attached to + val (reducedGraph, freshMadeBoundaries, removedBoundaries) = startingGraph.cutVertex(cutVertex, inputList.toSet) + val uncutBoundariesVector: Vector[VName] = + inputList.filter(b => !removedBoundaries.contains(b)).toVector + val reducedGraphBoundaries = uncutBoundariesVector.toList ++ freshMadeBoundaries.toList + + val spiderTensor = spiderInterpreter( + startingGraph.vdata(cutVertex).asInstanceOf[NodeV], + numCutSpiderInOuts - freshMadeBoundaries.size, + freshMadeBoundaries.size) + + val reducedGraphBoundariesVector: Vector[VName] = reducedGraphBoundaries.toVector + + val bottomSigma: Tensor = { + val swapList: List[Int] = { + var leftCount = 0 + var rightCount = (reducedGraphBoundaries.toSet -- freshMadeBoundaries).size + inputList.indices.map(i => + if (reducedGraphBoundaries.contains(inputList(i))) { + leftCount += 1 + leftCount - 1 + } else { + rightCount += 1 + rightCount - 1 + } + ).toList + } + Tensor.swap(inputList.size, swapList) + } + + /* + val topSigma: Tensor = { + val inputs: Vector[VName] = uncutBoundariesVector ++ freshMadeBoundaries.toVector + Tensor.swap(reducedGraphBoundaries.size, + i => reducedGraphBoundariesVector.indexOf(inputs(i))) + } + */ + + /** + * Starting with the graph G, pulling out the spider S + * You want: + * + * [ G' ] + * [id] x [S] + * [ botSigma ] + * + * Where G' is the reduced graph, with vector of inputs: reducedGraphBoundariesVector + * S is the cut spider, + * the identity is on uncutBoundariesVector + * And the bottom sigma acts on allInputsVector + * (Rather than having topSigma now just have different lis tof inputs for G') + */ + + graphInterpreter(reducedGraph, reducedGraphBoundaries, outputList) o + // topSigma o + (Tensor.idWires(uncutBoundariesVector.size) x spiderTensor) o + bottomSigma + } private def interpretAdjMat(adj: AdjMat, join: Tensor, vertexToTensor: (Int) => Tensor): Tensor = { // Interpret the graph as (caps) o (crossings) o (vertices) @@ -113,7 +362,7 @@ object Interpreter { connectionList = connectionList ++ List(claimLeg(i), claimLeg(j)) } } - connectingCaps.plugAbove(allSpidersTensors, connectionList) + connectingCaps o Tensor.swap(connectionList).transpose o allSpidersTensors } } @@ -124,27 +373,25 @@ object Interpreter { } else { // Tensor representation of a spider def vertexToSpider(v: Int): Tensor = { - def pullOutAngle(nv: NodeV) = if (!nv.value.isEmpty) { - try { - nv.value.toDouble - } catch { - case e: Error => nv.angle.evaluate(Map("pi" -> math.Pi)) * math.Pi - } + def pullOutAngle(nv: NodeV): PhaseExpression = if (!nv.value.isEmpty) { + nv.value } else { - nv.angle.evaluate(Map("pi" -> math.Pi)) * math.Pi + "0" } val (colour, nodeType) = adj.vertexColoursAndTypes(v) val numLegs = adj.mat(v).count(p => p) - val green = true colour match { case VertexColour.Boundary => Tensor.id(2) - case VertexColour.Green => interpretZXSpider(green, pullOutAngle(greenAM(nodeType)), 0, numLegs) - case VertexColour.Red => interpretZXSpider(!green, pullOutAngle(redAM(nodeType)), 0, numLegs) + case VertexColour.Green => interpretZXSpider( + ZXAngleData(isGreen = true, pullOutAngle(greenAM(nodeType))), 0, numLegs) + case VertexColour.Red => interpretZXSpider( + ZXAngleData(isGreen = false, pullOutAngle(redAM(nodeType))), 0, numLegs) } + } - val cap = interpretZXSpider(green = true, 0, 2, 0) + val cap = interpretZXSpider(ZXAngleData(isGreen = true, "0"), 2, 0) interpretAdjMat(adj, cap, vertexToSpider) } @@ -160,11 +407,14 @@ object Interpreter { // Using ZX colours for now colour match { case VertexColour.Boundary => Tensor.id(2) - case VertexColour.Green => interpretZWSpider(!black, numLegs) - case VertexColour.Red => interpretZWSpider(black, numLegs) + case VertexColour.Green => interpretZWSpiderNoInputs(!black, numLegs) + case VertexColour.Red => interpretZWSpiderNoInputs(black, numLegs) } } interpretAdjMat(adj, cup, vertexToSpider) } + + case class ZXAngleData(isGreen: Boolean, angle: PhaseExpression) + } diff --git a/scala/src/main/scala/quanto/cosy/RuleSynthesis.scala b/scala/src/main/scala/quanto/cosy/RuleSynthesis.scala index c8019902..fe46fa63 100644 --- a/scala/src/main/scala/quanto/cosy/RuleSynthesis.scala +++ b/scala/src/main/scala/quanto/cosy/RuleSynthesis.scala @@ -1,18 +1,35 @@ package quanto.cosy -import quanto.data._ +import quanto.cosy.RuleSynthesis.GraphComparison import quanto.data.Derivation.DerivationWithHead +import quanto.data._ import quanto.rewrite._ -import quanto.util.json.Json +import quanto.util.json.{Json, JsonObject} import scala.annotation.tailrec import scala.util.Random +import scala.util.matching.Regex +import quanto.data.Names._ /** * Created by hector on 29/06/17. */ object RuleSynthesis { + type GraphComparison = (Graph, Graph) => Int + // returns x where + // x < 0 iff left < right + // x > 0 iff left > right + // x == 0 otherwise + + def basicGraphComparison(left: Graph, right: Graph): Int = { + if (left < right) -1 else { + if (left > right) 1 else 0 + } + } + + + def loadRuleDirectory(directory: String): List[Rule] = { quanto.util.FileHelper.getListOfFiles(directory, raw".*\.qrule"). map(file => (file.getName.replaceFirst(raw"\.qrule", ""), Json.parse(file))). @@ -21,6 +38,7 @@ object RuleSynthesis { } /** Given an equivalence class creates rules for any irreducible members of the class */ + // Superseded by CoSyRun def graphEquivClassReduction[T](makeGraph: (T => Graph), equivalenceClass: EquivalenceClass[T], knownRules: List[Rule]): List[Rule] = { @@ -50,27 +68,117 @@ object RuleSynthesis { } } - def discardDirectlyReducibleRules(rules: List[Rule], theory: Theory, seed: Random = new Random()): List[Rule] = { - rules.filter(rule => - AutoReduce.genericReduce(graphToDerivation(rule.lhs, theory), rules.filter(r => r != rule), seed) >= rule.lhs - ) + def extendMatchingSpidersWithBBoxes(rule: Rule, boundariesRegex : Option[Regex]) : Rule = { + require(!rule.hasBBoxes) + // This is not safe to do if the rule already has bboxes. + + val boundaries = GraphAnalysis.boundariesFromRegex(rule.lhs, boundariesRegex).toList + def nearestNeighbourType(graph: Graph, vName: VName) : Option[(VName, String)] = { + val neighbours = graph.adjacentNodesAndBoundaries(vName) + if(neighbours.size == 1){ + val t : String = (graph.vdata(neighbours.head).data / "type").toString + Some((neighbours.head, t)) + } else None + } + + def addBBoxIfOkay(rule: Rule, vName: VName) : Rule = { + val lhsT = nearestNeighbourType(rule.lhs, vName) + val rhsT = nearestNeighbourType(rule.rhs, vName) + if(lhsT.nonEmpty && rhsT.nonEmpty && lhsT.get._2 == rhsT.get._2 && rule.lhs.vdata(lhsT.get._1).isInstanceOf[NodeV]){ + val leftNeighbour = lhsT.get._1 + val rightNeighbour = rhsT.get._1 + val bBName = rule.lhs.bboxes.freshWithSuggestion("bb0") + + val lhsB = rule.lhs.addBBox(bBName, BBData(), Set(vName)) + val rhsB = rule.lhs.addBBox(bBName, BBData(), Set(vName)) + Rule(lhsB, rhsB, rule.derivation, rule.description) + } else rule + } + + def removeBoundaryIfSuperfluous(rule: Rule, vName: VName): Rule = { + val lhsT = nearestNeighbourType(rule.lhs, vName) + val rhsT = nearestNeighbourType(rule.rhs, vName) + if(lhsT.nonEmpty && rhsT.nonEmpty && lhsT.get == rhsT.get && rule.lhs.vdata(lhsT.get._1).isInstanceOf[NodeV]) { + val ln = lhsT.get._1 + val rn = rhsT.get._1 + val lNeighbourhood = rule.lhs.adjacentNodesAndBoundaries(ln).intersect(boundaries.toSet) + val rNeighbourhood = rule.rhs.adjacentNodesAndBoundaries(rn).intersect(boundaries.toSet) + if((lNeighbourhood intersect rNeighbourhood).size > 1) { + val lCut = rule.lhs.deleteVertex(vName) + val rCut = rule.rhs.deleteVertex(vName) + Rule(lCut, rCut) + } else rule + } else rule + } + + val withBBoxes = boundaries.foldLeft(rule)(addBBoxIfOkay) + boundaries.foldLeft(withBBoxes)(removeBoundaryIfSuperfluous) } - def graphToDerivation(graph: Graph, theory: Theory): DerivationWithHead = { - (new Derivation(theory, graph), None) + def removeIsomorphisms(theory: Theory, boundaryRegex: Option[Regex], rules: List[Rule]) : List[Rule] = { + def isIso(rule: Rule) : Boolean = GraphAnalysis.checkIsomorphic(theory, boundaryRegex)(rule.lhs, rule.rhs) + rules.filter(!isIso(_)) } - def minimiseRuleset(rules: List[Rule], theory: Theory, seed: Random = new Random()): List[Rule] = { - rules.map(rule => minimiseRuleInPresenceOf(rule, rules.filter(otherRule => otherRule != rule), theory)) + def greedyReduceRules(comparison: GraphComparison, throwOutIsos : Option[(Theory, Option[Regex])] = None) + (rules: List[Rule]): List[Rule] = { + + // Will automatically invert rules that head upwards + + // This does not throw out isomorphisms for you; it returns the list with each entry altered + + // Yes, this is imperative rather than functional. We really do want to update a list as we act on it. + + var rulesAsMap = rules.zipWithIndex.map(ri => (ri._2, ri._1)).toMap // i -> rule_i + var updatedThisRun = true + + def isomorphism(rule: Rule): Boolean = throwOutIsos match { + case Some(tor) => + GraphAnalysis.checkIsomorphic(tor._1, tor._2)(rule.lhs, rule.rhs) + case None => false + } + + while (updatedThisRun) { + updatedThisRun = false + for (i <- rulesAsMap.keys) { + val rule = rulesAsMap(i) + val otherRules = rulesAsMap.filterKeys(_ != i).values.toList.filter(!isomorphism(_)).map( + rule => { + if (comparison(rule.lhs, rule.rhs) < 0) { + rule.inverse + } else rule + } + ) + val newLhs: Graph = AutoReduce.greedyReduce(comparison, graphToDerivation(rule.lhs), otherRules) + val newRhs: Graph = AutoReduce.greedyReduce(comparison, graphToDerivation(rule.rhs), otherRules) + if (newLhs != rule.lhs || newRhs != rule.rhs) { + rulesAsMap = rulesAsMap + (i -> new Rule(newLhs, + newRhs, + derivation = None, + RuleDesc(rule.name + " reduced") + )) + updatedThisRun = true + } + } + } + + rulesAsMap.values.toList } - def minimiseRuleInPresenceOf(rule: Rule, otherRules: List[Rule], theory: Theory, seed: Random = new Random()): Rule = { - val minLhs: Graph = AutoReduce.genericReduce(graphToDerivation(rule.lhs, theory), otherRules, seed) - val minRhs: Graph = AutoReduce.genericReduce(graphToDerivation(rule.rhs, theory), otherRules, seed) - val wasItReduced = (minLhs < rule.lhs) || (minRhs < rule.rhs) - new Rule(minLhs, minRhs, description = RuleDesc( - rule.name + (if (wasItReduced) " reduced" else ""))) + def discardDirectlyReducibleRules(comparison: GraphComparison, + rules: List[Rule], + seed: Random = new Random()): List[Rule] = { + rules.filter(rule => + AutoReduce.greedyReduce(comparison, graphToDerivation(rule.lhs), rules.filter(r => r != rule)) >= rule.lhs + ).filter(rule => + AutoReduce.greedyReduce(comparison, graphToDerivation(rule.rhs), rules.filter(r => r != rule)) >= rule.rhs + ) } + + def graphToDerivation(graph: Graph): DerivationWithHead = { + (new Derivation(graph), None) + } + } /** @@ -82,8 +190,6 @@ object AutoReduce { * Automatically reduce, with no handler or multithreading */ - implicit def inverseToRuleVariant(inverse: Boolean): RuleVariant = if (inverse) RuleInverse else RuleNormal - def smallestStepNameBelow(derivationHeadPair: (Derivation, Option[DSName])): Option[DSName] = { derivationHeadPair._2 match { case Some(head) => @@ -101,19 +207,20 @@ object AutoReduce { } // Tries multiple methods and is sure to return nothing larger than what you started with - def genericReduce(derivationAndHead: DerivationWithHead, + def genericReduce(comparison: GraphComparison) + (derivationAndHead: DerivationWithHead, rules: List[Rule], seed: Random = new Random()): DerivationWithHead = { var latestDerivation: DerivationWithHead = derivationAndHead latestDerivation._2 match { case Some(initialHead) => - latestDerivation = annealingReduce(latestDerivation, rules, seed) - latestDerivation = greedyReduce(latestDerivation, rules) - latestDerivation = greedyReduce((latestDerivation._1.addHead(initialHead), Some(initialHead)), rules) + latestDerivation = annealingReduce(comparison, latestDerivation, rules, seed) + latestDerivation = greedyReduce(comparison, latestDerivation, rules) + latestDerivation = greedyReduce(comparison, (latestDerivation._1.addHead(initialHead), Some(initialHead)), rules) case None => - latestDerivation = annealingReduce(latestDerivation, rules, seed) - latestDerivation = greedyReduce(latestDerivation, rules) + latestDerivation = annealingReduce(comparison, latestDerivation, rules, seed) + latestDerivation = greedyReduce(comparison, latestDerivation, rules) } // Go back to original request, find smallest child @@ -121,17 +228,19 @@ object AutoReduce { } // Simplest entry point - def annealingReduce(derivationHeadPair: DerivationWithHead, + def annealingReduce(comparison: GraphComparison, + derivationHeadPair: DerivationWithHead, rules: List[Rule], seed: Random = new Random(), vertexLimit: Option[Int] = None): DerivationWithHead = { - val maxTime = math.pow(derivationHeadPair.verts.size, 2).toInt // Set as squaring #vertices for now + val maxTime = 100 + math.pow(derivationHeadPair.verts.size, 2).toInt // Set as squaring #vertices for now val timeDilation = 3 // Gives an e^-3 ~ 0.05% chance of a non-reduction rule on the final step - annealingReduce(derivationHeadPair, rules, maxTime, timeDilation, seed, vertexLimit) + annealingReduce(comparison, derivationHeadPair, rules, maxTime, timeDilation, seed, vertexLimit) } // Enter here to have control over e.g. how long it runs for - def annealingReduce(derivationHeadPair: DerivationWithHead, + def annealingReduce(comparison: GraphComparison, + derivationHeadPair: DerivationWithHead, rules: List[Rule], maxTime: Int, timeDilation: Double, @@ -144,41 +253,45 @@ object AutoReduce { val suggestedNextStep = randomSingleApply(d, randRule, seed, None, None, None) val head = Derivation.derivationHeadPairToGraph(d) val smallEnough = vertexLimit.isEmpty || (head.verts.size < vertexLimit.get) - if ((allowIncrease && smallEnough) || suggestedNextStep < head) suggestedNextStep else d + if ((allowIncrease && smallEnough) || comparison(suggestedNextStep, head) < 0) suggestedNextStep else d } else d } } @tailrec - def greedyReduce(derivationHeadPair: DerivationWithHead, + def greedyReduce(comparison: GraphComparison, + derivationHeadPair: DerivationWithHead, rules: List[Rule], remainingRules: List[Rule]): DerivationWithHead = { remainingRules match { case r :: tailRules => Matcher.findMatches(r.lhs, derivationHeadPair) match { - case ruleMatch #:: t => + case ruleMatch #:: _ => val reducedGraph = Rewriter.rewrite(ruleMatch, r.rhs)._1.minimise val stepName = quanto.data.Names.mapToNameMap(derivationHeadPair._1.steps). freshWithSuggestion(DSName(r.description.name)) - greedyReduce((derivationHeadPair._1.addStep( - derivationHeadPair._2, - DStep(stepName, r, reducedGraph) - ), Some(stepName)), + greedyReduce(comparison, + (derivationHeadPair._1.addStep( + derivationHeadPair._2, + DStep(stepName, r, reducedGraph) + ), Some(stepName)), rules, rules) case Stream.Empty => - greedyReduce(derivationHeadPair, rules, tailRules) + greedyReduce(comparison, derivationHeadPair, rules, tailRules) } case Nil => derivationHeadPair } } // Simply apply the first reduction rule it can find until there are none left - def greedyReduce(derivationHeadPair: DerivationWithHead, rules: List[Rule]): DerivationWithHead = { + def greedyReduce(comparison: GraphComparison, + derivationHeadPair: DerivationWithHead, + rules: List[Rule]): DerivationWithHead = { val reducingRules = rules.filter(rule => rule.lhs > rule.rhs) - val reduced = greedyReduce(derivationHeadPair, reducingRules, reducingRules) + val reduced = greedyReduce(comparison, derivationHeadPair, reducingRules, reducingRules) // Go round again until it stops reducing - if (reduced < derivationHeadPair) { - greedyReduce(reduced, rules) + if (comparison(reduced, derivationHeadPair) < 0) { + greedyReduce(comparison, reduced, rules) } else reduced } diff --git a/scala/src/main/scala/quanto/cosy/SimplificationProcedure.scala b/scala/src/main/scala/quanto/cosy/SimplificationProcedure.scala index 7b9ca718..1aabc1b7 100644 --- a/scala/src/main/scala/quanto/cosy/SimplificationProcedure.scala +++ b/scala/src/main/scala/quanto/cosy/SimplificationProcedure.scala @@ -134,9 +134,9 @@ object SimplificationProcedure { (DerivationWithHead, State) = { import state._ val d = derivation - val randRule = rules.toList(seed.nextInt(rules.size)) + val randRule = rules(seed.nextInt(rules.size)) val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(d) - val errors = GraphAnalysis.detectErrors(d) + val errors = GraphAnalysis.detectPiNodes(d) val specialsDistances = errors.map(eName => (eName, state.weightFunction(d, Set(eName)))) val maxDistance = specialsDistances.maxBy[Double](ed => ed._2.getOrElse(0))._2.getOrElse(0) @@ -153,8 +153,9 @@ object SimplificationProcedure { val changed = suggestedNextStep._1.steps.size > derivation._1.steps.size - val shrunkNextStep = AutoReduce.greedyReduce(suggestedNextStep, greedyRules.getOrElse(Set()).toList) - val newErrors = GraphAnalysis.detectErrors(shrunkNextStep) + val shrunkNextStep = AutoReduce.greedyReduce(RuleSynthesis.basicGraphComparison, + suggestedNextStep, greedyRules.getOrElse(Set()).toList) + val newErrors = GraphAnalysis.detectPiNodes(shrunkNextStep) val suggestedNewSize: Double = state.weightFunction(shrunkNextStep, newErrors).getOrElse(0) // Bias towards strict reduction @@ -168,7 +169,7 @@ object SimplificationProcedure { println(randRule.description) (shrunkNextStep, state.next(Some(suggestedNewSize))) } else { - println("rej " + suggestedNewSize) + println("rej " + suggestedNewSize) (d, state.next(currentDistance)) } } diff --git a/scala/src/main/scala/quanto/cosy/SimprocBatch.scala b/scala/src/main/scala/quanto/cosy/SimprocBatch.scala new file mode 100644 index 00000000..adf5e931 --- /dev/null +++ b/scala/src/main/scala/quanto/cosy/SimprocBatch.scala @@ -0,0 +1,199 @@ +package quanto.cosy + +import java.util.Calendar + +import quanto.data.Names._ +import quanto.data._ +import quanto.gui.{BatchDerivationCreatorPanel, QuantoDerive} +import quanto.rewrite.Simproc +import quanto.util.UserAlerts.{Elevation, alert} +import quanto.util.json.{Json, JsonArray, JsonObject} +import quanto.util.{FileHelper, UserOptions} + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future +import scala.swing.Publisher +import scala.swing.event.Event +import scala.util.{Failure, Success} + + +// Each simproc, graph pair generates a SimprocSingleRun result +// The result holds the name of the simproc, the generated derivation, and the timings for each step +case class SimprocSingleRun(simprocName: String, derivation: Derivation, derivationTimings: List[(String, Double)]) + +object SimprocSingleRun { + def toJson(ssr: SimprocSingleRun): Json = { + JsonObject( + "simproc" -> ssr.simprocName, + "derivation" -> Derivation.toJson(ssr.derivation), + "timings" -> JsonArray( + ssr.derivationTimings.map(t => + JsonObject( + "step" -> t._1, + "time" -> t._2 + ) + ) + ) + ) + } + + def fromJson(json: Json): SimprocSingleRun = { + val name: String = (json / "simproc").stringValue + val derivation: Derivation = Derivation.fromJson(json / "derivation") + val timings: List[(String, Double)] = (json / "timings").asArray.map(j => ( + (j / "step").stringValue, + (j / "time").doubleValue)).toList + SimprocSingleRun(name, derivation, timings) + } +} + +// Collect the single runs and the simproc definitions in one place +case class SimprocBatchResult(selectedSimprocs: List[String], + allSimprocs: Map[String, String], + singleResults: List[SimprocSingleRun], + notes: String) { + + lazy val toJson: JsonObject = { + JsonObject( + "python" -> JsonArray( + allSimprocs.map( + ss => JsonObject( + ss._1 -> ss._2 + ) + ) + ), + "selected_simprocs" -> JsonArray(selectedSimprocs), + "results" -> JsonArray( + singleResults.map(sr => SimprocSingleRun.toJson(sr)) + ), + "notes" -> notes + ) + } +} + +// This class contains just the metadata for the run +// It is used when loading the results file when you don't want everything to be pulled from the json +case class SimprocLazyBatchResult(notes: String, selectedSimprocs: List[String], resultCount: Int) + +object SimprocBatchResult { + def collate(SBResult: SimprocBatchResult): Map[String, List[(Derivation, List[(String, Double)])]] = { + SBResult.singleResults. + groupBy { case SimprocSingleRun(a, _, _) => a }. + mapValues(_.map { case SimprocSingleRun(_, b, c) => (b, c) }) + } + + // Import everything into memory + def fromJson(json: Json): SimprocBatchResult = { + val notes: String = (json / "notes").stringValue + val selectedSimprocs: List[String] = (json / "selected_simprocs").asArray.map(a => a.stringValue).toList + val allSimprocs: Map[String, String] = (json / "python").asArray.flatMap(j => j.asObject.v.map(sj => sj._1 -> sj._2.stringValue)).toMap + val results: List[SimprocSingleRun] = (json / "results").asArray.map(j => SimprocSingleRun.fromJson(j.asObject)).toList + SimprocBatchResult(selectedSimprocs, allSimprocs, results, notes) + } + + // Just import the metadata into memory + def lazyFromJson(json: Json): SimprocLazyBatchResult = { + val notes: String = (json / "notes").stringValue + val selectedSimprocs: List[String] = (json / "selected_simprocs").asArray.map(a => a.stringValue).toList + val resultCount: Int = (json / "results").asArray.size + SimprocLazyBatchResult(notes, selectedSimprocs, resultCount) + } +} + + +// Set up the run as a SimprocBatch +// This contains all the user supplied information +// Then pulls loaded simprocs etc. from the project +case class SimprocBatch(selectedSimprocs: List[String], selectedGraphs: List[Graph], notes: String) { + def run(): Unit = { + alert(s"Beginning simproc batch run on ${selectedGraphs.length} graphs, " + + s"with simprocs\n${selectedSimprocs.mkString("\n")}") + + val simprocGraphPairs = for (simprocName <- selectedSimprocs; graph <- selectedGraphs) yield (simprocName, graph) + val listFutureResults = simprocGraphPairs.map(sg => { + Future { + val derivationData = SimprocBatch.runSimprocGetTimings(sg._1, sg._2) + SimprocSingleRun(sg._1, derivationData._1, derivationData._2) + } + }) + val futureListResults = Future.sequence(listFutureResults) + + futureListResults.onComplete { + case Success(list) => + val result = SimprocBatchResult(selectedSimprocs, + SimprocBatch.loadedSimprocs.map(ss => (ss._1, ss._2.sourceCode)), + list, + notes) + SimprocBatch.publish(SimprocBatchRunComplete(result)) + case Failure(e) => + alert("Simproc batch run failed!", Elevation.ERROR) + e.printStackTrace() + } + } +} + +case class SimprocNotLoaded(simprocName: String) extends Exception + +case class SimprocBatchRunComplete(result: SimprocBatchResult) extends Event + +object SimprocBatch extends Publisher { + + var timeout: Long = 60 * 1000 // timeout time in milliseconds + + // CONCURRENTLY run a simproc on a graph + // Makes its own derivation with timing data + // Runs until completion or timeout + def runSimprocGetTimings(simprocName: String, graph: Graph): (Derivation, List[(String, Double)]) = { + val simproc = simprocFromName(simprocName) + val startTime: Long = now + val jobIDAtStart = BatchDerivationCreatorPanel.jobID + + def timeElapsed = now - startTime + + var timings: List[(String, Double)] = List() + var derivation = new Derivation(graph) + var parentOpt: Option[DSName] = None + for ((graph, rule) <- simproc.simp(graph)) { + // Stop if taking too long + // Stop if the user has incremented the jobID (indicating they want the job to halt) + if (timeElapsed < timeout || BatchDerivationCreatorPanel.jobID > jobIDAtStart) { + val suggest = rule.name.replaceFirst("^.*\\/", "") + "-0" + val step = DStep( + name = derivation.steps.freshWithSuggestion(DSName(suggest)), + rule = rule, + graph = graph.minimise) // layout is already done by simproc now + + derivation = derivation.addStep(parentOpt, step) + timings = (step.name.toString, timeElapsed.toDouble) :: timings + parentOpt = Some(step.name) + } + } + + (derivation, timings) + } + + listenTo(this) + reactions += { + case SimprocBatchRunComplete(result) => + alert("Simproc Batch completed") + val fileName = UserOptions.preferredDateTimeFormat.format(Calendar.getInstance().getTime) + .replace(":","-").replace(".","--") + ".qsbr" + val projectRoot = QuantoDerive.CurrentProject.map(p => p.rootFolder + "/").getOrElse("") + FileHelper.printJson(projectRoot + "batch_results/" + fileName, result.toJson) + } + + implicit def simprocFromName(name: String): Simproc = { + try { + loadedSimprocs(name) + } catch { + case _: Exception => + alert(s"Requested simproc $name was not found. Please load it first.", Elevation.ERROR) + throw SimprocNotLoaded(name) + } + } + + def now: Long = Calendar.getInstance().getTimeInMillis + + def loadedSimprocs: Map[String, Simproc] = QuantoDerive.CurrentProject.map(p => p.simprocs).getOrElse(Map()) + +} \ No newline at end of file diff --git a/scala/src/main/scala/quanto/cosy/Tensor.scala b/scala/src/main/scala/quanto/cosy/Tensor.scala index d845bd59..79febda3 100644 --- a/scala/src/main/scala/quanto/cosy/Tensor.scala +++ b/scala/src/main/scala/quanto/cosy/Tensor.scala @@ -2,8 +2,6 @@ package quanto.cosy import quanto.util.json.{JsonArray, JsonObject} -import scala.runtime.RichInt - /** * A tensor-valued interpretation of a graph * @@ -81,16 +79,16 @@ class Tensor(c: Array[Array[Complex]]) { def t: Tensor = this.transpose - def transpose: Tensor = { - // transpose - Tensor(this.width, this.height, (i, j) => this.c(j)(i)) - } - def dagger: Tensor = { // conjugate transpose this.conjugate.transpose } + def transpose: Tensor = { + // transpose + Tensor(this.width, this.height, (i, j) => this.c(j)(i)) + } + def conjugate: Tensor = { Tensor(this.height, this.width, (i, j) => this.c(i)(j).conjugate) } @@ -164,20 +162,41 @@ class Tensor(c: Array[Array[Complex]]) { def entry(down: Int, across: Int): Complex = contents(down)(across) def isRoughlyUpToScalar(that: Tensor, distance: Double = Tensor.defaultDistance): Boolean = { - // + + if (!this.isSameShapeAs(that)) return false + + val thisIsRoughly0 = this.isRoughly(Tensor.zero(this.height, this.width), distance) + val thatIsRoughly0 = that.isRoughly(Tensor.zero(this.height, this.width), distance) + + if (thisIsRoughly0 && thatIsRoughly0) return true + if (thisIsRoughly0 && !thatIsRoughly0) return false + if (!thisIsRoughly0 && thatIsRoughly0) return false this.distanceAfterScaling(that) < distance } def distanceAfterScaling(that: Tensor): Double = { if (this.isSameShapeAs(that)) { - var maxEntry = (0, 0) - var maxEntryValue = Complex.zero - for (i <- this.c.indices; j <- this.c.head.indices) { + + val (maxEntry, maxEntryValue) = + (for (i <- this.c.indices; j <- this.c.head.indices) yield (i, j)).toList.foldLeft((0, 0), Complex.zero) { + (agg, coord) => { + val value = this.c(coord._1)(coord._2) + if (value.abs > agg._2.abs) { + (coord, value) + } else { + agg + } + } + } + + /* + { if (this.c(i)(j).abs > maxEntryValue.abs) { maxEntry = (i, j) maxEntryValue = this.c(i)(j) } } + */ if (maxEntryValue == Complex.zero) { this.distance(that) } else { @@ -189,7 +208,7 @@ class Tensor(c: Array[Array[Complex]]) { } } } else { - -1 + 1 } } @@ -198,6 +217,41 @@ class Tensor(c: Array[Array[Complex]]) { isRoughly(_, maxDist) } + def isRoughly(that: Tensor, maxDistance: Double = 1e-14): Boolean = + // Compare two tensors up to a given distance + if (this.isSameShapeAs(that)) { + this.distance(that) < maxDistance + } else { + false + } + + /** Returns max abs distance */ + def distance(that: Tensor): Double = { + require(this.isSameShapeAs(that)) + (this - that).contents.flatten.foldLeft(0.0) { (a: Double, b: Complex) => math.max(a, b.abs) } + } + + def isSameShapeAs(that: Tensor): Boolean = { + this.width == that.width && this.height == that.height + } + + def -(that: Tensor): Tensor = { + // this - that + this + that.scaled(Complex(-1, 0)) + } + + def +(that: Tensor): Tensor = { + // this + that + require(this.width == that.width) + require(this.height == that.height) + Tensor(this.height, this.width, (i, j) => this.c(i)(j) + that.contents(i)(j)) + } + + def scaled(factor: Complex): Tensor = { + // scalar multiplication + Tensor(this.height, this.width, (i, j) => this.c(i)(j) * factor) + } + override def equals(other: Any): Boolean = // Compares matrix-entry by matrix-entry other match { @@ -250,37 +304,6 @@ class Tensor(c: Array[Array[Complex]]) { this.scaled(maxAbsEntry.inverse()) } } - - def isRoughly(that: Tensor, maxDistance: Double = 1e-14): Boolean = - // Compare two tensors up to a given distance - this.distance(that) < maxDistance - - /** Returns max abs distance */ - def distance(that: Tensor): Double = { - require(this.isSameShapeAs(that)) - (this - that).contents.flatten.foldLeft(0.0) { (a: Double, b: Complex) => math.max(a, b.abs) } - } - - def isSameShapeAs(that: Tensor): Boolean = { - this.width == that.width && this.height == that.height - } - - def -(that: Tensor): Tensor = { - // this - that - this + that.scaled(Complex(-1, 0)) - } - - def +(that: Tensor): Tensor = { - // this + that - require(this.width == that.width) - require(this.height == that.height) - Tensor(this.height, this.width, (i, j) => this.c(i)(j) + that.contents(i)(j)) - } - - def scaled(factor: Complex): Tensor = { - // scalar multiplication - Tensor(this.height, this.width, (i, j) => this.c(i)(j) * factor) - } } object Tensor { @@ -293,6 +316,8 @@ object Tensor { Tensor(Array(Array(1, 1), Array(1, -1))).scaled(math.pow(2, -0.5)) } val defaultDistance = 1e-14 + private var permutationCache: Map[List[Int], Matrix] = Map() + private var swapCache: Map[List[Int], Tensor] = Map() def idWires(n: Int): Tensor = { // Identity on n wires, i.e. 2^n * 2^n matrix @@ -304,13 +329,6 @@ object Tensor { Tensor(n, n, (a: Int, b: Int) => if (a == b) new Complex(1) else new Complex(0)) } - def apply(c: Array[Array[Complex]]) = new Tensor(c) - - def apply(cInt: Array[Array[Int]]): Tensor = { - // Convert integer matrix to complex matrix - Tensor(cInt.length, cInt(0).length, (i, j) => Complex.doubleToComplex(cInt(i)(j))) - } - def apply(height: Int, width: Int, generator: Tensor.Generator): Tensor = { // create tensor based on size and generating function new Tensor(generatorToMatrix(height, width, generator)) @@ -335,47 +353,72 @@ object Tensor { (for (_ <- 0 until width) yield Complex.zero).toArray).toArray } + def apply(c: Array[Array[Complex]]) = new Tensor(c) + + def apply(cInt: Array[Array[Int]]): Tensor = { + // Convert integer matrix to complex matrix + Tensor(cInt.length, cInt(0).length, (i, j) => Complex.doubleToComplex(cInt(i)(j))) + } + def permutation(asList: List[Int]): Tensor = { // Produce the matrix that sends i -> asList(i) val gen = (x: Int) => asList(x) new Tensor(permutationMatrix(asList.length, gen)) } - def permutationMatrix(size: Int, gen: Int => Int): Matrix = { - val base = emptyMatrix(size, size) - for (i <- 0 until size) { - base(gen(i))(i) = Complex.one - } - base - } - def permutation(asArray: Array[Int]): Tensor = { // Produce the matrix that sends i -> asArray(i) val gen = (x: Int) => asArray(x) new Tensor(permutationMatrix(asArray.length, gen)) } + def permutationMatrix(size: Int, gen: Int => Int): Matrix = { + val genAsList: List[Int] = (0 until size).map(gen(_)).toList + if (permutationCache.contains(genAsList)) { + permutationCache(genAsList) + } else { + val base = emptyMatrix(size, size) + for (i <- 0 until size) { + base(genAsList(i))(i) = Complex.one + } + permutationCache += genAsList -> base + base + } + } + def swap(asList: List[Int]): Tensor = { // Produce the matrix that sends WIRE i to WIRE asList(i) + // READ BOTTOM TO TOP val gen = (x: Int) => asList(x) swap(asList.length, gen) } def swap(size: Int, gen: Int => Int): Tensor = { // Produce the matrix that sends WIRE i to WIRE gen(i) - def padLeft(s: String, n: Int): String = if (s.length < n) padLeft("0" + s, n) else s - - def permGen(i: Int): Int = { - val binaryStringIn = padLeft((i: RichInt).toBinaryString, size) - val permedString = (for (j <- 0 until size) yield binaryStringIn(gen(j))).mkString("") - permedString match { - case "" => 0 - case s => Integer.parseInt(s, 2) + // READ BOTTOM TO TOP + val genAsList: List[Int] = (0 until size).map(gen).toList + + if (swapCache.contains(genAsList)) { + swapCache(genAsList) + } else { + + def padLeft(s: String, n: Int): String = if (s.length < n) padLeft("0" + s, n) else s + + def permGen(i: Int): Int = { + val binaryStringIn = padLeft(i.toBinaryString, size) + val permedString = (for (j <- 0 until size) yield binaryStringIn(gen(j))).mkString("") + permedString match { + case "" => 0 + case s => Integer.parseInt(s, 2) + } + } + val answer = permutation(math.pow(2, size).toInt, permGen).transpose + swapCache += genAsList -> answer + answer } - permutation(math.pow(2, size).toInt, permGen) } def permutation(size: Int, gen: Int => Int): Tensor = { diff --git a/scala/src/main/scala/quanto/cosy/qutrits.scala b/scala/src/main/scala/quanto/cosy/qutrits.scala deleted file mode 100644 index a3ae36b7..00000000 --- a/scala/src/main/scala/quanto/cosy/qutrits.scala +++ /dev/null @@ -1,17 +0,0 @@ -package quanto.cosy - -object qutrits { - - def ei ( arg : Double) = Complex(math.cos(arg),math.sin(arg)) - def conj (T : Tensor) = Tensor(T.height,T.width,(i, j) => T(j,i).conjugate) - val w : Complex = ei(2*math.Pi/3) - val H_unnormalised = Tensor(Array(Array[Complex](1,1,1),Array[Complex](1,w,w*w),Array[Complex](1,w*w,w))) - val H = H_unnormalised.scaled(1.0 / math.sqrt(3)) - val e0 = Tensor(Array(Array(1,0,0))).transpose - val e1 = Tensor(Array(Array(0,1,0))).transpose - val e2 = Tensor(Array(Array(0,0,1))).transpose - val f0 = Tensor(Array(Array(1,1,1))).transpose.scaled(1.0 / math.sqrt(3)) - val f1 = Tensor(Array(Array(Complex.one,w,w*w))).transpose.scaled(1.0 / math.sqrt(3)) - val f2 = Tensor(Array(Array(Complex.one,w*w,w))).transpose.scaled(1.0 / math.sqrt(3)) -} - diff --git a/scala/src/main/scala/quanto/data/AngleExpression.scala b/scala/src/main/scala/quanto/data/AngleExpression.scala deleted file mode 100644 index e252485a..00000000 --- a/scala/src/main/scala/quanto/data/AngleExpression.scala +++ /dev/null @@ -1,139 +0,0 @@ -package quanto.data - -// ported from linrat_angle_expr.ML - -import quanto.util.Rational - -import scala.util.parsing.combinator._ - -class AngleParseException(message: String) - extends Exception(message) - -class AngleEvaluationException(message: String) - extends Exception(message) - -class AngleExpression(val const : Rational, val coeffs : Map[String,Rational]) { - lazy val vars = coeffs.keySet - - def *(r : Rational): AngleExpression = - AngleExpression(const * r, coeffs.mapValues(x => x * r)) - - def *(i : Int): AngleExpression = this * Rational(i) - - def +(e : AngleExpression) = AngleExpression(const + e.const, e.coeffs.foldLeft(coeffs) { - case (m, (k,v)) => m + (k -> (v + m.getOrElse(k, Rational(0)))) - }) - - def -(e: AngleExpression) : AngleExpression = this + (e * -1) - - def subst(v : String, e : AngleExpression) : AngleExpression = { - val c = coeffs.getOrElse(v,Rational(0)) - this - AngleExpression(Rational(0), Map(v -> c)) + (e * c) - } - - def subst(mp : Map[String, AngleExpression]): AngleExpression = - mp.foldLeft(this) { case (e, (v,e1)) => e.subst(v,e1) } - - def evaluate(mp: Map[String, Double]) : Double = { - try { - const + coeffs.foldLeft(0.0) { (a, b) => a + (mp(b._1) * Rational.rationalToDouble(b._2)) } - } catch { - case e : Exception => new AngleEvaluationException("Given arguments do not match those in the coefficient list") - 0 - } - } - - override def equals(that: Any): Boolean = that match { - case e : AngleExpression => - const == e.const && coeffs == e.coeffs - case _ => false - } - - override def toString: String = { - var fst = true - var s = "" - if (!const.isZero) { - fst = false - if (const.isOne) s += "\\pi" - else s += const.toString + " \\pi" - } - - coeffs.foreach { case (x,c) => - if (fst) { - fst = false - s = s + (if (c == Rational(1)) "" else c.toString + " ") + x - } else { - if (c < Rational(0)) { - s = s + " - " + (if (c == Rational(-1)) "" else (c * -1).toString + " ") + x - } else { - s = s + " + " + (if (c == Rational(1)) "" else c.toString + " ") + x - } - } - } - - s - } - - -} - -object AngleExpression { - def apply(const : Rational = Rational(0), - coeffs : Map[String,Rational] = Map()) = - new AngleExpression(const mod 2, coeffs.filter { case (_,c) => !c.isZero }) - - val ZERO = AngleExpression(Rational(0)) - val ONE_PI = AngleExpression(Rational(1)) - - def parse(s : String) = AngleExpressionParser.p(s) - - private object AngleExpressionParser extends RegexParsers { - override def skipWhitespace = true - def INT: Parser[Int] = """[0-9]+""".r ^^ { _.toInt } - def INT_OPT : Parser[Int] = INT.? ^^ { _.getOrElse(1) } - def IDENT : Parser[String] = """[\\a-zA-Z_][a-zA-Z0-9_]*""".r ^^ { _.toString } - def PI : Parser[Unit] = """\\?[pP][iI]""".r ^^ { _ => Unit } - - - def coeff : Parser[Rational] = - INT ~ "/" ~ INT ^^ { case n ~ _ ~ d => Rational(n,d) } | - "(" ~ coeff ~ ")" ^^ { case _ ~ c ~ _ => c } | - INT ^^ { n => Rational(n) } - - - def frac : Parser[AngleExpression] = - INT_OPT ~ "*".? ~ PI ~ "/" ~ INT ^^ { case n ~ _ ~ _ ~ _ ~ d => AngleExpression(Rational(n,d)) } | - INT_OPT ~ "*".? ~ IDENT ~ "/" ~ INT ^^ { - case n ~ _ ~ x ~ _ ~ d => AngleExpression(Rational(0), Map(x -> Rational(n,d))) - } - - def term : Parser[AngleExpression] = - frac | - "-" ~ term ^^ { case _ ~ t => t * -1 } | - coeff ~ "*".? ~ PI ^^ { case c ~ _ ~ _ => AngleExpression(c) } | - PI ^^ { _ => ONE_PI } | - coeff ~ "*".? ~ IDENT ^^ { case c ~ _ ~ x => AngleExpression(Rational(0), Map(x -> c)) } | - IDENT ^^ { case x => AngleExpression(Rational(0), Map(x -> Rational(1))) } | - coeff ^^ { AngleExpression(_) } | - "(" ~ expr ~ ")" ^^ { case _ ~ t ~ _ => t } - - def term1 : Parser[AngleExpression] = - "+" ~ term ^^ { case _ ~ t => t } | - "-" ~ term ^^ { case _ ~ t => t * -1 } - - def terms : Parser[AngleExpression] = - term1 ~ terms ^^ { case s ~ t => s + t } | - term1 - - def expr : Parser[AngleExpression] = - term ~ terms ^^ { case s ~ t => s + t } | - term | - "" ^^ { _ => ZERO } - - def p(s : String) = parseAll(expr, s) match { - case Success(e, _) => e - case Failure(msg, _) => throw new AngleParseException(msg) - case Error(msg, _) => throw new AngleParseException(msg) - } - } -} diff --git a/scala/src/main/scala/quanto/data/BBData.scala b/scala/src/main/scala/quanto/data/BBData.scala index 95549032..b41d1552 100644 --- a/scala/src/main/scala/quanto/data/BBData.scala +++ b/scala/src/main/scala/quanto/data/BBData.scala @@ -3,14 +3,14 @@ package quanto.data import quanto.util.json.JsonObject /** - * A class which represents the data for bang boxes - * - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BBData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - * @author Aleks Kissinger - */ + * A class which represents the data for bang boxes + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BBData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + * @author Aleks Kissinger + */ case class BBData( - data: JsonObject = JsonObject(), - annotation: JsonObject = JsonObject(), - theory: Theory = Theory.DefaultTheory -) extends GraphElementData + data: JsonObject = JsonObject(), + annotation: JsonObject = JsonObject(), + theory: Theory = Theory.DefaultTheory + ) extends GraphElementData diff --git a/scala/src/main/scala/quanto/data/BinRel.scala b/scala/src/main/scala/quanto/data/BinRel.scala index 521c7834..4e7f29b5 100644 --- a/scala/src/main/scala/quanto/data/BinRel.scala +++ b/scala/src/main/scala/quanto/data/BinRel.scala @@ -1,32 +1,31 @@ /** A package which contains the data objects */ package quanto.data -import collection.immutable.TreeSet -import scala.collection.{mutable, IterableLike} +import scala.collection.immutable.TreeSet +import scala.collection.{IterableLike, mutable} /** - * A trait which contains useful methods for binary relations - * - * @tparam A type of the domain - * @tparam B type of the codomain - * - * @author Aleks Kissinger - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BinRel.scala Source code]] - */ -trait BinRel[A,B] extends IterableLike[(A,B), BinRel[A,B]] { + * A trait which contains useful methods for binary relations + * + * @tparam A type of the domain + * @tparam B type of the codomain + * @author Aleks Kissinger + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BinRel.scala Source code]] + */ +trait BinRel[A, B] extends IterableLike[(A, B), BinRel[A, B]] { /** Domain function - assigns a set to each element of the domain */ - def domf : Map[A,Set[B]] + def domf: Map[A, Set[B]] /** Codomain function - assigns a set to each element of the codomain */ - def codf : Map[B,Set[A]] + def codf: Map[B, Set[A]] - /** - * Add an element to relation - * - * @param kv element to be added - * @return New relation containing '''kv''' in addition to the current one - */ - def +(kv: (A,B)): BinRel[A,B] + /** + * Add an element to relation + * + * @param kv element to be added + * @return New relation containing '''kv''' in addition to the current one + */ + def +(kv: (A, B)): BinRel[A, B] /** * Check if an element is in a relation @@ -34,166 +33,173 @@ trait BinRel[A,B] extends IterableLike[(A,B), BinRel[A,B]] { * @param kv element to be checked * @return boolean */ - def contains(kv: (A,B)): Boolean + def contains(kv: (A, B)): Boolean /** - * Remove an element from relation - * - * @param kv element to be removed - * @return New relation containing the same elements as the - * current one except '''kv''' - */ - def unmap(kv: (A,B)): BinRel[A,B] - def -(kv: (A,B)) : BinRel[A,B] = unmap(kv) + * Remove an element from relation + * + * @param kv element to be removed + * @return New relation containing the same elements as the + * current one except '''kv''' + */ + def unmap(kv: (A, B)): BinRel[A, B] + + def -(kv: (A, B)): BinRel[A, B] = unmap(kv) /** - * Remove all relation pairs '''(d,_)''' - * - * @param d Element to be removed from domain - * @return New relation with relation pairs '''(d,_)''' removed - */ - def unmapDom(d: A): BinRel[A,B] + * Remove all relation pairs '''(d,_)''' + * + * @param d Element to be removed from domain + * @return New relation with relation pairs '''(d,_)''' removed + */ + def unmapDom(d: A): BinRel[A, B] /** - * Remove all relation pairs '''(_,c)''' - * - * @param c Element to be removed from codomain - * @return New relation with relation pairs '''(_,c)''' removed - */ - def unmapCod(c: B): BinRel[A,B] + * Remove all relation pairs '''(_,c)''' + * + * @param c Element to be removed from codomain + * @return New relation with relation pairs '''(_,c)''' removed + */ + def unmapCod(c: B): BinRel[A, B] /** The domain set of the relation */ - def dom = domf.keys - def domSet = domf.keySet + def dom: Iterable[A] = domf.keys + + def domSet: Set[A] = domf.keySet /** The codomain set of the relation */ - def cod = codf.keys - def codSet = codf.keySet + def cod: Iterable[B] = codf.keys + + def codSet: Set[B] = codf.keySet /** - * The codomain image of a set of domain elements under this relation - * - * @param set A set containing elements of type '''A''' - * @return The set of codomain elements which are in relation with some - * element from '''set''' - */ - def directImage(set: Set[A]) = set.foldLeft(Set[B]()){ (s,x) => s union domf.getOrElse(x, Set()) } + * The codomain image of a set of domain elements under this relation + * + * @param set A set containing elements of type '''A''' + * @return The set of codomain elements which are in relation with some + * element from '''set''' + */ + def directImage(set: Set[A]): Set[B] = set.foldLeft(Set[B]()) { (s, x) => s union domf.getOrElse(x, Set()) } /** - * The domain image of a set of codomain elements under this relation - * - * @param set A set containing elements of type '''B''' - * @return The set of domain elements which are in relation with some - * element from '''set''' - */ - def inverseImage(set: Set[B]) = set.foldLeft(Set[A]()) { (s,y) => s union codf.getOrElse(y, Set()) } - - + * The domain image of a set of codomain elements under this relation + * + * @param set A set containing elements of type '''B''' + * @return The set of domain elements which are in relation with some + * element from '''set''' + */ + def inverseImage(set: Set[B]): Set[A] = set.foldLeft(Set[A]()) { (s, y) => s union codf.getOrElse(y, Set()) } + + /** - * BinRel inherits equality from '''domf''' - * - * @return The hashcode of '''domf''' - */ - override def hashCode = domf.hashCode() + * BinRel inherits equality from '''domf''' + * + * @return The hashcode of '''domf''' + */ + override def hashCode: Int = domf.hashCode() - /** True iff '''other''' is of type '''BinRel[_,_]''' */ - override def canEqual(other: Any) = other match { - case _: BinRel[_,_] => true + /** BinRel inherits equality from '''domf''' */ + override def equals(other: Any): Boolean = other match { + case that: BinRel[_, _] => (that canEqual this) && (this.domf == that.domf) case _ => false } - /** BinRel inherits equality from '''domf''' */ - override def equals(other: Any) = other match { - case that: BinRel[_,_] => (that canEqual this) && (this.domf == that.domf) + /** True iff '''other''' is of type '''BinRel[_,_]''' */ + override def canEqual(other: Any): Boolean = other match { + case _: BinRel[_, _] => true case _ => false } - override def toString() = { - getClass.getSimpleName + "(" + seq.map{ case (k,v) => k.toString + " -> " + v.toString }.mkString(", ") + ")" + override def toString: String = { + getClass.getSimpleName + "(" + seq.map { case (k, v) => k.toString + " -> " + v.toString }.mkString(", ") + ")" } // TODO: get ++ implemented correctly (i.e. using builders etc) for PFun/BinRel - def ++(r:BinRel[A,B]) = r.foldLeft(this) { case (mp, kv) => mp + kv } + def ++(r: BinRel[A, B]): BinRel[A, B] = r.foldLeft(this) { case (mp, kv) => mp + kv } } /** - * A class which represents a binary relation as a pair of two functions - - * the domain map and the codomain map - * - * @tparam A The type of the elements in the domain of the relation - * @tparam B The type of the elements in the codomain of the relation - * - * @constructor Create an instance of the class from two functions mapping - * elements to sets of elements - * - * @param domMap The domain map - maps an element of type '''A''' to the - * set of elements of type '''B''' which are in relation with it - * - * @param codMap The codomain map - maps an element of type '''B''' to the - * set of elements of type '''A''' which are in relation with it - * - * @author Aleks Kissinger - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BinRel.scala Source code]] - */ -class MapPairBinRel[A,B](domMap: Map[A,TreeSet[B]], codMap: Map[B,TreeSet[A]]) - (implicit domOrd: Ordering[A], codOrd: Ordering[B]) - extends BinRel[A,B] { - - def domf = domMap.withDefaultValue(TreeSet()(codOrd)) - def codf = codMap.withDefaultValue(TreeSet()(domOrd)) - - def +(kv: (A,B)) = { - new MapPairBinRel[A,B]( + * A class which represents a binary relation as a pair of two functions - + * the domain map and the codomain map + * + * @tparam A The type of the elements in the domain of the relation + * @tparam B The type of the elements in the codomain of the relation + * @constructor Create an instance of the class from two functions mapping + * elements to sets of elements + * @param domMap The domain map - maps an element of type '''A''' to the + * set of elements of type '''B''' which are in relation with it + * @param codMap The codomain map - maps an element of type '''B''' to the + * set of elements of type '''A''' which are in relation with it + * @author Aleks Kissinger + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BinRel.scala Source code]] + */ +class MapPairBinRel[A, B](domMap: Map[A, TreeSet[B]], codMap: Map[B, TreeSet[A]]) + (implicit domOrd: Ordering[A], codOrd: Ordering[B]) + extends BinRel[A, B] { + + def +(kv: (A, B)): MapPairBinRel[A, B] = { + new MapPairBinRel[A, B]( domMap + (kv._1 -> (domMap.getOrElse(kv._1, TreeSet()(codOrd)) + kv._2)), codMap + (kv._2 -> (codMap.getOrElse(kv._2, TreeSet()(domOrd)) + kv._1)) ) } - def contains(kv: (A,B)) = domf.get(kv._1).contains(kv._2) + def contains(kv: (A, B)) : Boolean = domf.get(kv._1).contains(kv._2) - def unmap(kv: (A,B)) = { - new MapPairBinRel[A,B]( + def unmapDom(d: A): MapPairBinRel[A, B] = domf(d).foldLeft(this) { (rel, c) => rel unmap(d, c) } + + def domf : Map[A, TreeSet[B]] = domMap.withDefaultValue(TreeSet()(codOrd)) + + def unmap(kv: (A, B)): MapPairBinRel[A, B] = { + new MapPairBinRel[A, B]( domMap.get(kv._1) match { - case Some(xs) if (xs.size == 1 && xs.contains(kv._2)) => domMap - kv._1 + case Some(xs) if xs.size == 1 && xs.contains(kv._2) => domMap - kv._1 case Some(xs) => domMap + (kv._1 -> (xs - kv._2)) case None => domMap }, codMap.get(kv._2) match { - case Some(xs) if (xs.size == 1 && xs.contains(kv._1)) => codMap - kv._2 + case Some(xs) if xs.size == 1 && xs.contains(kv._1) => codMap - kv._2 case Some(xs) => codMap + (kv._2 -> (xs - kv._1)) case None => codMap } ) } - def unmapDom(d: A) = domf(d).foldLeft(this) { (rel,c) => rel unmap (d, c) } - def unmapCod(c: B) = codf(c).foldLeft(this) { (rel,d) => rel unmap (d, c) } + def unmapCod(c: B): MapPairBinRel[A, B] = codf(c).foldLeft(this) { (rel, d) => rel unmap(d, c) } + + def codf: Map[B, TreeSet[A]] = codMap.withDefaultValue(TreeSet()(domOrd)) + + def seq: Seq[(A, B)] = iterator.toSeq /** Returns an iterator over pairs '''(a,b)''' of the relation */ - def iterator = domMap.foldLeft(Iterator[(A,B)]()) { case (iter, (domElement, codSet)) => + def iterator: Iterator[(A, B)] = domMap.foldLeft(Iterator[(A, B)]()) { case (iter, (domElement, codSet)) => iter ++ (Iterator.continually(domElement) zip codSet.iterator) } - protected[this] def newBuilder = new mutable.Builder[(A,B),BinRel[A,B]] { - val s = collection.mutable.Buffer[(A,B)]() + protected[this] def newBuilder : mutable.Builder[(A, B), BinRel[A, B]] = new mutable.Builder[(A, B), BinRel[A, B]] { + val s: mutable.Buffer[(A, B)] = collection.mutable.Buffer[(A, B)]() + def result() = BinRel(s: _*) - def clear() = s.clear() - def +=(elem: (A,B)) = { s += elem; this } - } - def seq = iterator.toSeq + def clear(): Unit = s.clear() + + def +=(elem: (A, B)): this.type = { + s += elem + this + } + } } /** - * Companion object for the BinRel trait - * - * @author Aleks Kissinger - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BinRel.scala Source code]] - */ + * Companion object for the BinRel trait + * + * @author Aleks Kissinger + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/BinRel.scala Source code]] + */ object BinRel { /** Construct a binary relation from a sequence of pairs '''(a,b)''' */ - def apply[A,B](kvs: (A,B)*)(implicit domOrd: Ordering[A], codOrd: Ordering[B]) : BinRel[A,B] = { - kvs.foldLeft(new MapPairBinRel[A,B](Map(),Map())(domOrd,codOrd)){ (rel, kv) => rel + kv } + def apply[A, B](kvs: (A, B)*)(implicit domOrd: Ordering[A], codOrd: Ordering[B]): BinRel[A, B] = { + kvs.foldLeft(new MapPairBinRel[A, B](Map(), Map())(domOrd, codOrd)) { (rel, kv) => rel + kv } } } diff --git a/scala/src/main/scala/quanto/data/CompositeExpression.scala b/scala/src/main/scala/quanto/data/CompositeExpression.scala new file mode 100644 index 00000000..5c87ba58 --- /dev/null +++ b/scala/src/main/scala/quanto/data/CompositeExpression.scala @@ -0,0 +1,207 @@ +package quanto.data + +import quanto.data.Theory.ValueType +import quanto.util.Rational + +import scala.util.parsing.combinator.RegexParsers + + +case class MismatchedPhaseData(data1: Vector[ValueType], data2: Vector[ValueType]) + extends Exception(data1.toString + " != " + data2.toString) + +case class GenericParseException(message: String) extends Exception(message) + +case class TypeNotFoundException(message: String) extends Exception(message) + + +// The types are stored as a vector of enums +// The data is stored as a vector of PhaseExpressions, which are then cast according to the type data +case class CompositeExpression(valueTypes: Vector[ValueType], values: Vector[PhaseExpression]) { + + lazy val varsWithType: Set[(ValueType, String)] = + values.zipWithIndex.flatMap(valueWithIndex => + valueWithIndex._1.vars.map(variableName => + (valueTypes(valueWithIndex._2), variableName)) + ).toSet + val vars: Set[String] = values.flatMap(_.vars).toSet + val description: ValueType = ValueType.Empty + + // Addition + def +(e: CompositeExpression): CompositeExpression = { + + if (valueTypes != e.valueTypes) { + throw MismatchedPhaseData(valueTypes, e.valueTypes) + } + + + val summedValues: Vector[PhaseExpression] = + valueTypes.zipWithIndex.map(vi => values(vi._2) + e.values(vi._2)) + CompositeExpression(valueTypes, summedValues) + } + + // Subtraction + def -(e: CompositeExpression): CompositeExpression = { + + + val negatedValues: Vector[PhaseExpression] = + valueTypes.zipWithIndex.map(vi => values(vi._2) - e.values(vi._2)) + CompositeExpression(valueTypes, negatedValues) + } + + // Combine strings of each subvalue + override def toString: String = { + + if (values.forall(e => e == PhaseExpression.zero(e.description))) { + "" + } else { + val stringValues = values.zipWithIndex.map(pi => (pi._1.description, pi._2) match { + case (ValueType.String, _) => + pi._1.toString // Always render strings directly + case (_, 0) => // Always render the first entry directly + pi._1.toString + case (_, _) => // Put a space before anything else to aid legibiliy + " " + pi._1.toString + }) + stringValues.mkString(",") + } + } + + def *(r: Rational): CompositeExpression = { + + val scaledValues: Vector[PhaseExpression] = values.map(v => v * r) + CompositeExpression(valueTypes, scaledValues) + } + + def firstOrError[T <: PhaseExpression](valueType: ValueType): T = { + val typeIndex = valueTypes.zipWithIndex.find(x => x._1 == valueType) + if (typeIndex.nonEmpty) { + values(typeIndex.get._2).asInstanceOf[T] + } else { + throw TypeNotFoundException(valueType.toString + " was not present in " + valueTypes.mkString(",")) + } + } + + def first[T <: PhaseExpression](valueType: ValueType): Option[T] = { + try { + Some(this.firstOrError(valueType)) + } + catch { + case _: Throwable => None + } + } + + def substSubValues(mp: Map[String, PhaseExpression]): CompositeExpression = + mp.foldLeft(this) { case (e, (v, e1)) => e.substSubValue(v, e1) } + + def substSubVariables(mp: Map[(ValueType, String), String]): CompositeExpression = { + substSubValues(mp.map(vss => vss._1._2 -> PhaseExpression.parse(vss._2, vss._1._1) )) + } + + def substSubValue(variableName: String, phase: PhaseExpression): CompositeExpression = { + // Apply substitution to subvalues with the correct valueType + + val newValues: Vector[PhaseExpression] = + valueTypes.zipWithIndex.map(vi => { + val current = values(vi._2) + if (phase.description == current.description) { + current.subst(variableName, phase) + } + else { + current + } + }) + CompositeExpression(valueTypes, newValues) + } + +} + +object CompositeExpression { + + implicit val modulus: Option[Int] = None + + def empty: CompositeExpression = { + CompositeExpression(Vector(), Vector()) + } + + def zero(valueTypes: Vector[ValueType]): CompositeExpression = { + val zeroValues: Vector[PhaseExpression] = valueTypes.map(t => PhaseExpression.zero(t)) + CompositeExpression(valueTypes, zeroValues) + } + + def one(valueTypes: Vector[ValueType]): CompositeExpression = { + val oneValues: Vector[PhaseExpression] = valueTypes.map(t => PhaseExpression.one(t)) + CompositeExpression(valueTypes, oneValues) + } + + def parse(types: String, values: String): CompositeExpression = { + val typeVector = parseTypes(types) + CompositeExpression(typeVector, parseKnowingTypes(values, typeVector)) + } + + def parseKnowingTypes(s: String, v: Vector[ValueType]): Vector[PhaseExpression] = { + // Will fill with empties if more types requested than string elements given + val split: Array[String] = s.split(",") + v.zipWithIndex.map(si => parseSingle(split.lift(si._2).getOrElse(""), si._1)) + } + + def parseSingle(s: String, v: ValueType): PhaseExpression = PhaseExpression.parse(s, v) + + def parseTypes(s: String): Vector[ValueType] = TypeExpressionParser.p(s).toVector + + def wrap[T <: PhaseExpression](expression: T): CompositeExpression = { + CompositeExpression(Vector(expression.description), Vector(expression)) + } + + private object TypeExpressionParser extends RegexParsers { + override def skipWhitespace = true + + // Partial matches will confuse things! + // Make sure a supermatch comes before a submatch + def ANGLE: Parser[ValueType] = + """(angle_expr|(LinRat|)[Aa]ngle)""".r ^^ { _ => ValueType.AngleExpr } + + def BOOL: Parser[ValueType] = """[bB]ool(ean|)""".r ^^ { _ => ValueType.Boolean } + + def RATIONAL: Parser[ValueType] = """[Rr]ational""".r ^^ { _ => ValueType.Rational } + + def INTEGER: Parser[ValueType] = """[Ii]nt(eger|)""".r ^^ { _ => ValueType.Integer } + + def STRING: Parser[ValueType] = """[Ss]tring""".r ^^ { _ => ValueType.String } + + def LONG: Parser[ValueType] = """long(_string|)""".r ^^ { _ => ValueType.Long } + + def ENUM: Parser[ValueType] = """enum""".r ^^ { _ => ValueType.Enum } + + def EMPTY: Parser[ValueType] = """[Ee]mpty""".r ^^ { _ => ValueType.Empty } + + + def term: Parser[ValueType] = + ANGLE | + BOOL | + RATIONAL | + INTEGER | + STRING | + LONG | + ENUM | + EMPTY + + + def terms: Parser[List[ValueType]] = + "(" ~ terms ~ ")" ^^ { case _ ~ t ~ _ => t } | + term ~ "," ~ terms ^^ { case s ~ _ ~ t => s :: t } | + term ^^ { t => List(t) } + + def expr: Parser[List[ValueType]] = + terms | + term ^^ { t => List(t) } | + "" ^^ { _ => List(ValueType.Empty) } + + def p(s: String): List[ValueType] = parseAll(expr, s) match { + case Success(e, _) => e + case Failure(msg, _) => throw GenericParseException(msg) + case Error(msg, _) => throw GenericParseException(msg) + } + } + +} + diff --git a/scala/src/main/scala/quanto/data/Derivation.scala b/scala/src/main/scala/quanto/data/Derivation.scala index 055e666b..eb515168 100644 --- a/scala/src/main/scala/quanto/data/Derivation.scala +++ b/scala/src/main/scala/quanto/data/Derivation.scala @@ -1,11 +1,11 @@ package quanto.data +import quanto.gui.{DeriveState, HeadState, StepState} +import quanto.layout.ForceLayout +import quanto.util.TreeSeq import quanto.util.json._ -import javax.management.remote.rmi._RMIConnectionImpl_Tie + import scala.collection.SortedSet -import quanto.gui.{StepState, HeadState, DeriveState} -import quanto.util.TreeSeq -import quanto.layout.ForceLayout trait DerivationException @@ -13,20 +13,9 @@ case class DerivationLoadException(message: String, cause: Throwable = null) extends Exception(message, cause) with DerivationException -sealed abstract class RuleVariant - -case object RuleNormal extends RuleVariant { - override def toString = "normal" -} - -case object RuleInverse extends RuleVariant { - override def toString = "inverse" -} - case class DStep(name: DSName, ruleName: String, rule: Rule, - variant: RuleVariant, graph: Graph) { def layout: DStep = { val layoutProc = new ForceLayout @@ -36,7 +25,7 @@ case class DStep(name: DSName, //layoutProc.edgeLength = 0.1 layoutProc.gravity = 0.0 - val rhsi = rule.rhs.verts.filter(!rule.rhs.isBoundary(_)) + val rhsi = rule.rhs.verts.filter(!rule.rhs.isTerminalWire(_)) graph.verts.foreach { v => if (!rhsi.contains(v)) layoutProc.lockVertex(v) @@ -52,19 +41,15 @@ case class DStep(name: DSName, def copy(name: DSName = name, ruleName: String = ruleName, rule: Rule = rule, - variant: RuleVariant = variant, graph: Graph = graph) - = DStep(name, ruleName, rule, variant, graph) + = DStep(name, ruleName, rule, graph) } object DStep { - implicit def booleanToVariant(isInverted : Boolean) : RuleVariant = { - if (isInverted) RuleNormal else RuleInverse - } - def apply(name: DSName, rule: Rule, graph: Graph) : DStep = - DStep(name, rule.name, rule, rule.description.inverse, graph) + def apply(name: DSName, rule: Rule, graph: Graph): DStep = + DStep(name, rule.name, rule, graph) def toJson(dstep: DStep, parent: Option[DSName], thy: Theory = Theory.DefaultTheory): Json = { JsonObject( @@ -72,23 +57,25 @@ object DStep { "parent" -> parent.map(_.toString), "rule_name" -> dstep.ruleName, "rule" -> Rule.toJson(dstep.rule, thy), - "rule_variant" -> (dstep.variant match { - case RuleNormal => JsonNull; - case v => v.toString + "rule_variant" -> (if (dstep.rule.description.inverse) { + "inverse" + } else { + "forwards" }), "graph" -> Graph.toJson(dstep.graph, thy) - ).noEmpty + ) } def fromJson(name: DSName, json: Json, thy: Theory = Theory.DefaultTheory): DStep = try { + val baseRule = Rule.fromJson(json / "rule", thy) + val rule: Rule = json ? "rule_variant" match { + case JsonString("inverse") => baseRule.inverse + case _ => baseRule + } DStep( name = name, ruleName = (json / "rule_name").stringValue, - rule = Rule.fromJson(json / "rule", thy), - variant = json ? "rule_variant" match { - case JsonString("inverse") => RuleInverse; - case _ => RuleNormal - }, + rule = rule, graph = Graph.fromJson(json / "graph", thy) ) } catch { @@ -104,8 +91,7 @@ object DStep { } } -case class Derivation(theory: Theory, - root: Graph, +case class Derivation(root: Graph, steps: Map[DSName, DStep] = Map(), heads: SortedSet[DSName] = SortedSet(), parentMap: PFun[DSName, DSName] = PFun()) @@ -118,7 +104,7 @@ case class Derivation(theory: Theory, def stepsTo(head: DSName): Array[DSName] = (parentMap.get(head) match { case Some(p) => stepsTo(p) - case None => Array() + case None => Array.empty[DSName] }) :+ head def graphsTo(head: DSName): Array[Graph] = root +: stepsTo(head).map(s => steps(s).graph) @@ -129,12 +115,6 @@ case class Derivation(theory: Theory, copy(steps = steps + (s -> s1)) } - def copy(theory: Theory = theory, - root: Graph = root, - steps: Map[DSName, DStep] = steps, - heads: SortedSet[DSName] = heads, - parent: PFun[DSName, DSName] = parentMap) = Derivation(theory, root, steps, heads, parent) - def allChildren(s: DSName): Set[DSName] = children(s).foldLeft(Set[DSName]()) { case (set, c) => set union allChildren(c) } + s @@ -146,6 +126,12 @@ case class Derivation(theory: Theory, def addHead(h: DSName): Derivation = copy(heads = heads + h) + def copy( + root: Graph = root, + steps: Map[DSName, DStep] = steps, + heads: SortedSet[DSName] = heads, + parent: PFun[DSName, DSName] = parentMap) = Derivation(root, steps, heads, parent) + def deleteHead(h: DSName): Derivation = copy(heads = heads - h) def addStep(parentOpt: Option[DSName], step: DStep): Derivation = parentOpt match { @@ -169,7 +155,7 @@ case class Derivation(theory: Theory, copy( steps = steps1, heads = parentOpt match { - case Some(s) => heads1 + s; + case Some(`s`) => heads1 + s; case _ => heads1 }, parent = parent1 @@ -222,7 +208,7 @@ case class Derivation(theory: Theory, object Derivation { type DerivationWithHead = (Derivation, Option[DSName]) - def fromJson(json: Json, thy: Theory = Theory.DefaultTheory) = try { + def fromJson(json: Json, thy: Theory = Theory.DefaultTheory): Derivation = try { val parent = (json ? "steps").asObject.foldLeft(PFun[DSName, DSName]()) { case (pf, (step, obj)) => obj.get("parent") match { case Some(JsonString(p)) => pf + (DSName(step) -> DSName(p)) @@ -237,23 +223,22 @@ object Derivation { val heads = (json ? "heads").asArray.foldLeft(SortedSet[DSName]()) { case (set, h) => set + DSName(h.stringValue) } Derivation( - theory = thy, root = Graph.fromJson(json / "root", thy), steps = steps, heads = heads, parentMap = parent ) } catch { - case e: JsonAccessException => throw new DerivationLoadException(e.getMessage, e) + case e: JsonAccessException => throw DerivationLoadException(e.getMessage, e) case e: GraphLoadException => - throw new DerivationLoadException("Graph 'root': " + e.getMessage, e) + throw DerivationLoadException("Graph 'root': " + e.getMessage, e) case e: DerivationLoadException => throw e case e: Exception => e.printStackTrace() - throw new DerivationLoadException("Error reading JSON", e) + throw DerivationLoadException("Error reading JSON", e) } - def toJson(derive: Derivation, thy: Theory = Theory.DefaultTheory) = { + def toJson(derive: Derivation, thy: Theory = Theory.DefaultTheory): JsonObject = { val steps = derive.steps.map { case (k, v) => (k.toString, DStep.toJson(v, derive.parentMap.get(k), thy)) } JsonObject( "root" -> Graph.toJson(derive.root, thy), diff --git a/scala/src/main/scala/quanto/data/EData.scala b/scala/src/main/scala/quanto/data/EData.scala index 7644e342..6ff8e3f6 100644 --- a/scala/src/main/scala/quanto/data/EData.scala +++ b/scala/src/main/scala/quanto/data/EData.scala @@ -3,73 +3,85 @@ package quanto.data import quanto.util.json._ /** - * An abstract class providing an interface for accessing edge data - * - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - * @author Aleks Kissinger - */ + * An abstract class providing an interface for accessing edge data + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + * @author Aleks Kissinger + */ abstract class EData extends GraphElementData { def isDirected: Boolean + def typeInfo = theory.edgeTypes(typ) + /** type of the edge */ - def typ = (data / "type").stringValue + def typ: String = (data / "type").stringValue - def typeInfo = theory.edgeTypes(typ) - def label = data.getOrElse("label","").stringValue - def value = data ? "value" + def label: String = data.getOrElse("label", "").stringValue + + def value: Json = data ? "value" /** Create a copy of the current edge data, but with the new value */ def withValue(v: String): EData def toDirEdge = DirEdge(data, annotation, theory) + def toUndirEdge = UndirEdge(data, annotation, theory) } /** - * A class which represents directed edge data. - * - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - */ + * A class which represents directed edge data. + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + */ case class DirEdge( - data: JsonObject = Theory.DefaultTheory.defaultEdgeData, - annotation: JsonObject = JsonObject(), - theory: Theory = Theory.DefaultTheory -) extends EData { + data: JsonObject = Theory.DefaultTheory.defaultEdgeData, + annotation: JsonObject = JsonObject(), + theory: Theory = Theory.DefaultTheory + ) extends EData { def isDirected = true + def withValue(v: String): DirEdge = copy(data = data.setPath("$.value", v).setPath("$.label", v).asObject) - override def toJson = DirEdge.toJson(this, theory) + + override def toJson: JsonObject = DirEdge.toJson(this, theory) } /** - * A class which represents undirected edge data. - * - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - */ + * A class which represents undirected edge data. + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + */ case class UndirEdge( - data: JsonObject = Theory.DefaultTheory.defaultEdgeData, - annotation: JsonObject = JsonObject(), - theory: Theory = Theory.DefaultTheory -) extends EData { + data: JsonObject = Theory.DefaultTheory.defaultEdgeData, + annotation: JsonObject = JsonObject(), + theory: Theory = Theory.DefaultTheory + ) extends EData { def isDirected = false + def withValue(v: String): UndirEdge = copy(data = data.setPath("$.value", v).setPath("$.label", v).asObject) - override def toJson = UndirEdge.toJson(this, theory) + + override def toJson: JsonObject = UndirEdge.toJson(this, theory) } /** - * Companion object for the DirEdge class. Contains methods to convert to/from - * JSON - * - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - */ + * Companion object for the DirEdge class. Contains methods to convert to/from + * JSON + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + */ object DirEdge { - def toJson(d: EData, theory: Theory) = JsonObject( - "data" -> (if (d.data == theory.edgeTypes(d.typ).defaultData) JsonNull else d.data), + def toJson(d: EData, theory: Theory): JsonObject = JsonObject( + "data" -> (if (d.typ == theory.defaultEdgeType && d.data == theory.edgeTypes(d.typ).defaultData) { + JsonNull + } else { + d.data + }), "annotation" -> d.annotation).noEmpty - def fromJson(json: Json, theory: Theory) : DirEdge = { + + def fromJson(json: Json, theory: Theory): DirEdge = { val data = json.getOrElse("data", theory.defaultEdgeData).asObject val annotation = (json ? "annotation").asObject DirEdge(data, annotation, theory) @@ -77,17 +89,22 @@ object DirEdge { } /** - * Companion object for the UndirEdge class. Contains methods to convert - * to/from JSON - * - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - */ + * Companion object for the UndirEdge class. Contains methods to convert + * to/from JSON + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/EData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + */ object UndirEdge { - def toJson(d: EData, theory: Theory) = JsonObject( - "data" -> (if (d.data == theory.edgeTypes(d.typ).defaultData) JsonNull else d.data), + def toJson(d: EData, theory: Theory): JsonObject = JsonObject( + "data" -> (if (d.typ == theory.defaultEdgeType && d.data == theory.edgeTypes(d.typ).defaultData) { + JsonNull + } else { + d.data + }), "annotation" -> d.annotation).noEmpty - def fromJson(json: Json, theory: Theory) : UndirEdge = { + + def fromJson(json: Json, theory: Theory): UndirEdge = { val data = json.getOrElse("data", theory.defaultEdgeData).asObject val annotation = (json ? "annotation").asObject UndirEdge(data, annotation, theory) diff --git a/scala/src/main/scala/quanto/data/GData.scala b/scala/src/main/scala/quanto/data/GData.scala index 3f4cb887..250d989d 100644 --- a/scala/src/main/scala/quanto/data/GData.scala +++ b/scala/src/main/scala/quanto/data/GData.scala @@ -3,12 +3,13 @@ package quanto.data import quanto.util.json.JsonObject /** - * A class which represents graph data - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/GData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - */ + * A class which represents graph data + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/GData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + */ case class GData( - data: JsonObject = JsonObject(), - annotation: JsonObject = JsonObject(), - theory: Theory = Theory.DefaultTheory -) extends GraphElementData + data: JsonObject = JsonObject(), + annotation: JsonObject = JsonObject(), + theory: Theory = Theory.DefaultTheory + ) extends GraphElementData diff --git a/scala/src/main/scala/quanto/data/Graph.scala b/scala/src/main/scala/quanto/data/Graph.scala index ec812d23..91c67098 100644 --- a/scala/src/main/scala/quanto/data/Graph.scala +++ b/scala/src/main/scala/quanto/data/Graph.scala @@ -1,81 +1,113 @@ package quanto.data -import Names._ + import quanto.cosy.AdjMat +import quanto.data.Names._ +import quanto.data.Theory.ValueType +import quanto.util.json.JsonValues._ import quanto.util.json._ -import math.sqrt -import JsonValues._ -import collection.mutable.ArrayBuffer -import quanto.util._ -import java.awt.datatransfer.{DataFlavor, Transferable} -import scala.annotation.tailrec +import scala.collection.mutable.ArrayBuffer +import scala.math.sqrt +import scala.util.Random class GraphException(msg: String, cause: Throwable = null) extends Exception(msg, cause) + class PluggingException(msg: String) extends GraphException(msg) case class GraphSearchContext(exploredV: Set[VName], exploredE: Set[EName]) class GraphLoadException(message: String, cause: Throwable = null) -extends GraphException(message, cause) + extends GraphException(message, cause) sealed abstract class BBOp { def bb: BBName + def shortName: String + //override def toString: String = shortName } -case class BBExpand(bb: BBName, mp: GraphMap) extends BBOp { def shortName = "E(" + bb + ")" } -case class BBCopy(bb: BBName, mp: GraphMap) extends BBOp { def shortName = "C(" + bb + ")" } -case class BBDrop(bb: BBName) extends BBOp { def shortName = "D(" + bb + ")" } -case class BBKill(bb: BBName) extends BBOp { def shortName = "K(" + bb + ")" } +case class BBExpand(bb: BBName, mp: GraphMap) extends BBOp { + def shortName: String = "E(" + bb + ")" +} + +case class BBCopy(bb: BBName, mp: GraphMap) extends BBOp { + def shortName: String = "C(" + bb + ")" +} + +case class BBDrop(bb: BBName) extends BBOp { + def shortName: String = "D(" + bb + ")" +} + +case class BBKill(bb: BBName) extends BBOp { + def shortName: String = "K(" + bb + ")" +} case class Graph( - data: GData = GData(), - vdata: Map[VName,VData] = Map[VName,VData](), - edata: Map[EName,EData] = Map[EName,EData](), - source: PFun[EName,VName] = PFun[EName,VName](), - target: PFun[EName,VName] = PFun[EName,VName](), - bbdata: Map[BBName,BBData] = Map[BBName,BBData](), - inBBox: BinRel[VName,BBName] = BinRel[VName,BBName](), - bboxParent: PFun[BBName,BBName] = PFun[BBName,BBName]()) extends Ordered[Graph] -{ - def isInput (v: VName): Boolean = vdata(v).isWireVertex && inEdges(v).isEmpty && outEdges(v).size == 1 - def isOutput(v: VName): Boolean = vdata(v).isWireVertex && outEdges(v).isEmpty && inEdges(v).size == 1 - def isInternal(v: VName): Boolean = vdata(v).isWireVertex && outEdges(v).size == 1 && inEdges(v).size == 1 - def isBoundary(vn: VName): Boolean = - vdata(vn).isWireVertex && (inEdges(vn).size + outEdges(vn).size) <= 1 + data: GData = GData(), + vdata: Map[VName, VData] = Map[VName, VData](), + edata: Map[EName, EData] = Map[EName, EData](), + source: PFun[EName, VName] = PFun[EName, VName](), + target: PFun[EName, VName] = PFun[EName, VName](), + bbdata: Map[BBName, BBData] = Map[BBName, BBData](), + inBBox: BinRel[VName, BBName] = BinRel[VName, BBName](), + bboxParent: PFun[BBName, BBName] = PFun[BBName, BBName]()) extends Ordered[Graph] { + lazy val nodesThatAreNotWires: Set[VName] = verts.filterNot(vdata(_).isWireVertex) + protected val factory = new Graph(_, _, _, _, _, _, _, _) + + def isInternalWire(v: VName): Boolean = vdata(v).isWireVertex && outEdges(v).size == 1 && inEdges(v).size == 1 + def isCircle(vn: VName): Boolean = vdata(vn).isWireVertex && inEdges(vn).size == 1 && inEdges(vn) == outEdges(vn) - def typeOf(v: VName): String = vdata(v).typ + // convenience methods + def inEdges(vn: VName): Set[EName] = target.codf(vn) + def isAdjacentToBoundary(v: VName): Boolean = adjacentVerts(v).exists(isBoundary) + + def isBoundary(vn: VName): Boolean = isTerminalWire(vn) + def isAdjacentToType(v: VName, t: String): Boolean = adjacentVerts(v).exists(typeOf(_) == t) - def isWireVertex(v: VName) = vdata(v).isWireVertex - def representsWire(vn: VName) = vdata(vn).isWireVertex && + def typeOf(v: VName): String = vdata(v).typ + + /** Returns a set of vertex names adjacent to vn */ + def adjacentVerts(vn: VName): Set[VName] = predVerts(vn) union succVerts(vn) + + def predVerts(vn: VName): Set[VName] = inEdges(vn).map(source(_)) + + def succVerts(vn: VName): Set[VName] = outEdges(vn).map(target(_)) + + def isWireVertex(v: VName): Boolean = vdata(v).isWireVertex + + def numAdjacentBoundaries(vn: VName): Int = adjacentVerts(vn).count(isTerminalWire) + + def representsWire(vn: VName): Boolean = vdata(vn).isWireVertex && (predVerts(vn).headOption match { case None => true case Some(vn1) => vn == vn1 || !vdata(vn1).isWireVertex }) - def representsBareWire(vn: VName) = - isInput(vn) && + def representsBareWire(vn: VName): Boolean = + isInputWire(vn) && (succVerts(vn).headOption match { case None => false - case Some(vn1) => isOutput(vn1) + case Some(vn1) => isOutputWire(vn1) }) - def arity(v: VName) = adjacentEdges(v).size + def isInputWire(v: VName): Boolean = vdata(v).isWireVertex && inEdges(v).isEmpty && outEdges(v).size == 1 - def verts: Set[VName] = vdata.keySet - def edges: Set[EName] = edata.keySet - def bboxes: Set[BBName] = bbdata.keySet + def isOutputWire(v: VName): Boolean = vdata(v).isWireVertex && outEdges(v).isEmpty && inEdges(v).size == 1 + + def arity(v: VName): Int = adjacentEdges(v).size + + def inputs: Set[VName] = verts.filter(isInputWire) - def inputs: Set[VName] = verts.filter(isInput) - def outputs: Set[VName] = verts.filter(isOutput) - def boundary: Set[VName] = verts.filter(isBoundary) + def outputs: Set[VName] = verts.filter(isOutputWire) + + def boundary: Set[VName] = verts.filter(isTerminalWire) override def hashCode: Int = { var h = data.hashCode @@ -89,8 +121,6 @@ case class Graph( h } - def canEqual(other: Any): Boolean = other.isInstanceOf[Graph] - override def equals(other: Any): Boolean = other match { case that: Graph => (that canEqual this) && vdata == that.vdata && @@ -103,61 +133,72 @@ case class Graph( case _ => false } + def canEqual(other: Any): Boolean = other.isInstanceOf[Graph] - protected val factory = new Graph(_,_,_,_,_,_,_,_) + def vars: Set[String] = vdata.values.foldLeft(Set.empty[String]) { + case (vs, d: NodeV) => + vs ++ d.phaseData.vars + case (vs, _) => vs + } - def copy(data: GData = this.data, - vdata: Map[VName,VData] = this.vdata, - edata: Map[EName,EData] = this.edata, - source: PFun[EName,VName] = this.source, - target: PFun[EName,VName] = this.target, - bbdata: Map[BBName,BBData] = this.bbdata, - inBBox: BinRel[VName,BBName] = this.inBBox, - bboxParent: PFun[BBName,BBName] = this.bboxParent): Graph = - factory(data,vdata,edata,source,target,bbdata,inBBox,bboxParent) + // The set of nodes and wire-boundaires at ends of wires attached to vn + def adjacentNodesAndBoundaries(vn: VName): Set[VName] = adjacentEdges(vn).flatMap(e => edgeEndPoints(e)._1) - vn - // convenience methods - def inEdges(vn: VName): Set[EName] = target.codf(vn) - def outEdges(vn: VName): Set[EName] = source.codf(vn) - def predVerts(vn: VName): Set[VName] = inEdges(vn).map(source(_)) - def succVerts(vn: VName): Set[VName] = outEdges(vn).map(target(_)) - def contents(bbn: BBName): Set[VName] = inBBox.codf(bbn) - def bboxesContaining(vn: VName): Set[BBName] = inBBox.domf(vn) + /** + * Traverse the wire nodes to find the non-wire endpoints of an edge + * + * @param edge Edge to focus on + * @return endPoints, edgesBetweenThem, wireNodesBetweenThem + */ + def edgeEndPoints(edge: EName): (Set[VName], Set[EName], Set[VName]) = { + // Given an edge find its two end points, + // where wire nodes of arity 2 do not count as endpoints + // If given a loop then returns an empty set + var edgesVisited: Set[EName] = Set() + var nodesVisited: Set[VName] = Set() - def vars: Set[String] = vdata.values.foldLeft(Set.empty[String]) { - case (vs, d: NodeV) => - vs ++ d.angle.vars - case (vs,_) => vs - } + var endPointsFound: Set[VName] = Set() - /** Returns a set of vertex names adjacent to vn */ - def adjacentVerts(vn: VName): Set[VName] = predVerts(vn) union succVerts(vn) + // slightly different to arity, since arity counts self-loops as one edge not two + def inPlusOut(vertex: VName): Int = { + source.codf(vertex).size + target.codf(vertex).size + } + + def considerNode(v2: VName) { + if (!nodesVisited.contains(v2)) { + if (!vdata(v2).isWireVertex || inPlusOut(v2) != 2) { + endPointsFound += v2 + } else { + nodesVisited += v2 + adjacentEdges(v2).foreach(considerEdge) + } + } + } + + def considerEdge(e2: EName) { + if (!edgesVisited.contains(e2)) { + edgesVisited += e2 + considerNode(source(e2)) + considerNode(target(e2)) + } + } + + considerEdge(edge) + (endPointsFound, edgesVisited, nodesVisited) + } /** Returns a set of vertex names adjacent to, and including, vset */ def extendToAdjacentVerts(vset: Set[VName]): Set[VName] = - vset.foldRight(Set[VName]()) { (v,vs) => (vs union adjacentVerts(v)) + v } + vset.foldRight(Set[VName]()) { (v, vs) => (vs union adjacentVerts(v)) + v } /** Returns a set of vertex names adjacent to, but not including, vset */ def adjacentVerts(vset: Set[VName]): Set[VName] = - vset.foldRight(Set[VName]()) { (v,vs) => vs union adjacentVerts(v) } -- vset - - /** Returns a set of edge names adjacent to vn */ - def adjacentEdges(vn: VName): Set[EName] = source.codf(vn) union target.codf(vn) + vset.foldRight(Set[VName]()) { (v, vs) => vs union adjacentVerts(v) } -- vset /** Returns a set of edge names adjacent to any vertex in vset */ def adjacentEdges(vset: Set[VName]): Set[EName] = - vset.foldRight(Set[EName]()) { (v,es) => es union adjacentEdges(v) } - - /** Returns a set of edge names which connect v1 to v2 or vice versa */ - def edgesBetween(v1: VName, v2: VName): Set[EName] = { - if (v1 == v2) { - source.codf(v1) intersect target.codf(v1) - } - else { - adjacentEdges(v1) intersect adjacentEdges(v2) - } - } + vset.foldRight(Set[EName]()) { (v, es) => es union adjacentEdges(v) } /** * If "e" is not a self-loop, get the vertex connected @@ -170,6 +211,7 @@ case class Graph( /** * Get the other edge connected to this wire vertex, if there is one + * * @param w a wire vertex * @param e an edge * @return an edge, optionally @@ -180,43 +222,49 @@ case class Graph( else throw new GraphException("Wire: " + w + " is not connected to edge: " + e) } + /** Returns a set of edge names adjacent to vn */ + def adjacentEdges(vn: VName): Set[EName] = source.codf(vn) union target.codf(vn) + /** - * Partition of all edges into sets, s.t. they connect the same two vertices - * regardless of edge direction - */ - def edgePartition() : List[Set[EName]] = { - var res : List[Set[EName]] = List() - for ((v1,_) <- vdata; (v2,_) <- vdata if v1 <= v2) { - val edgeSet = edgesBetween(v1,v2) + * Partition of all edges into sets, s.t. they connect the same two vertices + * regardless of edge direction + */ + def edgePartition: List[Set[EName]] = { + var res: List[Set[EName]] = List() + for ((v1, _) <- vdata; (v2, _) <- vdata if v1 <= v2) { + val edgeSet = edgesBetween(v1, v2) if (edgeSet.nonEmpty) res = edgeSet :: res } res } + /** Returns a set of edge names which connect v1 to v2 or vice versa */ + def edgesBetween(v1: VName, v2: VName): Set[EName] = { + if (v1 == v2) { + source.codf(v1) intersect target.codf(v1) + } + else { + adjacentEdges(v1) intersect adjacentEdges(v2) + } + } + def isBBoxed(v: VName): Boolean = inBBox.domf(v).nonEmpty /// by song // to compute whether two vertices are in the same bbox. - def isInSameBBox(v1:VName, v2:VName): Boolean = (inBBox.domf(v1) & inBBox.domf(v2)).nonEmpty - - def addVertex(vn: VName, data: VData): Graph = { - if (vdata contains vn) - throw new DuplicateVertexNameException(vn) - - copy(vdata = vdata + (vn -> data)) - } + def isInSameBBox(v1: VName, v2: VName): Boolean = (inBBox.domf(v1) & inBBox.domf(v2)).nonEmpty /** - * @return A new graph where all vertices have coordinates which align to a - * grid - */ + * @return A new graph where all vertices have coordinates which align to a + * grid + */ def snapToGrid(): Graph = { - def roundCoord(d : Double) = { + def roundCoord(d: Double) = { math.rint(d * 4.0) / 4.0 // rounds to .25 } - val snapped_vdata = vdata.mapValues {vd => + val snapped_vdata = vdata.mapValues { vd => val coord = vd.coord vd.withCoord(roundCoord(coord._1), roundCoord(coord._2)) } @@ -228,19 +276,11 @@ case class Graph( (addVertex(vn, data), vn) } - def addEdge(en: EName, data: EData, vns: (VName, VName)): Graph = { - if (edata contains en) - throw new DuplicateEdgeNameException(en) - if (!vdata.contains(vns._1)) - throw new GraphException("Edge: " + en + " has no endpoint: " + vns._1 + " in graph") - if (!vdata.contains(vns._1)) - throw new GraphException("Edge: " + en + " has no endpoint: " + vns._2 + " in graph") + def addVertex(vn: VName, data: VData): Graph = { + if (vdata contains vn) + throw new DuplicateVertexNameException(vn) - copy( - edata = edata + (en -> data), - source = source + (en -> vns._1), - target = target + (en -> vns._2) - ) + copy(vdata = vdata + (vn -> data)) } def newEdge(data: EData, vns: (VName, VName)): (Graph, EName) = { @@ -254,7 +294,7 @@ case class Graph( val g1 = copy( bbdata = bbdata + (bbn -> data), - inBBox = contents.foldLeft(inBBox){ (x,v) => x + (v -> bbn) } + inBBox = contents.foldLeft(inBBox) { (x, v) => x + (v -> bbn) } ) parent match { @@ -265,10 +305,11 @@ case class Graph( /** * A list of bbox parents, with the closest ancestor first + * * @param bb a bbox * @return */ - def bboxParentList(bb : BBName): List[BBName] = + def bboxParentList(bb: BBName): List[BBName] = bboxParent.get(bb) match { case Some(bb1) => bb1 :: bboxParentList(bb1) case None => List() @@ -276,7 +317,8 @@ case class Graph( /** * The set of all parents for a given bbox - * @param bb + * + * @param bb Bounding Box * @return */ def bboxParents(bb: BBName): Set[BBName] = @@ -285,16 +327,15 @@ case class Graph( case None => Set() } - def bboxChildren(bb: BBName) : Set[BBName] = + def bboxChildren(bb: BBName): Set[BBName] = bboxParent.codf(bb) - def addToBBox(v: VName, bb: BBName): Graph = { - copy(inBBox = inBBox + (v,bb)) + copy(inBBox = inBBox + (v, bb)) } def addToBBoxes(v: VName, bboxes: Set[BBName]): Graph = { - copy(inBBox = bboxes.foldRight(inBBox) { (bb,mp) => mp + (v,bb) }) + copy(inBBox = bboxes.foldRight(inBBox) { (bb, mp) => mp + (v, bb) }) } /** Replace the contents of a bang box with new ones @@ -306,11 +347,11 @@ case class Graph( var inBB = inBBox for (bb1 <- updateBB) { - oldContents.foreach {v => inBB -= (v -> bb1) } - newContents.foreach {v => inBB += (v -> bb1) } + oldContents.foreach { v => inBB -= (v -> bb1) } + newContents.foreach { v => inBB += (v -> bb1) } } - copy( inBBox = inBB ) + copy(inBBox = inBB) } /** Change bbox parent. All contents will be removed from old parents and added to @@ -334,17 +375,16 @@ case class Graph( } - for (bbp <- oldParents) { - cont.foreach {v => inBB -= (v -> bbp) } + cont.foreach { v => inBB -= (v -> bbp) } } for (bbp <- newParents) { - cont.foreach {v => inBB += (v -> bbp) } + cont.foreach { v => inBB += (v -> bbp) } } - copy( inBBox = inBB , bboxParent = bbP ) + copy(inBBox = inBB, bboxParent = bbP) } def newBBox(data: BBData, contents: Set[VName] = Set[VName](), parent: Option[BBName] = None): (Graph, BBName) = { @@ -360,15 +400,7 @@ case class Graph( ) } - def deleteEdge(en: EName): Graph = { - copy( - edata = edata - en, - source = source.unmapDom(en), - target = target.unmapDom(en) - ) - } - - def deleteEdges(es: Set[EName]): Graph = es.foldRight(this) { (e,g) => g.deleteEdge(e) } + def deleteEdges(es: Set[EName]): Graph = es.foldRight(this) { (e, g) => g.deleteEdge(e) } def safeDeleteVertex(vn: VName): Graph = { if (source.codf(vn).nonEmpty || target.codf(vn).nonEmpty) @@ -378,54 +410,146 @@ case class Graph( copy(vdata = vdata - vn, inBBox = inBBox.unmapDom(vn)) } - def deleteVertex(vn: VName): Graph = { + def deleteVertices(vs: Set[VName]): Graph = vs.foldRight(this) { (v, g) => g.deleteVertex(v) } + + def cutVertex(vertexName: VName): (Graph, Set[VName], Set[VName]) = cutVertex(vertexName, verts.filter(isBoundary)) + /** + * Delete a vertex, but leave edges dangling if they were attached to another non-boundary node, + * removes boundaries adjacent to the cut vertex, + * dangling edges have a newly created boundary at one end. + * + * If trying to remove a boundary, consider just using deleteVertex instead. + * + * @param vertexName Vertex Name + * @param removingBoundaries If any of these are neighbours then remove whole-cloth + * @return (Cut graph, new boundaries, removed boundaries) + */ + def cutVertex(vertexName: VName, removingBoundaries: Set[VName]): (Graph, Set[VName], Set[VName]) = { var g = this - for (e <- source.codf(vn)) g = g.deleteEdge(e) - for (e <- target.codf(vn)) g = g.deleteEdge(e) - g.copy(vdata = vdata - vn, inBBox = inBBox.unmapDom(vn)) - } + var newBoundaries: Set[VName] = Set() + var oldBoundaries: Set[VName] = Set() - def deleteVertices(vs: Set[VName]): Graph = vs.foldRight(this) { (v, g) => g.deleteVertex(v) } + def midPoint(v1: VData, v2: VData): (Double, Double) = { + val c1 = v1.coord + val c2 = v2.coord + val x = (c1._1 + c2._1) / 2.0 + val y = (c1._2 + c2._2) / 2.0 + (x, y) + } + + def breakEdge(g: Graph, e: EName): Graph = { + var g2 = g + val (ends, edges, wireNodes) = edgeEndPoints(e) + if (ends.size == 2) { + val joinNode = (ends - vertexName).head + if (g2.isTerminalWire(joinNode) && removingBoundaries.contains(joinNode)) { + oldBoundaries += joinNode + // Delete boundaries that would otherwise float after vertex removal + g2 = g2.deleteVertex(joinNode) + } else { + // Create a boundary where we cut the wire + val bName = g.verts.freshWithSuggestion(VName("c-" + vertexName.s + "-b")) + newBoundaries += bName + val oldSource : VName = g.source(e) + val direction = if (oldSource == joinNode) joinNode -> bName else bName -> joinNode + g2 = g2.addVertex(bName, WireV()). + addEdge(g.edges.freshWithSuggestion(e), g.edata(e), direction) + // Add a coordinate to our new boundary + val newCoordinate = midPoint(g.vdata(vertexName), g.vdata(joinNode)) + g2 = g2.updateVData(bName) { vd => vd.withCoord(newCoordinate) } + } + g2 = g2.deleteEdges(edges) + g2 = g2.deleteVertices(wireNodes) + } else { + // the edge given was part of a loop + // should never end up here. + } + g2 + } + + g = source.codf(vertexName).foldLeft(g) { (g, e) => breakEdge(g, e) } + g = target.codf(vertexName).foldLeft(g) { (g, e) => breakEdge(g, e) } + g = g.deleteVertex(vertexName) + (g, newBoundaries, oldBoundaries) + } // data updaters def updateData(f: GData => GData): Graph = copy(data = f(data)) - def updateVData(vn: VName)(f: VData => VData): Graph = copy(vdata = vdata + (vn -> f(vdata(vn)))) + def updateEData(en: EName)(f: EData => EData): Graph = copy(edata = edata + (en -> f(edata(en)))) + def updateBBData(bbn: BBName)(f: BBData => BBData): Graph = copy(bbdata = bbdata + (bbn -> f(bbdata(bbn)))) - def rename(vrn: Map[VName,VName], ern: Map[EName, EName], brn: Map[BBName,BBName]): Graph = { + def rename(vrn: Map[VName, VName] = Map(), ern: Map[EName, EName] = Map(), brn: Map[BBName, BBName] = Map()): Graph = { // compute inverses -// val vrni = vrn.foldLeft(Map[VName,VName]()) { case (mp, (k,v)) => mp + (v -> k) } -// val erni = ern.foldLeft(Map[EName,EName]()) { case (mp, (k,v)) => mp + (v -> k) } -// val brni = brn.foldLeft(Map[BBName,BBName]()) { case (mp, (k,v)) => mp + (v -> k) } - - val vdata1 = vdata.foldLeft(Map[VName,VData]()) { case (mp, (k,v)) => mp + (vrn(k) -> v)} - val edata1 = edata.foldLeft(Map[EName,EData]()) { case (mp, (k,v)) => mp + (ern(k) -> v)} - val bbdata1 = bbdata.foldLeft(Map[BBName,BBData]()) { case (mp, (k,v)) => mp + (brn(k) -> v)} - val source1 = source.foldLeft(PFun[EName,VName]()) { case (mp, (k,v)) => mp + (ern(k) -> vrn(v))} - val target1 = target.foldLeft(PFun[EName,VName]()) { case (mp, (k,v)) => mp + (ern(k) -> vrn(v))} - val inBBox1 = inBBox.foldLeft(BinRel[VName,BBName]()) { case (mp, (k,v)) => mp + (vrn(k) -> brn(v))} - val bboxParent1 = bboxParent.foldLeft(PFun[BBName,BBName]()) { case (mp, (k,v)) => mp + (brn(k) -> brn(v))} - - copy(vdata=vdata1,edata=edata1,source=source1,target=target1, - bbdata=bbdata1,inBBox=inBBox1,bboxParent=bboxParent1) + // val vrni = vrn.foldLeft(Map[VName,VName]()) { case (mp, (k,v)) => mp + (v -> k) } + // val erni = ern.foldLeft(Map[EName,EName]()) { case (mp, (k,v)) => mp + (v -> k) } + // val brni = brn.foldLeft(Map[BBName,BBName]()) { case (mp, (k,v)) => mp + (v -> k) } + + val vdata1 = vdata.foldLeft(Map[VName, VData]()) { case (mp, (k, v)) => mp + (vrn.getOrElse(k, k) -> v) } + val edata1 = edata.foldLeft(Map[EName, EData]()) { case (mp, (k, v)) => mp + (ern.getOrElse(k, k) -> v) } + val bbdata1 = bbdata.foldLeft(Map[BBName, BBData]()) { case (mp, (k, v)) => mp + (brn.getOrElse(k, k) -> v) } + val source1 = source.foldLeft(PFun[EName, VName]()) { case (mp, (k, v)) => mp + (ern.getOrElse(k, k) -> vrn.getOrElse(v, v)) } + val target1 = target.foldLeft(PFun[EName, VName]()) { case (mp, (k, v)) => mp + (ern.getOrElse(k, k) -> vrn.getOrElse(v, v)) } + val inBBox1 = inBBox.foldLeft(BinRel[VName, BBName]()) { case (mp, (k, v)) => mp + (vrn.getOrElse(k, k) -> brn.getOrElse(v, v)) } + val bboxParent1 = bboxParent.foldLeft(PFun[BBName, BBName]()) { case (mp, (k, v)) => mp + (brn.getOrElse(k, k) -> brn.getOrElse(v, v)) } + + copy(vdata = vdata1, edata = edata1, source = source1, target = target1, + bbdata = bbdata1, inBBox = inBBox1, bboxParent = bboxParent1) } // get a subgraph consisting of the given vertices and bboxes, with any edges/nesting between them def fullSubgraph(vs: Set[VName], bbs: Set[BBName]): Graph = { val es = edges.filter { e => vs.contains(source(e)) && vs.contains(target(e)) } - val vdata1 = vdata.filter { case (v,_) => vs.contains(v) } - val edata1 = edata.filter { case (e,_) => es.contains(e) } - val source1 = source.filter { case (e,_) => es.contains(e) } - val target1 = target.filter { case (e,_) => es.contains(e) } - val inBBox1 = inBBox.filter{ case (v,b) => vs.contains(v) && bbs.contains(b) } - val bbdata1 = bbdata.filter { case (b,_) => bbs.contains(b) } - val bboxParent1 = bboxParent.filter { case (b1,b2) => bbs.contains(b1) && bbs.contains(b2) } + val vdata1 = vdata.filter { case (v, _) => vs.contains(v) } + val edata1 = edata.filter { case (e, _) => es.contains(e) } + val source1 = source.filter { case (e, _) => es.contains(e) } + val target1 = target.filter { case (e, _) => es.contains(e) } + val inBBox1 = inBBox.filter { case (v, b) => vs.contains(v) && bbs.contains(b) } + val bbdata1 = bbdata.filter { case (b, _) => bbs.contains(b) } + val bboxParent1 = bboxParent.filter { case (b1, b2) => bbs.contains(b1) && bbs.contains(b2) } - copy(data=GData(),vdata=vdata1,edata=edata1,source=source1,target=target1, - bbdata=bbdata1,inBBox=inBBox1,bboxParent=bboxParent1) + copy(data = GData(), vdata = vdata1, edata = edata1, source = source1, target = target1, + bbdata = bbdata1, inBBox = inBBox1, bboxParent = bboxParent1) + } + + def renameAvoiding(g: Graph): Graph = makeRenaming(g.verts, g.edges, g.bboxes).image(this) + + // form a new graph by merging the given (non-empty) set of vertices into a new vertex with the given name + def mergeVertices(vs: Set[VName], newV: VName): Graph = { + val rep = vs.head + val bboxes = inBBox.directImage(vs) + + val source1 = source.inverseImage(vs).foldRight(source) { (e, mp) => mp + (e -> newV) } + val target1 = target.inverseImage(vs).foldRight(target) { (e, mp) => mp + (e -> newV) } + + val g1 = if (verts.contains(newV)) this else addVertex(newV, vdata(rep)) + + g1.copy(source = source1, target = target1) + .deleteVertices(vs - newV) + .addToBBoxes(newV, bboxes) + } + + // add g to this graph, plugging 'b' in this graph into 'bg' in g. + def plugGraph(g: Graph, b: VName, bg: VName): Graph = { + // freshen target graph w.r.t. source + val mp = makeRenaming(verts, edges, bboxes) + val g1 = mp.image(g) + val bg1 = mp.v(bg) + + // re-position g relative to this graph, using the boundaries as a guide + val bgcoord = g1.vdata(bg1).coord + val bcoord = vdata(b).coord + val dx = bcoord._1 - bgcoord._1 + val dy = bcoord._2 - bgcoord._2 + + val g2 = g1.verts.foldLeft(g1) { (g1, v) => + g1.updateVData(v) { d => d.withCoord(d.coord._1 + dx, d.coord._2 + dy) } + } + + this.appendGraph(g2).plugBoundaries(b, bg1) } def makeRenaming(avoidVerts: Set[VName], @@ -458,22 +582,26 @@ case class Graph( mp } - def renameAvoiding(g: Graph): Graph = makeRenaming(g.verts, g.edges, g.bboxes).image(this) + def edges: Set[EName] = edata.keySet + + def bboxes: Set[BBName] = bbdata.keySet // append the given graph. note that its names should already be fresh - def appendGraph(g: Graph): Graph = { + def appendGraph(g: Graph, noOverlap : Boolean = true): Graph = { val coords = verts.map(vdata(_).coord) // Pick any vertex in g and offset until that vertex is not sitting exactly // on top of another. var offset = 0.0 - g.verts.headOption.foreach { v1 => - val (x,y) = g.vdata(v1).coord - while (coords.contains((x + offset, y))) offset += 1.0 + if(noOverlap) { + g.verts.headOption.foreach { v1 => + val (x, y) = g.vdata(v1).coord + while (coords.contains((x + offset, y))) offset += 1.0 + } } - val g1 = g.verts.foldLeft(g) { (g1,v) => - g1.updateVData(v) { d => d.withCoord (d.coord._1 + offset, d.coord._2) } + val g1 = g.verts.foldLeft(g) { (g1, v) => + g1.updateVData(v) { d => d.withCoord(d.coord._1 + offset, d.coord._2) } } copy( @@ -487,65 +615,55 @@ case class Graph( ) } + def updateVData(vn: VName)(f: VData => VData): Graph = copy(vdata = vdata + (vn -> f(vdata(vn)))) + + def updateAllVData(f: VData => VData): Graph = { + val graph = this + graph.verts.foldLeft(graph) { (g, v) => + g.updateVData(v)(f) + } + } + def plugBoundaries(b1: VName, b2: VName): Graph = { // pull a boundary edge, which we'll inherit the data from val be = inEdges(b2).headOption.getOrElse( - outEdges(b2).headOption.getOrElse ( + outEdges(b2).headOption.getOrElse( throw new PluggingException("Target boundary is an isolated point."))) val beData = edata(be) // figure out who should be the source and target - val (s,t) = ( + val (s, t) = ( predVerts(b1).headOption, succVerts(b1).headOption, predVerts(b2).headOption, succVerts(b2).headOption) match { - case (None, Some(t1), Some(s1), None) => (s1,t1) - case (Some(s1), None, None, Some(t1)) => (s1,t1) - case (Some(s1), None, Some(t1), None) if !beData.isDirected => (s1,t1) - case (None, Some(s1), None, Some(t1)) if !beData.isDirected => (s1,t1) + case (None, Some(t1), Some(s1), None) => (s1, t1) + case (Some(s1), None, None, Some(t1)) => (s1, t1) + case (Some(s1), None, Some(t1), None) if !beData.isDirected => (s1, t1) + case (None, Some(s1), None, Some(t1)) if !beData.isDirected => (s1, t1) case _ => throw new PluggingException("Bad boundary arity") } - this.deleteVertex(b1).deleteVertex(b2).addEdge(be, beData, (s,t)) + this.deleteVertex(b1).deleteVertex(b2).addEdge(be, beData, (s, t)) } - // form a new graph by merging the given (non-empty) set of vertices into a new vertex with the given name - def mergeVertices(vs: Set[VName], newV: VName): Graph = { - val rep = vs.head - val bboxes = inBBox.directImage(vs) - - val source1 = source.inverseImage(vs).foldRight(source) { (e, mp) => mp + (e -> newV) } - val target1 = target.inverseImage(vs).foldRight(target) { (e, mp) => mp + (e -> newV) } - - val g1 = if (verts.contains(newV)) this else addVertex(newV, vdata(rep)) + def deleteVertex(vn: VName): Graph = { + var g = this + for (e <- source.codf(vn)) g = g.deleteEdge(e) + for (e <- target.codf(vn)) g = g.deleteEdge(e) - g1.copy(source = source1, target = target1) - .deleteVertices(vs - newV) - .addToBBoxes(newV, bboxes) + g.copy(vdata = vdata - vn, inBBox = inBBox.unmapDom(vn)) } - // add g to this graph, plugging 'b' in this graph into 'bg' in g. - def plugGraph(g: Graph, b: VName, bg: VName): Graph = { - // freshen target graph w.r.t. source - val mp = makeRenaming(verts, edges, bboxes) - val g1 = mp.image(g) - val bg1 = mp.v(bg) - - // re-position g relative to this graph, using the boundaries as a guide - val bgcoord = g1.vdata(bg1).coord - val bcoord = vdata(b).coord - val dx = bcoord._1 - bgcoord._1 - val dy = bcoord._2 - bgcoord._2 - - val g2 = g1.verts.foldLeft(g1) { (g1,v) => - g1.updateVData(v) { d => d.withCoord (d.coord._1 + dx, d.coord._2 + dy) } - } - - this.appendGraph(g2).plugBoundaries(b, bg1) + def deleteEdge(en: EName): Graph = { + copy( + edata = edata - en, + source = source.unmapDom(en), + target = target.unmapDom(en) + ) } def compare(that: Graph): Int = { - val x : Graph = this - val y : Graph = that + val x: Graph = this + val y: Graph = that if (x.verts.size > y.verts.size) { 1 } else { @@ -572,78 +690,28 @@ case class Graph( | bboxes: %s, | nesting: %s |}""".stripMargin.format( - data, vdata, - edata.map(kv => kv._1 -> "(%s => %s)::%s".format(source(kv._1), target(kv._1), kv._2)), - bbdata.map(kv => kv._1 -> "%s::%s".format(inBBox.codf(kv._1), kv._2)), - bboxParent.toString - ) - } - - private def dftSuccessors[T](fromV: VName, exploredV: Set[VName], exploredE: Set[EName])(base: T) - (f: (T, EName, GraphSearchContext) => T): (T, Set[VName], Set[EName]) = - { - val nextEs = outEdges(fromV).filter(!exploredE.contains(_)) - - if (nextEs.nonEmpty) { - val e = nextEs.min - val nextV = target(e) - - val (base1, exploredV1, exploredE1) = - dftSuccessors(nextV, exploredV + nextV, exploredE + e)(base)(f) - val (base2, exploredV2, exploredE2) = - dftSuccessors(fromV, exploredV1, exploredE1)(base1)(f) - - val context = GraphSearchContext(exploredV2, exploredE2) - (f(base2, e, context), exploredV2, exploredE2) - } else { - (base, exploredV, exploredE) - } - } - - private def dftComponents[T](exploredV: Set[VName], exploredE: Set[EName])(base: T) - (f: (T, EName, GraphSearchContext) => T) : T = - { - val nextVs = vdata.keySet.filter(!exploredV.contains(_)) - val initialVs = nextVs.filter(inEdges(_).isEmpty) - - // Try to start with the minimal unexplored vertex with no in-edges. Failing that, start with the - // minimal unexplored vertex. - val vOpt = if (initialVs.nonEmpty) Some(initialVs.min) - else if (nextVs.nonEmpty) Some(nextVs.min) - else None - - vOpt match { - case Some(v) => - val (base1, exploredV1, exploredE1) = dftSuccessors(v, exploredV + v, exploredE)(base)(f) - dftComponents[T](exploredV1, exploredE1)(base1)(f) - case None => base - } + data, vdata, + edata.map(kv => kv._1 -> "(%s => %s)::%s".format(source(kv._1), target(kv._1), kv._2)), + bbdata.map(kv => kv._1 -> "%s::%s".format(inBBox.codf(kv._1), kv._2)), + bboxParent.toString + ) } def dft[T](base: T)(f: (T, EName, GraphSearchContext) => T): T = dftComponents(Set[VName](), Set[EName]())(base)(f) - - private def bbDft(bb : BBName, bbSeq : collection.mutable.Buffer[BBName], bbs : collection.mutable.Set[BBName]) { - for (ch <- bboxParent.codf(bb)) bbDft(ch, bbSeq, bbs) - if (bbs.contains(bb)) { - bbs.remove(bb) - bbSeq += bb - } - } - def bboxesChildrenFirst: Seq[BBName] = { val bbSeq = collection.mutable.Buffer[BBName]() - val bbs = collection.mutable.Set[BBName](bboxes.toSeq : _*) - while (bbs.nonEmpty) bbDft(bbs.iterator.next(),bbSeq,bbs) + val bbs = collection.mutable.Set[BBName](bboxes.toSeq: _*) + while (bbs.nonEmpty) bbDft(bbs.iterator.next(), bbSeq, bbs) - bbSeq.toSeq + bbSeq } // returns a topo ordering. If graph is a dag, all edges will be consistent with this ordering def topologicalOrdering: PartialOrdering[VName] = { val visited = collection.mutable.Set[VName]() - var ordMap = Map[VName,Int]() + var ordMap = Map[VName, Int]() var max = 0 def visit(v: VName) { @@ -663,9 +731,9 @@ case class Graph( case _ => None } - def lteq(x: VName, y: VName): Boolean = tryCompare(x,y) match { - case Some(c) => c != 1 - case None => false + def lteq(x: VName, y: VName): Boolean = tryCompare(x, y) match { + case Some(c) => c != 1 + case None => false } } } @@ -673,14 +741,14 @@ case class Graph( def dagCopy: Graph = { // make a copy with no edges val noEdges = copy( - edata = Map[EName,EData](), - source = PFun[EName,VName](), - target = PFun[EName,VName]() + edata = Map[EName, EData](), + source = PFun[EName, VName](), + target = PFun[EName, VName]() ) val ord = this.topologicalOrdering - dft(noEdges) { (graph, e, context) => + dft(noEdges) { (graph, e, _) => val s = source(e) val t = target(e) @@ -688,20 +756,20 @@ case class Graph( else { // reverse back-edges to break cycles graph.addEdge(e, edata(e), - if (ord.lteq(s,t)) (s,t) else (t,s)) + if (ord.lteq(s, t)) (s, t) else (t, s)) } } } def simpleCopy: Graph = { var g = copy( - edata = Map[EName,EData](), - source = PFun[EName,VName](), - target = PFun[EName,VName]() + edata = Map[EName, EData](), + source = PFun[EName, VName](), + target = PFun[EName, VName]() ) for (v <- verts; w <- verts) { outEdges(v).find(target(_) == w) match { - case Some(e) => g = g.addEdge(e, edata(e), (v,w)) + case Some(e) => g = g.addEdge(e, edata(e), (v, w)) case None => () } } @@ -709,6 +777,35 @@ case class Graph( g } + def verts: Set[VName] = vdata.keySet + + def copy(data: GData = this.data, + vdata: Map[VName, VData] = this.vdata, + edata: Map[EName, EData] = this.edata, + source: PFun[EName, VName] = this.source, + target: PFun[EName, VName] = this.target, + bbdata: Map[BBName, BBData] = this.bbdata, + inBBox: BinRel[VName, BBName] = this.inBBox, + bboxParent: PFun[BBName, BBName] = this.bboxParent): Graph = + factory(data, vdata, edata, source, target, bbdata, inBBox, bboxParent) + + def outEdges(vn: VName): Set[EName] = source.codf(vn) + + def addEdge(en: EName, data: EData, vns: (VName, VName)): Graph = { + if (edata contains en) + throw new DuplicateEdgeNameException(en) + if (!vdata.contains(vns._1)) + throw new GraphException("Edge: " + en + " has no endpoint: " + vns._1 + " in graph") + if (!vdata.contains(vns._1)) + throw new GraphException("Edge: " + en + " has no endpoint: " + vns._2 + " in graph") + + copy( + edata = edata + (en -> data), + source = source + (en -> vns._1), + target = target + (en -> vns._2) + ) + } + def expandWire(w: VName): (Graph, (VName, VName, EName)) = { val ed = adjacentEdges(w).headOption match { case Some(e) => edata(e) @@ -725,7 +822,7 @@ case class Graph( outEdges(w).headOption match { case None => // 'w' is an output, so it should stay an output g = g.addEdge(newE, ed, newW -> w) - inEdges(w).headOption.foreach{e => + inEdges(w).headOption.foreach { e => g = g.deleteEdge(e).addEdge(e, ed, source(e) -> newW) } @@ -770,8 +867,8 @@ case class Graph( // direct edge in the same direction as the minimum of the two edges connected to w val endPoints = - if (target(e1) == w) (source(e1), edgeGetOtherVertex(e2,w)) - else (edgeGetOtherVertex(e2,w), target(e1)) + if (target(e1) == w) (source(e1), edgeGetOtherVertex(e2, w)) + else (edgeGetOtherVertex(e2, w), target(e1)) this .deleteVertex(w) @@ -780,26 +877,54 @@ case class Graph( } /** - * Put graph in normal form, where each (non-bare) wire has exactly 1 wire vertex - * @return + * Change all wires to boundaries or boundaries to wires depending on how many edges coincide with it + * + * @return Graph + */ + def coerceWiresAndBoundaries: Graph = { + val graph = this + graph.verts.foldLeft(graph) { (g, v) => + g.updateVData(v) { d => + if (d.isWireVertex) { + d.asInstanceOf[WireV].makeBoundary( + graph.isTerminalWire(v) + ) + } else d + } + } + } + + def isTerminalWire(vn: VName): Boolean = + vdata(vn).isWireVertex && (inEdges(vn).size + outEdges(vn).size) <= 1 + + /** + * Put graph in normal form, where each wire has exactly 1 wire vertex + * + * @return Graph */ def normalise: Graph = { var ch = false var g = this for (e <- edges) { - val s = source(e) - val t = target(e) - (vdata(s), vdata(t)) match { - case (_: NodeV, _: NodeV) => - g = g.edgeToWire(e) - ch = true - case (_: WireV, _: WireV) if s != t => - if (!isBoundary(s) || !isBoundary(t)) { - g = g.collapseWire(e) + if (!ch) { + val s = source(e) + val t = target(e) + (vdata(s), vdata(t)) match { + case (_: NodeV, _: NodeV) => + g = g.edgeToWire(e) ch = true - } - case _ => // do nothing + case (_: WireV, _: WireV) if s != t => + + /** + * Collapse if between two internal wires, unless going in or out of a bbox + */ + if (!isTerminalWire(s) && !isTerminalWire(t)) { + g = g.collapseWire(e) + ch = true + } + case _ => // do nothing + } } } @@ -812,10 +937,8 @@ case class Graph( } } - g - -// if (ch) g.normalise -// else g + if (ch) g.normalise + else g } def minimise: Graph = { @@ -828,6 +951,7 @@ case class Graph( /** * make a copy of the given bbox's contents, without copying the bbox itself + * * @param bb the bbox to be expanded * @return the new graph and a record containing relevant data for replaying the expansion */ @@ -839,8 +963,8 @@ case class Graph( var gfr = mp1.image(g) // add each expanded vertex to the bboxes that it was already in - g.verts.foreach{v => - ((bboxesContaining(v) -- bboxChildren(bb)) - bb).foreach{ bb1 => gfr = gfr.addToBBox(mp1.v(v), bb1) } + g.verts.foreach { v => + ((bboxesContaining(v) -- bboxChildren(bb)) - bb).foreach { bb1 => gfr = gfr.addToBBox(mp1.v(v), bb1) } } var g1 = appendGraph(gfr) @@ -853,21 +977,22 @@ case class Graph( val e1 = freshE.freshWithSuggestion(e) freshE = freshE + e1 mp1 = mp1.addEdge(e -> e1) - g1 = g1.addEdge(e1, edata(e), mp1.v.getOrElse(s,s) -> mp1.v.getOrElse(t,t)) + g1 = g1.addEdge(e1, edata(e), mp1.v.getOrElse(s, s) -> mp1.v.getOrElse(t, t)) } - (g1, BBExpand(bb,mp1)) + (g1, BBExpand(bb, mp1)) } /** * make a copy of the given bbox + * * @param bb the bbox to be copied * @return the new graph and a record containing relevant data for replaying the copy */ def copyBBox(bb: BBName, avoidV: Set[VName] = Set(), mp: GraphMap = GraphMap()): (Graph, BBCopy) = { var (g1, bbe) = expandBBox(bb, avoidV, mp) val mp1 = if (bbe.mp.bb.domSet contains bb) bbe.mp - else bbe.mp.copy(bb = bbe.mp.bb + (bb -> g1.bboxes.freshWithSuggestion(bb))) + else bbe.mp.copy(bb = bbe.mp.bb + (bb -> g1.bboxes.freshWithSuggestion(bb))) val bb1 = mp1.bb(bb) g1 = g1.addBBox(bb1, bbdata(bb), mp1.v.codSet) @@ -878,6 +1003,7 @@ case class Graph( /** * drop the given bbox, keeping the contents intact + * * @param bb the bbox to be dropped * @return the new graph and a record containing relevant data for replaying the drop */ @@ -887,6 +1013,7 @@ case class Graph( /** * kill the given bbox, also deleting child nodes and bboxes + * * @param bb the bbox to be dropped * @return the new graph and a record containing relevant data for replaying the drop */ @@ -898,29 +1025,26 @@ case class Graph( (g1, BBKill(bb)) } - private def freshMap(mp: PFun[VName, VName], avoid: Set[VName]) = { - mp - } - /** * Apply the given bbox operation - * @param bbop a !-box operation + * + * @param bbop a !-box operation * @param avoidV an optional set of (extra) vertices to avoid when copying !-box contents * @return */ def applyBBOp(bbop: BBOp, avoidV: Set[VName] = Set()): Graph = bbop match { case BBExpand(bb, mp) => - val mp1 = GraphMap(v = mp.v.filterKeys(v => verts.contains(v) && isBoundary(v)), bb = mp.bb) + val mp1 = GraphMap(v = mp.v.filterKeys(v => verts.contains(v) && isTerminalWire(v)), bb = mp.bb) expandBBox(bb, avoidV, mp1)._1 case BBCopy(bb, mp) => - val mp1 = GraphMap(v = mp.v.filterKeys(v => verts.contains(v) && isBoundary(v)), bb = mp.bb) + val mp1 = GraphMap(v = mp.v.filterKeys(v => verts.contains(v) && isTerminalWire(v)), bb = mp.bb) copyBBox(bb, avoidV, mp1)._1 case BBDrop(bb) => dropBBox(bb)._1 case BBKill(bb) => killBBox(bb)._1 } - def freeVars: Set[String] = vdata.foldRight(Set[String]()) { - case ((_,d: NodeV), s) => s union d.angle.vars + def freeVars: Set[(ValueType, String)] = vdata.foldRight(Set[(ValueType, String)]()) { + case ((_, d: NodeV), s) => s union d.phaseData.varsWithType case (_, s) => s } @@ -930,58 +1054,117 @@ case class Graph( contents(bb).forall(v => bboxesContaining(v).size > 1) } - def toJson(theory: Theory) : Json = Graph.toJson(this, theory) -} + def contents(bbn: BBName): Set[VName] = inBBox.codf(bbn) -object Graph { -// val Flavor = new DataFlavor(Graph.getClass, "X-quantoderive/qgraph; class=;") -// class GraphPacket(graph: Graph, val theory: Theory) extends Transferable { -// def getTransferData(f: DataFlavor) = this -// def isDataFlavorSupported(f: DataFlavor) = { f == Graph.Flavor } -// def getTransferDataFlavors = Array(Graph.Flavor) -// } + def bboxesContaining(vn: VName): Set[BBName] = inBBox.domf(vn) - implicit def qGraphAndNameToQGraph[N <: Name[N]](t: (Graph, Name[N])) : Graph = t._1 + def toJson(theory: Theory = Theory.DefaultTheory): Json = Graph.toJson(this, theory) - def apply(theory: Theory): Graph = Graph(data = GData(theory = theory)) + private def dftSuccessors[T](fromV: VName, exploredV: Set[VName], exploredE: Set[EName])(base: T) + (f: (T, EName, GraphSearchContext) => T): (T, Set[VName], Set[EName]) = { + val nextEs = outEdges(fromV).filter(!exploredE.contains(_)) + if (nextEs.nonEmpty) { + val e = nextEs.min + val nextV = target(e) - def fromJson(s: String, thy: Theory): Graph = - try { fromJson(Json.parse(s), thy) } - catch { case e:JsonParseException => throw new GraphLoadException("Error parsing JSON", e) } + val (base1, exploredV1, exploredE1) = + dftSuccessors(nextV, exploredV + nextV, exploredE + e)(base)(f) + val (base2, exploredV2, exploredE2) = + dftSuccessors(fromV, exploredV1, exploredE1)(base1)(f) + + val context = GraphSearchContext(exploredV2, exploredE2) + (f(base2, e, context), exploredV2, exploredE2) + } else { + (base, exploredV, exploredE) + } + } + + private def dftComponents[T](exploredV: Set[VName], exploredE: Set[EName])(base: T) + (f: (T, EName, GraphSearchContext) => T): T = { + val nextVs = vdata.keySet.filter(!exploredV.contains(_)) + val initialVs = nextVs.filter(inEdges(_).isEmpty) + // Try to start with the minimal unexplored vertex with no in-edges. Failing that, start with the + // minimal unexplored vertex. + val vOpt = if (initialVs.nonEmpty) Some(initialVs.min) + else if (nextVs.nonEmpty) Some(nextVs.min) + else None + + vOpt match { + case Some(v) => + val (base1, exploredV1, exploredE1) = dftSuccessors(v, exploredV + v, exploredE)(base)(f) + dftComponents[T](exploredV1, exploredE1)(base1)(f) + case None => base + } + } + + private def bbDft(bb: BBName, bbSeq: collection.mutable.Buffer[BBName], bbs: collection.mutable.Set[BBName]) { + for (ch <- bboxParent.codf(bb)) bbDft(ch, bbSeq, bbs) + if (bbs.contains(bb)) { + bbs.remove(bb) + bbSeq += bb + } + } + + private def freshMap(mp: PFun[VName, VName], avoid: Set[VName]) = { + mp + } +} + +object Graph { + // val Flavor = new DataFlavor(Graph.getClass, "X-quantoderive/qgraph; class=;") + // class GraphPacket(graph: Graph, val theory: Theory) extends Transferable { + // def getTransferData(f: DataFlavor) = this + // def isDataFlavorSupported(f: DataFlavor) = { f == Graph.Flavor } + // def getTransferDataFlavors = Array(Graph.Flavor) + // } + + implicit def qGraphAndNameToQGraph[N <: Name[N]](t: (Graph, Name[N])): Graph = t._1 def fromJson(s: String): Graph = fromJson(s, Theory.DefaultTheory) + def fromJson(s: String, thy: Theory): Graph = + try { + fromJson(Json.parse(s), thy) + } + catch { + case e: JsonParseException => throw new GraphLoadException("Error parsing JSON", e) + } + def fromJson(json: Json, thy: Theory = Theory.DefaultTheory): Graph = try { Function.chain[Graph](Seq( - (json ? "wire_vertices").asObject.foldLeft(_) { (g,v) => + (json ? "wire_vertices").asObject.foldLeft(_) { (g, v) => g.addVertex(v._1, WireV.fromJson(v._2, thy)) }, - (json ? "node_vertices").asObject.foldLeft(_) { (g,v) => + (json ? "node_vertices").asObject.foldLeft(_) { (g, v) => g.addVertex(v._1, NodeV.fromJson(v._2, thy)) }, - (json ? "dir_edges").asObject.foldLeft(_) { (g,e) => + (json ? "dir_edges").asObject.foldLeft(_) { (g, e) => val data = e._2.getOrElse("data", thy.defaultEdgeData).asObject val annotation = (e._2 ? "annotation").asObject g.addEdge(e._1, DirEdge(data, annotation, thy), ((e._2 / "src").stringValue, (e._2 / "tgt").stringValue)) }, - (json ? "undir_edges").asObject.foldLeft(_) { (g,e) => + (json ? "undir_edges").asObject.foldLeft(_) { (g, e) => val data = e._2.getOrElse("data", thy.defaultEdgeData).asObject val annotation = (e._2 ? "annotation").asObject g.addEdge(e._1, UndirEdge(data, annotation, thy), ((e._2 / "src").stringValue, (e._2 / "tgt").stringValue)) }, - (json ? "bang_boxes").asObject.foldLeft(_) { (g,bb) => + (json ? "bang_boxes").asObject.foldLeft(_) { (g, bb) => val data = (bb._2 ? "data").asObject val annotation = (bb._2 ? "annotation").asObject - val contains = (bb._2 ? "contents").vectorValue map { VName(_) } - val parent = bb._2.get("parent") map { BBName(_) } + val contains = (bb._2 ? "contents").vectorValue map { + VName(_) + } + val parent = bb._2.get("parent") map { + BBName(_) + } g.addBBox(bb._1, BBData(data, annotation), contains.toSet, parent) } @@ -999,14 +1182,12 @@ object Graph { } def toJson(graph: Graph, thy: Theory = Theory.DefaultTheory): Json = { - val (wireVertices, nodeVertices) = graph.vdata.foldLeft((JsonObject(), JsonObject())) - { - case ((objW,objN), (v,w: WireV)) => (objW + (v.toString -> w.toJson), objN) - case ((objW,objN), (v,n: NodeV)) => (objW, objN + (v.toString -> n.toJson)) + val (wireVertices, nodeVertices) = graph.vdata.foldLeft((JsonObject(), JsonObject())) { + case ((objW, objN), (v, w: WireV)) => (objW + (v.toString -> w.toJson), objN) + case ((objW, objN), (v, n: NodeV)) => (objW, objN + (v.toString -> n.toJson)) } - val (dirEdges, undirEdges) = graph.edata.foldLeft((JsonObject(), JsonObject())) - { case ((objD,objU), (e,d)) => + val (dirEdges, undirEdges) = graph.edata.foldLeft((JsonObject(), JsonObject())) { case ((objD, objU), (e, d)) => val entry = e.toString -> (d.toJson + ("src" -> graph.source(e).toString, "tgt" -> graph.target(e).toString)) if (d.isDirected) (objD + entry, objU) else (objD, objU + entry) } @@ -1014,11 +1195,12 @@ object Graph { val bangBoxes = graph.bbdata.foldLeft(JsonObject()) { case (obj, (bb, d)) => obj + (bb.toString -> JsonObject( - "contents" -> JsonArray(graph.contents(bb)), - "parent" -> (graph.bboxParent.get(bb) match { + "contents" -> JsonArray(graph.contents(bb)), + "parent" -> (graph.bboxParent.get(bb) match { case Some(p) => JsonString(p.toString) - case None => JsonNull }), - "data" -> d.data, + case None => JsonNull + }), + "data" -> d.data, "annotation" -> d.annotation ).noEmpty) } @@ -1026,18 +1208,18 @@ object Graph { JsonObject( "wire_vertices" -> wireVertices.asObjectOrKeyArray, "node_vertices" -> nodeVertices.asObjectOrKeyArray, - "dir_edges" -> dirEdges, - "undir_edges" -> undirEdges, - "bang_boxes" -> bangBoxes, - "data" -> graph.data.data, - "annotation" -> graph.data.annotation + "dir_edges" -> dirEdges, + "undir_edges" -> undirEdges, + "bang_boxes" -> bangBoxes, + "data" -> graph.data.data, + "annotation" -> graph.data.annotation ).noEmpty } def random(nverts: Int, nedges: Int, nbboxes: Int = 0): Graph = { val rand = new util.Random var randomGraph = Graph() - for (i <- 1 to nverts) { + for (_ <- 1 to nverts) { val p = (rand.nextDouble * 6.0 - 3.0, rand.nextDouble * 6.0 - 3.0) if (rand.nextBoolean()) randomGraph = randomGraph.newVertex(NodeV(p)) else randomGraph = randomGraph.newVertex(WireV(p)) @@ -1046,21 +1228,21 @@ object Graph { if (nverts != 0) { val sources = new ArrayBuffer[VName](randomGraph.vdata.keys.size) val targets = new ArrayBuffer[VName](randomGraph.vdata.keys.size) - randomGraph.vdata.keys.foreach{k => sources += k; targets += k} - for(j <- 1 to nedges if sources.nonEmpty && targets.nonEmpty) { - val (si,ti) = (rand.nextInt(sources.size), rand.nextInt(targets.size)) + randomGraph.vdata.keys.foreach { k => sources += k; targets += k } + for (_ <- 1 to nedges if sources.nonEmpty && targets.nonEmpty) { + val (si, ti) = (rand.nextInt(sources.size), rand.nextInt(targets.size)) val s = sources(si) val t = targets(ti) if (randomGraph.vdata(s).isWireVertex) sources -= s if (randomGraph.vdata(t).isWireVertex) targets -= t - randomGraph = randomGraph.newEdge(DirEdge(), (s,t)) + randomGraph = randomGraph.newEdge(DirEdge(), (s, t)) } val varray = randomGraph.vdata.keys.toArray - for (i <- 1 to nbboxes) { - val randomVSet = (1 to sqrt(nverts).toInt).foldLeft(Set[VName]()) { (s,_) => + for (_ <- 1 to nbboxes) { + val randomVSet = (1 to sqrt(nverts).toInt).foldLeft(Set[VName]()) { (s, _) => s + varray(rand.nextInt(varray.length)) } @@ -1074,7 +1256,7 @@ object Graph { def randomDag(nverts: Int, nedges: Int): Graph = { val rand = new util.Random var randomGraph = Graph() - for (i <- 1 to nverts) { + for (_ <- 1 to nverts) { val p = (rand.nextDouble * 6.0 - 3.0, rand.nextDouble * 6.0 - 3.0) randomGraph = randomGraph.newVertex(NodeV(p)) } @@ -1082,17 +1264,39 @@ object Graph { // must have at least two verts to add edges since no self-loops allowed if (nverts > 1) - for(j <- 1 to nedges) { + for (_ <- 1 to nedges) { val x = rand.nextInt(varray.length) val y = rand.nextInt(varray.length - 1) val s = varray(x) - val t = varray(if (y >= x) y+1 else y) - randomGraph = randomGraph.newEdge(DirEdge(), if (s <= t) (s,t) else (t,s)) + val t = varray(if (y >= x) y + 1 else y) + randomGraph = randomGraph.newEdge(DirEdge(), if (s <= t) (s, t) else (t, s)) } randomGraph } + def variablesUsed(theory: Theory, graph: Graph): Set[String] = { + graph.vdata.foldLeft(Set[String]()) { (names, vnd) => { + val nodeDataType = theory.vertexTypes(vnd._2.typ).value.typ + val phases = CompositeExpression.parseKnowingTypes(vnd._2.data.toString(), nodeDataType) + val compositeExpression = CompositeExpression(nodeDataType, phases) + names union compositeExpression.vars + } + } + } + + def variablesUsedWithType(theory: Theory, graph: Graph): Set[(ValueType, String)] = { + graph.vdata.foldLeft(Set[(ValueType, String)]()) { (names, vnd) => + vnd._2 match { + case node: NodeV => + val nodeDataType = theory.vertexTypes(node.typ).value.typ + val phases = CompositeExpression.parseKnowingTypes(node.data ? "value", nodeDataType) + val compositeExpression = CompositeExpression(nodeDataType, phases) + names union compositeExpression.varsWithType + case _ => names + } + } + } def fromAdjMat(amat: AdjMat, rdata: Vector[NodeV], gdata: Vector[NodeV]): Graph = { val thy = @@ -1100,18 +1304,31 @@ object Graph { else if (rdata.nonEmpty) rdata(0).theory else throw new GraphException("Must give at least one piece of node data") + var g = Graph(thy) - for (i <- 0 until amat.numBoundaries) g = g.addVertex(VName("v"+i), WireV(theory = thy)) + for (i <- 0 until amat.numBoundaries) g = g.addVertex(VName("v" + i), WireV(theory = thy).withCoord(i,0)) var i = amat.numBoundaries - for (t <- 0 until amat.numRedTypes; v <- 0 until amat.red(t)) { - g = g.addVertex(VName("v" + i), rdata(t)) + def red(i: Int): Int = if (amat.red.size > i) { + amat.red(i) + } else { + 0 + } + + def green(i: Int): Int = if (amat.green.size > i) { + amat.green(i) + } else { + 0 + } + + for (t <- 0 until amat.numRedTypes; _ <- 0 until red(t)) { + g = g.addVertex(VName("v" + i), rdata(t).withCoord(i, math.sin(i + t))) i += 1 } - for (t <- 0 until amat.numGreenTypes; v <- 0 until amat.green(t)) { - g = g.addVertex(VName("v" + i), gdata(t)) + for (t <- 0 until amat.numGreenTypes; _ <- 0 until green(t)) { + g = g.addVertex(VName("v" + i), gdata(t).withCoord(i, math.sin(i + t))) i += 1 } @@ -1125,4 +1342,6 @@ object Graph { g } + + def apply(theory: Theory): Graph = Graph(data = GData(theory = theory)) } diff --git a/scala/src/main/scala/quanto/data/GraphElementData.scala b/scala/src/main/scala/quanto/data/GraphElementData.scala index 36f36014..e8212ef4 100644 --- a/scala/src/main/scala/quanto/data/GraphElementData.scala +++ b/scala/src/main/scala/quanto/data/GraphElementData.scala @@ -3,15 +3,19 @@ package quanto.data import quanto.util.json.JsonObject /** - * An abstract class providing a general interface for accessing - * information contained in its different components - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/GraphElementData.scala Source code]] - * @see Known Subclasses below - * @author Aleks Kissinger - */ + * An abstract class providing a general interface for accessing + * information contained in its different components + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/GraphElementData.scala Source code]] + * @see Known Subclasses below + * @author Aleks Kissinger + */ abstract class GraphElementData { def theory: Theory + def data: JsonObject + def annotation: JsonObject + def toJson: JsonObject = JsonObject("data" -> data, "annotation" -> annotation).noEmpty } diff --git a/scala/src/main/scala/quanto/data/GraphMap.scala b/scala/src/main/scala/quanto/data/GraphMap.scala index 2eb5b148..3d3784d6 100644 --- a/scala/src/main/scala/quanto/data/GraphMap.scala +++ b/scala/src/main/scala/quanto/data/GraphMap.scala @@ -32,14 +32,16 @@ case class GraphMap( bb.codf.forall(_._2.size == 1) } - def addVertex(p: (VName,VName)): GraphMap = copy(v = v + p) - def addEdge(p: (EName,EName)): GraphMap = copy(e = e + p) - def addBBox(p: (BBName,BBName)): GraphMap = copy(bb = bb + p) + def addVertex(p: (VName, VName)): GraphMap = copy(v = v + p) + + def addEdge(p: (EName, EName)): GraphMap = copy(e = e + p) + + def addBBox(p: (BBName, BBName)): GraphMap = copy(bb = bb + p) def isTotal(g: Graph): Boolean = - v.domSet == g.verts && - e.domSet == g.edges && - bb.domSet == g.bboxes + v.domSet == g.verts && + e.domSet == g.edges && + bb.domSet == g.bboxes def image(g: Graph): Graph = g.rename(v.toMap, e.toMap, bb.toMap) } diff --git a/scala/src/main/scala/quanto/data/HasGraph.scala b/scala/src/main/scala/quanto/data/HasGraph.scala index ebefa047..02ce4479 100644 --- a/scala/src/main/scala/quanto/data/HasGraph.scala +++ b/scala/src/main/scala/quanto/data/HasGraph.scala @@ -4,18 +4,21 @@ import scala.swing.Publisher import scala.swing.event.Event abstract class GraphEvent extends Event + case class GraphChanged(sender: HasGraph) extends GraphEvent // will cause any graph views to invalidate and repaint the graph case class GraphReplaced(sender: HasGraph, clearSelection: Boolean) extends GraphEvent trait HasGraph extends Publisher { - protected def gr: Graph - protected def gr_=(g: Graph) + def graph: Graph = gr - def graph = gr def graph_=(g: Graph) { gr = g publish(GraphChanged(this)) } + + protected def gr: Graph + + protected def gr_=(g: Graph) } diff --git a/scala/src/main/scala/quanto/data/Names.scala b/scala/src/main/scala/quanto/data/Names.scala index f1c4ac6f..8d697563 100644 --- a/scala/src/main/scala/quanto/data/Names.scala +++ b/scala/src/main/scala/quanto/data/Names.scala @@ -3,15 +3,10 @@ package quanto.data import quanto.util.json.JsonString import scala.collection._ -import quanto.util.StringNamer - - trait Name[This <: Name[This]] extends Ordered[This] { val s: String - protected val mk: String => This - val (prefix, suffix) = { var intIndex = s.length while (intIndex > 0 && s.charAt(intIndex - 1) >= '0' && s.charAt(intIndex - 1) <= '9') intIndex -= 1 @@ -24,41 +19,64 @@ trait Name[This <: Name[This]] extends Ordered[This] { if (intIndex == s.length) -1 else s.substring(intIndex, s.length).toInt ) } + protected val mk: String => This - def compare(that: This) = if (prefix < that.prefix) -1 - else if (prefix > that.prefix) 1 - else suffix compare that.suffix + def compare(that: This): Int = if (prefix < that.prefix) -1 + else if (prefix > that.prefix) 1 + else suffix compare that.suffix def succ: This = mk(prefix + (suffix + 1)) - override def toString = s + override def toString: String = s } -case class GName(s: String) extends Name[GName] { protected val mk = GName(_) } -case class VName(s: String) extends Name[VName] { protected val mk = VName(_) } -case class EName(s: String) extends Name[EName] { protected val mk = EName(_) } -case class BBName(s: String) extends Name[BBName] { protected val mk = BBName(_) } -case class DSName(s: String) extends Name[DSName] { protected val mk = DSName(_) } +case class GName(s: String) extends Name[GName] { + protected val mk = GName(_) +} + +case class VName(s: String) extends Name[VName] { + protected val mk = VName(_) +} + +case class EName(s: String) extends Name[EName] { + protected val mk = EName(_) +} + +case class BBName(s: String) extends Name[BBName] { + protected val mk = BBName(_) +} + +case class DSName(s: String) extends Name[DSName] { + protected val mk = DSName(_) +} class DuplicateNameException[N <: Name[N]](ty: String, val name: N) extends Exception("Duplicate " + ty + " name: '" + name + "'") + class DuplicateVertexNameException(override val name: VName) extends DuplicateNameException("vertex", name) + class DuplicateEdgeNameException(override val name: EName) extends DuplicateNameException("edge", name) + class DuplicateBBoxNameException(override val name: BBName) extends DuplicateNameException("bang box", name) object Names { + class NameSet[N <: Name[N]](val set: Set[N]) { - def fresh(implicit default: N) : N = if (set.isEmpty) default else set.max.succ - def freshWithSuggestion(s : N) : N = { var t = s; while (set.contains(t)) t = t.succ; t } + def fresh(implicit default: N): N = if (set.isEmpty) default else set.max.succ + + def freshWithSuggestion(s: N): N = { + var t = s; while (set.contains(t)) t = t.succ; t + } } - class NameMap[N <: Name[N], T](val map: Map[N,T]) { - def fresh(implicit default: N) : N = if (map.isEmpty) default else map.keys.max.succ - def freshWithSuggestion(s : N) : N = { + class NameMap[N <: Name[N], T](val map: Map[N, T]) { + def fresh(implicit default: N): N = if (map.isEmpty) default else map.keys.max.succ + + def freshWithSuggestion(s: N): N = { val set = map.keySet var t = s while (set.contains(t)) t = t.succ @@ -66,48 +84,60 @@ object Names { } } -// class NamePFun[N <: Name[N], T](val pf: PFun[N,T]) { -// def fresh(implicit default: N) : N = if (pf.isEmpty) default else pf.dom.max.succ -// def freshWithSuggestion(s : N) : N = { -// val set = pf.domSet -// var t = s -// while (set.contains(t)) t = t.succ -// t -// } -// } + // class NamePFun[N <: Name[N], T](val pf: PFun[N,T]) { + // def fresh(implicit default: N) : N = if (pf.isEmpty) default else pf.dom.max.succ + // def freshWithSuggestion(s : N) : N = { + // val set = pf.domSet + // var t = s + // while (set.contains(t)) t = t.succ + // t + // } + // } // TODO: overkill with implicits? - implicit def setToNameSet[N <: Name[N]](set : Set[N]):NameSet[N] = + implicit def setToNameSet[N <: Name[N]](set: Set[N]): NameSet[N] = new NameSet(set) - implicit def mapToNameMap[N <: Name[N], T](map : Map[N,T]):NameMap[N,T] = + + implicit def mapToNameMap[N <: Name[N], T](map: Map[N, T]): NameMap[N, T] = new NameMap(map) // these support general-purpose string-for-name substitution - implicit def stringToGName(s: String): GName = GName(s) - implicit def stringToVName(s: String): VName = VName(s) - implicit def stringToEName(s: String): EName = EName(s) + implicit def stringToGName(s: String): GName = GName(s) + + implicit def stringToVName(s: String): VName = VName(s) + + implicit def stringToEName(s: String): EName = EName(s) + implicit def stringToBBName(s: String): BBName = BBName(s) + implicit def stringToDSName(s: String): DSName = DSName(s) implicit def stringSetToGNameSet(set: Set[String]): Set[GName] = set map GName.apply + implicit def stringSetToVNameSet(set: Set[String]): Set[VName] = set map VName.apply + implicit def stringSetToENameSet(set: Set[String]): Set[EName] = set map EName.apply + implicit def stringSetToBBNameSet(set: Set[String]): Set[BBName] = set map BBName.apply + implicit def stringSetToDSNameSet(set: Set[String]): Set[DSName] = set map DSName.apply // edge creation methods take a pair of vertices - implicit def stringPairToVNamePair(t: (String,String)): (VName, VName) = (VName(t._1), VName(t._2)) + implicit def stringPairToVNamePair(t: (String, String)): (VName, VName) = (VName(t._1), VName(t._2)) // these can be used to save names into JSON without conversion implicit def gNameToJsonString(n: GName): JsonString = quanto.util.json.JsonString(n.toString) + implicit def vNameToJsonString(n: VName): JsonString = quanto.util.json.JsonString(n.toString) + implicit def eNameToJsonString(n: EName): JsonString = quanto.util.json.JsonString(n.toString) + implicit def bbNameToJsonString(n: BBName): JsonString = quanto.util.json.JsonString(n.toString) - implicit val defaultVName = VName("v0") - implicit val defaultEName = EName("e0") - implicit val defaultGName = GName("g0") - implicit val defaultBBName = BBName("bx0") - implicit val defaultDSName = DSName("0") + implicit val defaultVName: VName = VName("v0") + implicit val defaultEName: EName = EName("e0") + implicit val defaultGName: GName = GName("g0") + implicit val defaultBBName: BBName = BBName("bx0") + implicit val defaultDSName: DSName = DSName("0") } diff --git a/scala/src/main/scala/quanto/data/PFun.scala b/scala/src/main/scala/quanto/data/PFun.scala index 975660ee..4cc2bfb2 100644 --- a/scala/src/main/scala/quanto/data/PFun.scala +++ b/scala/src/main/scala/quanto/data/PFun.scala @@ -1,51 +1,37 @@ package quanto.data -import collection.immutable.TreeSet -import scala.collection.{TraversableLike, GenTraversableOnce, mutable, IterableLike} -import scala.collection.generic.CanBuildFrom -import Names._ +import scala.collection.immutable.TreeSet +import scala.collection.{GenTraversableOnce, IterableLike, mutable} /** - * Basically a map, but with cached inverse images - * - * @tparam A type of domain elements - * @tparam B type of codomain elements - * - * @constructor Create a new instance by specifying the partial function and - * the inverse image function - * @param f The partial function - * @param finv The inverse image function '''f ^-1^ ''' - * @param keyOrd Order on the domain elements (keys) - * - * @author Aleks Kissinger - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/PFun.scala Source code]] - */ + * Basically a map, but with cached inverse images + * + * @tparam A type of domain elements + * @tparam B type of codomain elements + * @constructor Create a new instance by specifying the partial function and + * the inverse image function + * @param f The partial function + * @param finv The inverse image function '''f ^-1^ ''' + * @param keyOrd Order on the domain elements (keys) + * @author Aleks Kissinger + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/PFun.scala Source code]] + */ class PFun[A, B] -(f : Map[A,B], finv: Map[B,TreeSet[A]]) +(f: Map[A, B], finv: Map[B, TreeSet[A]]) (implicit keyOrd: Ordering[A]) - extends BinRel[A,B] with IterableLike[(A,B), PFun[A,B]] - with GenTraversableOnce[(A,B)] { + extends BinRel[A, B] with IterableLike[(A, B), PFun[A, B]] + with GenTraversableOnce[(A, B)] { - def domf = f.mapValues(Set(_)) + def domf: Map[A, Set[B]] = f.mapValues(Set(_)) - def codf = finv.withDefaultValue(Set[A]()) + def codf: Map[B, Set[A]] = finv.withDefaultValue(Set[A]()) - def +(kv: (A,B)) : PFun[A,B] = { - val finv1 = - (f.get(kv._1) match { - case Some(oldV) => if (finv(oldV).size == 1) finv - oldV - else finv + (oldV -> (finv(oldV) - kv._1)) - case None => finv - }) + (kv._2 -> (finv.getOrElse(kv._2, TreeSet[A]()) + kv._1)) - new PFun(f + kv,finv1) - } + def contains(kv: (A, B)): Boolean = f.get(kv._1).contains(kv._2) - def contains(kv: (A,B)) = f.get(kv._1).contains(kv._2) - - def unmap(kv: (A, B)) = f.get(kv._1) match { + def unmap(kv: (A, B)): PFun[A, B] = f.get(kv._1) match { case Some(v) if v == kv._2 => val domSet = finv(v) - new PFun[A,B]( + new PFun[A, B]( f - kv._1, if (domSet.size == 1) finv - kv._2 else finv + (kv._2 -> (domSet - kv._1)) ) @@ -53,85 +39,106 @@ class PFun[A, B] case None => this } - def unmapDom(k: A) :PFun[A,B] = { - f.get(k) match { - case None => this // do nothing - case Some(v) => - val domSet = finv(v) - new PFun[A,B]( - f - k, - if (domSet.size == 1) finv - v else finv + (v -> (domSet - k)) - ) - } - } - - def unmapCod(v: B) :PFun[A,B] = { + def unmapCod(v: B): PFun[A, B] = { finv.get(v) match { case None => this // do nothing case Some(domSet) => - new PFun[A,B]( - domSet.foldLeft(f) { _ - _ }, + new PFun[A, B]( + domSet.foldLeft(f) { + _ - _ + }, finv - v ) } } - def restrictDom(s: Set[A]): PFun[A,B] = { - s.foldRight(PFun[A,B]()) { (k,mp) => get(k) match { case Some(v) => mp + (k -> v); case None => mp } } + def restrictDom(s: Set[A]): PFun[A, B] = { + s.foldRight(PFun[A, B]()) { (k, mp) => get(k) match { + case Some(v) => mp + (k -> v); + case None => mp + } + } } + /** + * Similar to '''apply''', but returns an option instead + * + * @param k Domain element where function should be evaluated + * @return Optionally returns the function value at '''k''' + */ + def get(k: A): Option[B] = f.get(k) + /** Same as '''unmapDom''' */ - def -(k: A) = unmapDom(k) + def -(k: A): PFun[A, B] = unmapDom(k) - /** Creates an iterator (same as f.iterator) */ - def iterator = f.iterator + def unmapDom(k: A): PFun[A, B] = { + f.get(k) match { + case None => this // do nothing + case Some(v) => + val domSet = finv(v) + new PFun[A, B]( + f - k, + if (domSet.size == 1) finv - v else finv + (v -> (domSet - k)) + ) + } + } /** Returns '''f''' */ - def toMap = f - - /** - * Specifies the behaviour for elements of the domain where the - * function is not defined. Children have the option to override - * this to give a default value. - * - * @param key Domain element - * @return Alway throws an exception - * @throws NoSuchElementException Exception indicates which key is not found - */ - def default(key: A): B = - throw new NoSuchElementException("key not found: " + key) - - /** - * Similar to '''apply''', but returns an option instead - * - * @param k Domain element where function should be evaluated - * @return Optionally returns the function value at '''k''' - */ - def get(k: A) = f.get(k) + def toMap: Map[A, B] = f /** - * Get the value of the function at the specified domain element - * '''( F(k) )''' - * - * @param k Domain element where function should be evaluated - * @return Function value at '''k''' if it is defined there, otherwise - * '''default(k)''' - */ - def apply(k: A) = f.get(k) match { + * Get the value of the function at the specified domain element + * '''( F(k) )''' + * + * @param k Domain element where function should be evaluated + * @return Function value at '''k''' if it is defined there, otherwise + * '''default(k)''' + */ + def apply(k: A): B = f.get(k) match { case Some(v) => v case None => default(k) } - protected[this] def newBuilder = new mutable.Builder[(A,B),PFun[A,B]] { - val s = collection.mutable.Buffer[(A,B)]() - def result() = PFun(s: _*) - def clear() = s.clear() - def +=(elem: (A,B)) = { s += elem; this } + /** + * Specifies the behaviour for elements of the domain where the + * function is not defined. Children have the option to override + * this to give a default value. + * + * @param key Domain element + * @return Alway throws an exception + * @throws NoSuchElementException Exception indicates which key is not found + */ + def default(key: A): B = + throw new NoSuchElementException("key not found: " + key) + + def seq: Seq[(A, B)] = iterator.toSeq + + /** Creates an iterator (same as f.iterator) */ + def iterator: Iterator[(A, B)] = f.iterator + + def ++(pf: PFun[A, B]): PFun[A, B] = pf.foldLeft(this) { case (mp, kv) => mp + kv } + + def +(kv: (A, B)): PFun[A, B] = { + val finv1 = + (f.get(kv._1) match { + case Some(oldV) => if (finv(oldV).size == 1) finv - oldV + else finv + (oldV -> (finv(oldV) - kv._1)) + case None => finv + }) + (kv._2 -> (finv.getOrElse(kv._2, TreeSet[A]()) + kv._1)) + new PFun(f + kv, finv1) } - def seq = iterator.toSeq + protected[this] def newBuilder : mutable.Builder[(A, B), PFun[A, B]] = new mutable.Builder[(A, B), PFun[A, B]] { + val s: mutable.Buffer[(A, B)] = collection.mutable.Buffer[(A, B)]() + + def result() = PFun(s: _*) - def ++(pf:PFun[A,B]) = pf.foldLeft(this) { case (mp, kv) => mp + kv } + def clear(): Unit = s.clear() + + def +=(elem: (A, B)): this.type = { + s += elem; this + } + } // PFun inherits equality from its member "f" // override def hashCode = f.hashCode() @@ -148,17 +155,18 @@ class PFun[A, B] } /** - * Companion object for the PFun class - * - * @author Aleks Kissinger - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/PFun.scala Source code]] - */ + * Companion object for the PFun class + * + * @author Aleks Kissinger + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/PFun.scala Source code]] + */ object PFun { /** Create an instance of PFun from a sequence of pairs */ - def apply[A, B](kvs: (A,B)*)(implicit keyOrd: Ordering[A]) : PFun[A,B] = { - kvs.foldLeft(new PFun[A,B](Map(),Map())){ (pf: PFun[A,B], kv: (A,B)) => pf + kv } + def apply[A, B](kvs: (A, B)*)(implicit keyOrd: Ordering[A]): PFun[A, B] = { + kvs.foldLeft(new PFun[A, B](Map(), Map())) { (pf: PFun[A, B], kv: (A, B)) => pf + kv } } - implicit def mapToPFun[A, B](mp: Map[A,B])(implicit keyOrd: Ordering[A]): PFun[A,B] = PFun(mp.toSeq:_*) - implicit def pFunToMap[A, B](f: PFun[A,B]): Map[A,B] = f.toMap + implicit def mapToPFun[A, B](mp: Map[A, B])(implicit keyOrd: Ordering[A]): PFun[A, B] = PFun(mp.toSeq: _*) + + implicit def pFunToMap[A, B](f: PFun[A, B]): Map[A, B] = f.toMap } diff --git a/scala/src/main/scala/quanto/data/PhaseExpression.scala b/scala/src/main/scala/quanto/data/PhaseExpression.scala new file mode 100644 index 00000000..c8f052cc --- /dev/null +++ b/scala/src/main/scala/quanto/data/PhaseExpression.scala @@ -0,0 +1,473 @@ +package quanto.data + +import quanto.data.Theory.ValueType +import quanto.util.{Rational, UserAlerts} + +import scala.util.matching.Regex +import scala.util.parsing.combinator.RegexParsers + +// Phases are just groups with added symbolic variables, +// Except we also need to be able to do Gaussian elimination on them +// So need to be able to interpret integers inside the phase group +// (i.e. needs to know what '1' is) +// We use the language of Abelian groups +// For things like Strings you just need to fudge it to allow people to put notes into nodes + + +case class PhaseParseException(msg: String, valueType: ValueType) + extends Exception(s"Attempting parse $ValueType threw message $msg") + +case class PhaseEvaluationException(msg: String) extends Exception(msg) + + +abstract class RationalWithSymbols[Group <: RationalWithSymbols[Group]](val coefficients: Map[String, Rational], val constant: Rational) { + + val description: ValueType + // Any symbolic variables should be in a set, + // However since we need these things to give consistent string outputs + // we sort the list here, and convert back into a set if needed + // to avoid "x + y" != "y + x" + val vars: List[String] = coefficients.keySet.toList.sorted + + // The companion object should hold the zero element, but link to it from the class + def zero: Group + + // Needs a 1, also held on the companion object + def one: Group + + // Addition + def +(g: Group): Group + + // Subtraction + def -(g: Group): Group + + // Multiplication - only needs to be able to do multiplication by possible variable coefficients + def *(r: Rational): Group + + // Substitution of single variable + def subst(s: String, e: Group): Group + + // Substitution of map of variables + def subst(mp: Map[String, Group]): Group + + def evaluate(mp: Map[String, Double]): Double = { + try { + constant + coefficients.foldLeft(0.0) { (a, b) => a + (mp(b._1) * Rational.rationalToDouble(b._2)) } + } catch { + case _: Exception => PhaseEvaluationException("Given arguments do not match those in the coefficient list") + 0 + } + } + +} + +class PhaseExpression(const: Rational, + coeff: Map[String, Rational], + val modulus: Option[Int], + val finiteField: Boolean, + override val description: ValueType) + extends RationalWithSymbols[PhaseExpression](coeff, const) { + + // Apply the modulus at creation + override val constant: Rational = mod(const) + // filter out any variables with zero as their coefficient + // If restricting to integers then we also apply the modulo to the coefficients + // e.g. 2*alpha = false in Bool (yes, that's a horrible mix, but it's the sort of thing we expect to see parsed) + override val coefficients: Map[String, Rational] = coeff.map( + sr => sr._1 -> (if (finiteField) mod(sr._2) else sr._2) + ).filterNot(sr => sr._2.isZero) + + def mod(r: Rational): Rational = if (modulus.nonEmpty) { + val reduced = Rational(r.n % (modulus.get * r.d), r.d) + // Recall that -1 % 2 is still -1, for reasons. + if (reduced < 0) { + reduced + modulus.get + } else { + reduced + } + } else { + r + } + + override def equals(that: Any): Boolean = that match { + case e: PhaseExpression => + description == e.description && constant == e.constant && coefficients == e.coefficients + case _ => false + } + + def zero: PhaseExpression = PhaseExpression.zero(description) + + def one: PhaseExpression = PhaseExpression.one(description) + + def subst(mp: Map[String, PhaseExpression]): PhaseExpression = + mp.foldLeft(this) { case (e, (v, e1)) => e.subst(v, e1) } + + def subst(v: String, e: PhaseExpression): PhaseExpression = { + val c = coefficients.getOrElse(v, Rational(0)) + this - PhaseExpression(Rational(0), Map(v -> c), description) + (e * c) + } + + def -(e: PhaseExpression): PhaseExpression = this + (e * -1) + + def *(i: Int): PhaseExpression = this * Rational(i) + + def +(e: PhaseExpression) = PhaseExpression(mod(constant + e.constant), + e.coefficients.foldLeft(coefficients) { + case (m, (k, v)) => m + (k -> (v + m.getOrElse(k, Rational(0)))) + }, description) + + def *(r: Rational): PhaseExpression = + PhaseExpression(mod(constant * r), coefficients.mapValues(x => x * r), description) + + override def toString: String = PhaseExpression.toString(description, this) + + override def evaluate(mp: Map[String, Double]): Double = mod(super.evaluate(mp)) + + def mod(d: Double): Double = if (modulus.nonEmpty) { + d % modulus.get + } else { + d + } + + def as(valueType: ValueType): PhaseExpression = PhaseExpression(constant, coefficients, valueType) + + def convertTo(valueType: ValueType): PhaseExpression = PhaseExpression(this.constant, this.coefficients, valueType) +} + +case class FieldData(modulus: Option[Int], finiteField: Boolean) + + +object PhaseExpression { + + def apply(r: Rational, valueType: ValueType): PhaseExpression = { + PhaseExpression(r, Map(), valueType) + } + + def parse(s: String, valueType: ValueType): PhaseExpression = { + try { + valueType match { + case ValueType.AngleExpr => AngleExpressionParser.p(s) + case ValueType.Boolean => BooleanExpressionParser.p(s) + case ValueType.Rational => RationalExpressionParser.p(s) + case ValueType.Empty => PhaseExpression(0, Map(), ValueType.Empty) + case ValueType.String => StringExpressionParser.p(s) + case v => throw PhaseParseException(s"Asked to parse unexpected '$s' for type $v", v) + } + }catch{ + case PhaseParseException(m, v) => + UserAlerts.alert(s"Could not parse $s for type $v, regex error: $m", UserAlerts.Elevation.ERROR) + PhaseExpression.zero(v) + } + } + + def one(valueType: ValueType): PhaseExpression = PhaseExpression(1, Map(), valueType) + + def apply(r: Rational, m: Map[String, Rational], valueType: ValueType): PhaseExpression = { + val fData = fieldData(valueType) + new PhaseExpression(r, m, fData.modulus, fData.finiteField, valueType) + } + + def fieldData(valueType: ValueType): FieldData = valueType match { + case ValueType.AngleExpr => FieldData(Some(2), finiteField = false) + case ValueType.Boolean => FieldData(Some(2), finiteField = true) + case _ => FieldData(None, finiteField = false) + } + + def toString(valueType: ValueType, phaseExpression: PhaseExpression): String = { + valueType match { + case ValueType.AngleExpr => + writeAsAngle(phaseExpression) + case ValueType.Boolean => + writeAsBoolean(phaseExpression) + case ValueType.String => + writeAsString(phaseExpression) + case ValueType.Empty => + "" + case ValueType.Rational => + writeAsRational(phaseExpression) + } + } + + private def writeAsRational(phaseExpression: PhaseExpression): String = { + val constant = phaseExpression.constant + val coefficients = phaseExpression.coefficients + var fst = true + var s = "" + if (!constant.isZero) { + + fst = false + val (n, sgn) = + if (constant.n > constant.d && constant.n < 2 * constant.d) (2 * constant.d - constant.n, "-") + else (constant.n, "") + if (n == 1) s += sgn + "1/" + constant.d + else s += sgn + n.toString + "/" + constant.d + + } + + def rStr(c: Rational) : String = writeAsRational(PhaseExpression(c, Map(), ValueType.Rational)) + + coefficients.keys.toList.sorted.foreach { variableName => + val c = coefficients(variableName) + if (!c.isZero) { + if (fst) { + fst = false + s = s + (if (c == Rational(1)) "" else rStr(c) + " ") + variableName + } else { + if (c < Rational(0)) { + s = s + " - " + (if (c == Rational(-1)) "" else rStr(c * -1) + " ") + variableName + } else { + s = s + " + " + (if (c == Rational(1)) "" else rStr(c) + " ") + variableName + } + } + } + } + if (phaseExpression == PhaseExpression.zero(ValueType.Rational)) s = "0" + + s + } + + private def writeAsString(phaseExpression: PhaseExpression): String = { + val constant = phaseExpression.constant + val vars = phaseExpression.vars + if (vars.nonEmpty) { + vars.mkString("", ", ", "") + } else { + "" + } + } + + private def writeAsBoolean(phaseExpression: PhaseExpression): String = { + val constant = phaseExpression.constant + val vars = phaseExpression.vars + val t = "\\True" + val f = "\\False" + if (vars.nonEmpty) { + vars.mkString(if (constant == Rational(1)) { + s"$t + " + } else { + "" + }, " + ", "") + } else { + if (constant == Rational(1)) t else f + } + } + + private def writeAsAngle(phaseExpression: PhaseExpression): String = { + val constant = phaseExpression.constant + val coefficients = phaseExpression.coefficients + var fst = true + var s = "" + if (!constant.isZero) { + + fst = false + if (constant.isOne) s += "\\pi" + else { + val (n, sgn) = + if (constant.n > constant.d && constant.n < 2 * constant.d) (2 * constant.d - constant.n, "-") + else (constant.n, "") + if (n == 1) s += sgn + "\\pi/" + constant.d + else s += sgn + n.toString + "\\pi/" + constant.d + } + } + + coefficients.keys.toList.sorted.foreach { variableName => + val c = coefficients(variableName) + if (!c.isZero) { + if (fst) { + fst = false + s = s + (if (c == Rational(1)) "" else c.toString + " ") + variableName + } else { + if (c < Rational(0)) { + s = s + " - " + (if (c == Rational(-1)) "" else (c * -1).toString + " ") + variableName + } else { + s = s + " + " + (if (c == Rational(1)) "" else c.toString + " ") + variableName + } + } + } + } + if (phaseExpression == PhaseExpression.zero(ValueType.AngleExpr)) s = "0" + + s + } + + def zero(valueType: ValueType): PhaseExpression = PhaseExpression(0, Map(), valueType) + + private object AngleExpressionParser extends CommonParser(ValueType.AngleExpr) { + + def angleExpression(r: Rational): PhaseExpression = PhaseExpression(r, ValueType.AngleExpr) + + def angleExpression(r: Rational, m: Map[String, Rational]): PhaseExpression = PhaseExpression(r, m, ValueType.AngleExpr) + + def INT_OPT: Parser[Int] = INT.? ^^ { + _.getOrElse(1) + } + + def PI: Parser[Unit] = """\\?[pP][iI]""".r ^^ { _ => Unit } + + + def coeff: Parser[Rational] = + INT ~ "/" ~ INT ^^ { case n ~ _ ~ d => Rational(n, d) } | + "(" ~ coeff ~ ")" ^^ { case _ ~ c ~ _ => c } | + INT ^^ { n => Rational(n) } + + + def frac: Parser[PhaseExpression] = + INT_OPT ~ "*".? ~ PI ~ "/" ~ INT ^^ { case n ~ _ ~ _ ~ _ ~ d => angleExpression(Rational(n, d)) } | + INT_OPT ~ "*".? ~ SYMBOL ~ "/" ~ INT ^^ { + case n ~ _ ~ x ~ _ ~ d => angleExpression(Rational(0), Map(x -> Rational(n, d))) + } + + def term: Parser[PhaseExpression] = + frac | + "-" ~ term ^^ { case _ ~ t => t * -1 } | + coeff ~ "*".? ~ PI ^^ { case c ~ _ ~ _ => angleExpression(c) } | + PI ^^ { _ => one } | + coeff ~ "*".? ~ SYMBOL ^^ { case c ~ _ ~ x => angleExpression(Rational(0), Map(x -> c)) } | + SYMBOL ~ "*" ~ coeff ^^ { case x ~ _ ~ c => angleExpression(Rational(0), Map(x -> c)) } | + SYMBOL ^^ { x => angleExpression(Rational(0), Map(x -> Rational(1))) } | + coeff ^^ angleExpression | + "(" ~ expr ~ ")" ^^ { case _ ~ t ~ _ => t } + + def term1: Parser[PhaseExpression] = + "+" ~ term ^^ { case _ ~ t => t } | + "-" ~ term ^^ { case _ ~ t => t * -1 } + + def terms: Parser[PhaseExpression] = + term1 ~ terms ^^ { case s ~ t => s + t } | + term1 + + def expr: Parser[PhaseExpression] = + term ~ terms ^^ { case s ~ t => s + t } | + term | + "" ^^ { _ => zero } + + } + + + private object RationalExpressionParser extends CommonParser(ValueType.AngleExpr) { + def rationalExpression(r: Rational): PhaseExpression = PhaseExpression(r, ValueType.Rational) + + def rationalExpression(r: Rational, m: Map[String, Rational]): PhaseExpression = + PhaseExpression(r, m, ValueType.Rational) + + def INT_OPT: Parser[Int] = INT.? ^^ { + _.getOrElse(1) + } + + def coeff: Parser[Rational] = + INT ~ "/" ~ INT ^^ { case n ~ _ ~ d => Rational(n, d) } | + "(" ~ coeff ~ ")" ^^ { case _ ~ c ~ _ => c } | + INT ^^ { n => Rational(n) } + + def frac: Parser[PhaseExpression] = + INT_OPT ~ "*".? ~ SYMBOL ~ "/" ~ INT ^^ { + case n ~ _ ~ x ~ _ ~ d => rationalExpression(Rational(0), Map(x -> Rational(n, d))) + } + + def term: Parser[PhaseExpression] = + frac | + "-" ~ term ^^ { case _ ~ t => t * -1 } | + coeff ~ "*".? ~ SYMBOL ^^ { case c ~ _ ~ x => rationalExpression(Rational(0), Map(x -> c)) } | + SYMBOL ^^ { x => rationalExpression(Rational(0), Map(x -> Rational(1))) } | + SYMBOL ~ "*" ~ coeff ^^ { case x ~ _ ~ c => rationalExpression(Rational(0), Map(x -> c)) } | + coeff ^^ rationalExpression | + "(" ~ expr ~ ")" ^^ { case _ ~ t ~ _ => t } + + def term1: Parser[PhaseExpression] = + "+" ~ term ^^ { case _ ~ t => t } | + "-" ~ term ^^ { case _ ~ t => t * -1 } + + def terms: Parser[PhaseExpression] = + term1 ~ terms ^^ { case s ~ t => s + t } | + term1 + + def expr: Parser[PhaseExpression] = + term ~ terms ^^ { case s ~ t => s + t } | + term | + "" ^^ { _ => zero } + + } + + + private object StringExpressionParser { + def zero: PhaseExpression = PhaseExpression.zero(ValueType.String) + + def p(s: String): PhaseExpression = s match { + case "" => zero + case t => PhaseExpression(0, Map(t -> Rational(1,1)), ValueType.String) + } + } + + private abstract class CommonParser(T: ValueType) extends RegexParsers{ + + val zero: PhaseExpression = PhaseExpression.zero(T) + val one: PhaseExpression = PhaseExpression.one(T) + + def INT: Parser[Int] = + """[0-9]+""".r ^^ { + _.toInt + } + + def SYMBOL: Parser[String] = + """[\\a-zA-Z_][a-zA-Z0-9_']*""".r ^^ { + _.toString + } + + final override def skipWhitespace = true + + def expr: Parser[PhaseExpression] + + def p(s: String): PhaseExpression = parseAll(expr, s) match { + case Success(e, _) => e + case Failure(msg, _) => throw PhaseParseException(msg, T) + case Error(msg, _) => throw PhaseParseException(msg, T) + } + } + + private object BooleanExpressionParser extends CommonParser(ValueType.Boolean) { + + def BooleanExpression(i: Int): PhaseExpression = PhaseExpression(Rational(i), ValueType.Boolean) + + def BooleanExpression(i: Int, m: Map[String, Rational]): PhaseExpression = + PhaseExpression(Rational(i), m, ValueType.Boolean) + + def TRUE: Parser[Int] = """\\?[Tt](rue|RUE)?""".r ^^ { _ => 1 } + + def FALSE: Parser[Int] = """\\?[Ff](alse|ALSE)?""".r ^^ { _ => 0 } + + def coeff: Parser[Int] = + TRUE | FALSE | INT + + def INT_OPT: Parser[Int] = INT.? ^^ { + _.getOrElse(0) + } + + + def term: Parser[PhaseExpression] = + "-" ~ term ^^ { case _ ~ t => t } | + TRUE ^^ { _ => one } | + FALSE ^^ { _ => zero } | + coeff ~ "*" ~ coeff ^^ { case c ~ _ ~ x => BooleanExpression(x * c) } | + coeff ~ "*".? ~ SYMBOL ^^ { case c ~ _ ~ x => BooleanExpression(0, Map(x -> 1)) * c } | + SYMBOL ^^ { x => BooleanExpression(0, Map(x -> 1)) } | + coeff ^^ { x => BooleanExpression(x) } | + "(" ~ expr ~ ")" ^^ { case _ ~ t ~ _ => t } + + def term1: Parser[PhaseExpression] = + "+" ~ term ^^ { case _ ~ t => t } | + "-" ~ term ^^ { case _ ~ t => t * -1 } + + def terms: Parser[PhaseExpression] = + term1 ~ terms ^^ { case s ~ t => s + t } | + term1 + + def expr: Parser[PhaseExpression] = + term ~ terms ^^ { case s ~ t => s + t } | + term | + terms | + "" ^^ { _ => zero } + + } + + +} diff --git a/scala/src/main/scala/quanto/data/Project.scala b/scala/src/main/scala/quanto/data/Project.scala index 28371797..fc294409 100644 --- a/scala/src/main/scala/quanto/data/Project.scala +++ b/scala/src/main/scala/quanto/data/Project.scala @@ -1,46 +1,78 @@ package quanto.data +import java.io.File + import quanto.rewrite.Simproc +import quanto.util.FileHelper import quanto.util.json._ -import java.io.File class ProjectLoadException(message: String, cause: Throwable) extends Exception(message, cause) -case class Project(theory: Theory, rootFolder: String, name : String) { - def rules: Vector[String] = rulesInPath(rootFolder) +case class Project(theory: Theory, projectFile: File, name: String) { + require(!projectFile.isDirectory) + + FileHelper.printJson(projectFile.getAbsolutePath, Project.toJson(this)) + + val rootFolder: String = projectFile.getParent + private val rootFolderFile = projectFile.getParentFile + var simprocs: Map[String, Simproc] = Map() + var lastRunPythonFilePath: Option[String] = None + + def relativePath(f: File): String = { + rootFolderFile.toURI.relativize(f.toURI).getPath + } - private def rulesInPath(p: String): Vector[String] = { - val f = new File(p) - if (f.isDirectory) f.listFiles().toVector.flatMap(f => rulesInPath(f.getPath)) - else if (f.getPath.endsWith(".qrule")) { - val fname = f.getPath - Vector(fname.substring(rootFolder.length + 1, fname.length - 6)) + //Scans the given folder for filenames that end in the given extension + def filesEndingIn(ext: String, path: String = rootFolderFile.getAbsolutePath): List[String] = { + val f = new File(path) + if (f.isDirectory) f.listFiles().toList.flatMap(f => filesEndingIn(ext, f.getPath)) + else if (f.getPath.endsWith(ext)) { + val fileName = relativePath(f) + List(fileName.substring(0, fileName.length - ext.length)) } - else Vector() + else List() } } object Project { - def fromTheoryOrProjectFile(theoryOrProjectFile: String, rootFolder: String = "", name : String = "") : Project = { - println(s"Asked to load prjoect from: $theoryOrProjectFile") + def fromJson(json: Json, projectFile: File): Project = try { + Project( + Theory.fromJson(json / "theory"), + projectFile, + (json / "name").stringValue + ) + } catch { + // First try loading as though old format + case _: Exception => + try { + println("Attempting to update project file from old format.") + Project(Theory.fromJson(json / "theory"), projectFile, projectFile.getName) + } + catch { + case e: Exception => throw new ProjectLoadException("Error loading project", e) + } + } + + def fromTheoryOrProjectFile(theoryOrProjectFile: File, rootFolder: File = new File("."), name: String = "main"): Project = { + println(s"Asked to load project from: $theoryOrProjectFile") val theory: Theory = Theory.fromJson({ - val extension: String = theoryOrProjectFile.replaceAll(".*\\.", "") + val extension: String = theoryOrProjectFile.getAbsolutePath.replaceAll(".*\\.", "") try { extension match { case "qtheory" => - Json.parse(new File(theoryOrProjectFile)) + Json.parse(theoryOrProjectFile) case "qproject" => - Json.parse(new File(theoryOrProjectFile)) / "theory" + Json.parse(theoryOrProjectFile) / "theory" case _ => - try{ + try { val theoryStream = Theory.getClass.getResourceAsStream(theoryOrProjectFile + ".qtheory") Json.parse(new Json.Input(theoryStream)) - } catch { - case e: Exception => - throw new ProjectLoadException(s"Could not parse the resource $theoryOrProjectFile", e) - } + } catch { + case e: Exception => + throw new ProjectLoadException(s"Could not parse the resource $theoryOrProjectFile", e) + } } } catch { case e: Exception => @@ -48,28 +80,11 @@ object Project { } } ) - new Project(theory, rootFolder, name) - } - - def fromJson(json: Json, rootFolder: String): Project = try { - Project( - Theory.fromJson(json / "theory"), - rootFolder, - (json / "name").stringValue - ) - } catch { - // First try loading as though old format - case e: Exception => - try { - println("Attempting to update project file from old format.") - Project.fromTheoryOrProjectFile((json / "theory").stringValue, rootFolder) - } - catch { - case e: Exception => throw new ProjectLoadException("Error loading project", e) - } + val suggestedFilename = rootFolder + "/" + name + ".qproject" + new Project(theory, new File(suggestedFilename), name) } - def toJson(project: Project) : Json = { + def toJson(project: Project): Json = { JsonObject( "name" -> project.name, "theory" -> Theory.toJson(project.theory) diff --git a/scala/src/main/scala/quanto/data/ResultSet.scala b/scala/src/main/scala/quanto/data/ResultSet.scala index 444ef48c..a5c00438 100644 --- a/scala/src/main/scala/quanto/data/ResultSet.scala +++ b/scala/src/main/scala/quanto/data/ResultSet.scala @@ -1,59 +1,64 @@ package quanto.data case class ResultLine(rule: RuleDesc, index: Int, total: Int) { - override def toString = { - rule.name + (if (rule.inverse) "[inverse]" else "") + " (" + index + "/" + total + ")" + override def toString: String = { + (if (rule.inverse) "<- " else "-> ") + rule.name + " (" + index + "/" + total + ")" } } -case class ResultSet(rules: Vector[RuleDesc], results: Map[RuleDesc,(Int,Vector[DStep])]) { - def copy(rules: Vector[RuleDesc] = rules, results: Map[RuleDesc,(Int,Vector[DStep])] = results) = - ResultSet(rules, results) - +case class ResultSet(rules: Vector[RuleDesc], results: Map[RuleDesc, (Int, Vector[DStep])]) { def currentResult(rule: RuleDesc): Option[DStep] = - results.get(rule).flatMap { case (i,vec) => if (i == 0) None else Some(vec(i-1)) } + results.get(rule).flatMap { case (i, vec) => if (i == 0) None else Some(vec(i - 1)) } - def resultIndex(rule: RuleDesc) = results(rule)._1 - def setResultIndex(rule: RuleDesc, i: Int) = - copy(results = results + (rule -> (results(rule) match { case (_, vec) => (i, vec) }))) - - def nextResult(rule: RuleDesc) = { + def nextResult(rule: RuleDesc): ResultSet = { val i = resultIndex(rule) + 1 if (i <= numResults(rule)) setResultIndex(rule, i) else this } - def previousResult(rule: RuleDesc) = { + def setResultIndex(rule: RuleDesc, i: Int): ResultSet = + copy(results = results + (rule -> (results(rule) match { + case (_, vec) => (i, vec) + }))) + + def previousResult(rule: RuleDesc): ResultSet = { val i = resultIndex(rule) - 1 if (i > 0) setResultIndex(rule, i) else this } - def replaceGraph(rule: RuleDesc, i: Int, graph: Graph) = { + def replaceGraph(rule: RuleDesc, i: Int, graph: Graph): ResultSet = { val x = results(rule) val res = x._2(i - 1).copy(graph = graph) copy(results = results + (rule -> (x._1, x._2.updated(i - 1, res)))) } - def graph(rule: RuleDesc, i: Int) = results(rule)._2(i - 1).graph + def graph(rule: RuleDesc, i: Int): Graph = results(rule)._2(i - 1).graph - def numResults(rule: RuleDesc) = results(rule)._2.size - def +(res: (RuleDesc,DStep)) = { + def +(res: (RuleDesc, DStep)): ResultSet = { val rs = results(res._1) copy(results = results + (res._1 -> (if (rs._1 == 0) 1 else rs._1, rs._2 :+ res._2))) } - def -(rule: RuleDesc) = { + def -(rule: RuleDesc): ResultSet = { copy(rules = rules.filter(_ != rule), results = results - rule) } - def resultLines = rules.map { r => ResultLine(r, resultIndex(r), numResults(r)) } + def copy(rules: Vector[RuleDesc] = rules, results: Map[RuleDesc, (Int, Vector[DStep])] = results) = + ResultSet(rules, results) + + def resultLines: Vector[ResultLine] = rules.map { r => ResultLine(r, resultIndex(r), numResults(r)) } + + def resultIndex(rule: RuleDesc): Int = results(rule)._1 + + def numResults(rule: RuleDesc): Int = results(rule)._2.size } object ResultSet { def apply(rules: Vector[RuleDesc]): ResultSet = { - val results = rules.foldLeft(Map[RuleDesc,(Int,Vector[DStep])]()) { - case (rs, rule) => rs + (rule -> (0, Vector())) } + val results = rules.foldLeft(Map[RuleDesc, (Int, Vector[DStep])]()) { + case (rs, rule) => rs + (rule -> (0, Vector())) + } ResultSet(rules, results) } } diff --git a/scala/src/main/scala/quanto/data/Rule.scala b/scala/src/main/scala/quanto/data/Rule.scala index 1e643918..3ad1fff9 100644 --- a/scala/src/main/scala/quanto/data/Rule.scala +++ b/scala/src/main/scala/quanto/data/Rule.scala @@ -11,19 +11,46 @@ case class RuleLoadException(message: String, cause: Throwable = null) case class Rule(private val _lhs: Graph, private val _rhs: Graph, derivation: Option[String] = None, - description: RuleDesc = RuleDesc("unnamed")) { + description: RuleDesc = RuleDesc()) { + + val lhs: Graph = if (description.inverse) _rhs else _lhs + val rhs: Graph = if (description.inverse) _lhs else _rhs + val name: String = description.name + (if (description.inverse) " inverted" else "") def inverse: Rule = { Rule(lhs, rhs, derivation, description.invert) } - val lhs: Graph = if (description.inverse) _rhs else _lhs + def hasBBoxes: Boolean = lhs.bboxes.nonEmpty || rhs.bboxes.nonEmpty - val rhs: Graph = if (description.inverse) _lhs else _rhs + def map(f: Graph => Graph): Rule = { + new Rule(f(lhs), f(rhs)) + } - val name: String = description.name + (if (description.inverse) " inverted" else "") + def colourSwap(changes: Map[String, String]): Rule = { + def safeChanges(s: String) : String = { + changes.get(s) match { + case Some(t) => t + case None => s + } + } + map(graph => { + graph.verts.foldLeft(graph) { (g, v) => + g.updateVData(v)(f = { + case n: NodeV => + n.copy(data = JsonObject( + "type" -> safeChanges((n.data / "type").stringValue), + "value" -> (n.data / "value") + )) + case m => + m + } + ) + } + }) + } - override def toString: String = name + " := "+ _lhs.toString + + override def toString: String = name + " := " + _lhs.toString + (if (description.inverse) { "<--" } else { @@ -36,6 +63,10 @@ case class RuleDesc(name: String = "unnamed", inverse: Boolean = false) { def invert: RuleDesc = RuleDesc(name, !inverse) } +object RuleDesc { + implicit def fromString(string: String) : RuleDesc = RuleDesc(string) +} + object Rule { def fromJson(json: Json, thy: Theory = Theory.DefaultTheory, description: Option[RuleDesc] = None): Rule = try { Rule(_lhs = Graph.fromJson(json / "lhs", thy), @@ -46,7 +77,7 @@ object Rule { }, description = if (description.isDefined) description.get else json.get("description") match { case Some(JsonString(s)) => RuleDesc(s); - case _ => RuleDesc("unnamed") + case _ => RuleDesc() }) } catch { case e: JsonAccessException => @@ -69,4 +100,11 @@ object Rule { case None => obj } } + + def namesUsed(rule: Rule, theory: Theory) : Set[String] = { + + val namesUsedInLHS = Graph.variablesUsed(theory, rule.lhs) + val namesUsedInRHS = Graph.variablesUsed(theory, rule.rhs) + namesUsedInLHS union namesUsedInRHS + } } \ No newline at end of file diff --git a/scala/src/main/scala/quanto/data/Theory.scala b/scala/src/main/scala/quanto/data/Theory.scala index 87744727..cd495ba7 100644 --- a/scala/src/main/scala/quanto/data/Theory.scala +++ b/scala/src/main/scala/quanto/data/Theory.scala @@ -4,96 +4,261 @@ import quanto.util.json._ import JsonValues._ import java.awt.{Color, Shape} + /** - * Exception thrown when theory cannot be created for some reason - * - * @author Aleks Kissinger - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/Theory.scala Source code]] - */ + * Exception thrown when theory cannot be created for some reason + * + * @author Aleks Kissinger + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/Theory.scala Source code]] + */ class TheoryLoadException(message: String, cause: Throwable = null) extends Exception(message, cause) /** - * A class which represents a theory - * - * @author Aleks Kissinger - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/Theory.scala Source code]] - */ + * A class which represents a theory + * + * @author Aleks Kissinger + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/Theory.scala Source code]] + */ case class Theory( - name: String, - coreName: String, - vertexTypes: Map[String, Theory.VertexDesc], - edgeTypes: Map[String, Theory.EdgeDesc] = Map("plain" -> Theory.PlainEdgeDesc), - defaultVertexType: String, - defaultEdgeType: String = "plain") -{ - def defaultVertexData = vertexTypes(defaultVertexType).defaultData - def defaultEdgeData = edgeTypes(defaultEdgeType).defaultData - override def toString = coreName + name: String, + coreName: String, + vertexTypes: Map[String, Theory.VertexDesc], + edgeTypes: Map[String, Theory.EdgeDesc] = Map("plain" -> Theory.PlainEdgeDesc), + defaultVertexType: String, + defaultEdgeType: String = "plain") { + def defaultVertexData: JsonObject = vertexTypes(defaultVertexType).defaultData + + def defaultEdgeData: JsonObject = edgeTypes(defaultEdgeType).defaultData + + def mixin(that: Theory, overwriteName: Option[String]): Theory = { + // New overwrites old! + Theory(overwriteName.getOrElse(name), + overwriteName.getOrElse(coreName), + vertexTypes ++ that.vertexTypes, + edgeTypes ++ that.edgeTypes, + that.defaultVertexType, that.defaultEdgeType) + } + + def mixin( + newVertexTypes: Map[String, Theory.VertexDesc] = Map(), + newEdgeTypes: Map[String, Theory.EdgeDesc] = Map(), + newName: Option[String] = None): Theory = { + // Overwrites existing types of the same name + Theory(newName.getOrElse(name), + newName.getOrElse(coreName), + vertexTypes ++ newVertexTypes, + edgeTypes ++ newEdgeTypes, + defaultVertexType, defaultEdgeType) + } + + def copy(name: String = name, + coreName: String = coreName, + vertexTypes: Map[String, Theory.VertexDesc] = vertexTypes, + edgeTypes: Map[String, Theory.EdgeDesc] = edgeTypes, + defaultVertexType: String = defaultVertexType, + defaultEdgeType: String = defaultEdgeType): Theory = Theory( + name, coreName, vertexTypes, edgeTypes, defaultVertexType, defaultEdgeType + ) + + override def toString: String = coreName } /** - * Companion object for the Theory class. Contains useful methods for - * converting a Theory to/from JSON object - * - * @author Aleks Kissinger - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/Theory.scala Source code]] - */ + * Companion object for the Theory class. Contains useful methods for + * converting a Theory to/from JSON object + * + * @author Aleks Kissinger + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/Theory.scala Source code]] + */ object Theory { - private implicit def jsonToColor(json: Json) = json match { - case JsonArray(Vector(r,g,b,a)) => new Color(r.floatValue,g.floatValue,b.floatValue,a.floatValue) - case JsonArray(Vector(r,g,b)) => new Color(r.floatValue,g.floatValue,b.floatValue) - case _ => throw new JsonParseException("Expected array of 3 or 4 doubles, got: " + json) - } - private implicit def jsonOptToColorOpt(jopt: Option[Json]) = jopt.map(jsonToColor(_)) - private implicit def colorToJson(c: Color) = + private implicit def jsonToColor(json: Json): Color = json match { + case JsonArray(Vector(r, g, b, a)) => new Color(r.floatValue, g.floatValue, b.floatValue, a.floatValue) + case JsonArray(Vector(r, g, b)) => new Color(r.floatValue, g.floatValue, b.floatValue) + case _ => throw new JsonParseException("Expected array of 3 or 4 doubles, got: " + json) + } + + private implicit def jsonOptToColorOpt(jopt: Option[Json]): Option[Color] = jopt.map(jsonToColor) + + private implicit def colorToJson(c: Color): JsonArray = c.getRGBComponents(null) match { - case Array(r,g,b,1.0) => JsonArray(r,g,b) - case Array(r,g,b,a) => JsonArray(r,g,b,a) + case Array(r, g, b, 1.0) => JsonArray(r, g, b) + case Array(r, g, b, a) => JsonArray(r, g, b, a) + } + + type ValueType = ValueType.Value + type VertexShape = VertexShape.Value + type VertexLabelPosition = VertexLabelPosition.Value + type EdgeLabelPosition = EdgeLabelPosition.Value + val PlainEdgeDesc = EdgeDesc( + value = ValueDesc(typ = Vector(ValueType.Empty)), + style = EdgeStyleDesc(), + defaultData = JsonObject("type" -> "plain") + ) + val DefaultTheory = Theory( + name = "String theory", + coreName = "string_theory", + vertexTypes = Map( + "string" -> VertexDesc( + value = ValueDesc( + typ = Vector(ValueType.String) + ), + style = VertexStyleDesc( + shape = VertexShape.Rectangle, + labelPosition = VertexLabelPosition.Inside + ), + defaultData = JsonObject("type" -> "string", "value" -> "") + ) + ), + defaultVertexType = "string" + ) + + /** + * Same as '''fromJson(json : Json)''', but tries to parse a string to a Json object first + * + * @throws TheoryLoadException Exception thrown when theory cannot be created for some reason + */ + def fromJson(s: String): Theory = + try { + fromJson(Json.parse(s)) + } + catch { + case e: JsonParseException => throw new TheoryLoadException("Error parsing JSON", e) + } + + /** Convert the theory to a JSON object */ + def toJson(thy: Theory): Json = JsonObject( + "name" -> thy.name, + "core_name" -> thy.coreName, + "vertex_types" -> JsonObject(thy.vertexTypes.mapValues(x => x: Json)), + "edge_types" -> JsonObject(thy.edgeTypes.mapValues(x => x: Json)), + "default_vertex_type" -> thy.defaultVertexType, + "default_edge_type" -> thy.defaultEdgeType + ).noEmpty + + /** + * Load a built-in theory from JSON file + * + * @param theoryFile name of the .qtheory file, without extension + * @return a theory object + */ + def fromFile(theoryFile: String): Theory = { + Theory.fromJson(Json.parse( + new Json.Input(Theory.getClass.getResourceAsStream(theoryFile + ".qtheory")))) + } + + /** + * Create a theory instance from a Json object + */ + def fromJson(json: Json): Theory = { + try { + val name = (json / "name").stringValue + val coreName = (json / "core_name").stringValue + val vertexTypes = (json / "vertex_types").asObject.mapValues(x => x: VertexDesc) + val defaultVertexType = json / "default_vertex_type" + if (!vertexTypes.contains(defaultVertexType)) + throw new TheoryLoadException("Default vertex type: " + defaultVertexType + " not in list.") + + val edgeTypes = json.get("edge_types") match { + case Some(et) => et.asObject.mapValues(x => x: EdgeDesc) + case None => Map("plain" -> PlainEdgeDesc) + } + val defaultEdgeType = json.getOrElse("default_edge_type", "plain").stringValue + + if (!edgeTypes.contains(defaultEdgeType)) + throw new TheoryLoadException("Default edge type: " + defaultEdgeType + " not in list.") + + Theory( + name, + coreName, + vertexTypes, + edgeTypes, + defaultVertexType, + defaultEdgeType + ) + } catch { + case e: JsonAccessException => throw new TheoryLoadException("Error reading JSON", e) } + } + + case class ValueDesc( + typ: Vector[ValueType] = Vector(ValueType.Empty), + enumOptions: Vector[String] = Vector[String](), + latexConstants: Boolean = false, + validateWithCore: Boolean = false + ) + + case class VertexStyleDesc( + shape: VertexShape, + customShape: Option[Shape] = None, + strokeWidth: Int = 1, + strokeColor: Color = Color.BLACK, + fillColor: Color = Color.WHITE, + labelPosition: VertexLabelPosition = VertexLabelPosition.Center, + labelForegroundColor: Color = Color.BLACK, + labelBackgroundColor: Option[Color] = None + ) + + case class EdgeStyleDesc( + strokeColor: Color = Color.BLACK, + strokeWidth: Int = 1, + labelPosition: EdgeLabelPosition = EdgeLabelPosition.Auto, + labelForegroundColor: Color = Color.BLACK, + labelBackgroundColor: Option[Color] = None + ) + + case class VertexDesc( + value: ValueDesc, + style: VertexStyleDesc, + defaultData: JsonObject + ) + + case class EdgeDesc( + value: ValueDesc, + style: EdgeStyleDesc, + defaultData: JsonObject + ) object ValueType extends Enumeration with JsonEnumConversions { - val String = Value("string") - val AngleExpr = Value("angle_expr") - val LongString = Value("long_string") - val Enum = Value("enum") - val Empty = Value("empty") + val String: ValueType = Value("string") + val AngleExpr: ValueType = Value("angle_expr") + val Boolean: ValueType = Value("boolean") + val Rational: ValueType = Value("rational") + val Integer: ValueType = Value("integer") + val Long: ValueType = Value("long") + val Enum: ValueType = Value("enum") + val Empty: ValueType = Value("empty") } - type ValueType = ValueType.Value object VertexShape extends Enumeration with JsonEnumConversions { - val Circle = Value("circle") - val Rectangle = Value("rectangle") - val Custom = Value("custom") + val Circle: VertexShape = Value("circle") + val Rectangle: VertexShape = Value("rectangle") + val Custom: VertexShape = Value("custom") + + + def fromName(name: String): Option[VertexShape] = this.values.find(v => v.toString == name) } - type VertexShape = VertexShape.Value object VertexLabelPosition extends Enumeration with JsonEnumConversions { - val Center = Value("center") - val Inside = Value("inside") - val Below = Value("below") + val Center: VertexLabelPosition = Value("center") + val Inside: VertexLabelPosition = Value("inside") + val Below: VertexLabelPosition = Value("below") + + def fromName(name: String): Option[VertexLabelPosition] = this.values.find(v => v.toString == name) } - type VertexLabelPosition = VertexLabelPosition.Value object EdgeLabelPosition extends Enumeration with JsonEnumConversions { - val Center = Value("center") - val Auto = Value("auto") + val Center: EdgeLabelPosition = Value("center") + val Auto: EdgeLabelPosition = Value("auto") } - type EdgeLabelPosition = EdgeLabelPosition.Value - case class ValueDesc( - typ: ValueType = ValueType.Empty, - enumOptions: Vector[String] = Vector[String](), - latexConstants: Boolean = false, - validateWithCore: Boolean = false - ) object ValueDesc { implicit def fromJson(json: Json): ValueDesc = ValueDesc( - typ = json / "type", + typ = CompositeExpression.parseTypes(json / "type"), enumOptions = (json ? "enum_options").vectorValue.map(_.stringValue), latexConstants = json.getOrElse("latex_constants", false), validateWithCore = json.getOrElse("validate_with_core", false) @@ -101,27 +266,19 @@ object Theory { implicit def toJson(v: ValueDesc): JsonObject = JsonObject( - "type" -> v.typ, + "type" -> v.typ.mkString(","), "enum_options" -> v.enumOptions, "latex_constants" -> v.latexConstants, "validate_with_core" -> v.validateWithCore ).noEmpty } - case class VertexStyleDesc( - shape: VertexShape, - customShape: Option[Shape] = None, - strokeColor: Color = Color.BLACK, - fillColor: Color = Color.WHITE, - labelPosition: VertexLabelPosition = VertexLabelPosition.Center, - labelForegroundColor: Color = Color.BLACK, - labelBackgroundColor: Option[Color] = None - ) // TODO: implement custom shapes object VertexStyleDesc { - implicit def fromJson(json: Json) = VertexStyleDesc( - shape = (json / "shape"), + implicit def fromJson(json: Json): VertexStyleDesc = VertexStyleDesc( + shape = json / "shape", customShape = None, + strokeWidth = json.getOrElse("stroke_width", 1), strokeColor = json.getOrElse("stroke_color", Color.BLACK), fillColor = json.getOrElse("fill_color", Color.WHITE), labelPosition = (json ? "label").getOrElse("position", VertexLabelPosition.Center), @@ -129,30 +286,23 @@ object Theory { labelBackgroundColor = (json ? "label").get("bg_color") ) - implicit def toJson(v: VertexStyleDesc) = + implicit def toJson(v: VertexStyleDesc): JsonObject = JsonObject( "shape" -> v.shape, "custom_shape" -> JsonNull, + "stroke_width" -> v.strokeWidth, "stroke_color" -> v.strokeColor, "fill_color" -> v.fillColor, "label" -> JsonObject( "position" -> v.labelPosition, "fg_color" -> v.labelForegroundColor, - "bg_color" -> v.labelBackgroundColor.map(x=>x:Json).getOrElse(JsonNull) + "bg_color" -> v.labelBackgroundColor.map(x => x: Json).getOrElse(JsonNull) ).noEmpty ).noEmpty } - case class EdgeStyleDesc( - strokeColor: Color = Color.BLACK, - strokeWidth: Int = 1, - labelPosition: EdgeLabelPosition = EdgeLabelPosition.Auto, - labelForegroundColor: Color = Color.BLACK, - labelBackgroundColor: Option[Color] = None - ) - object EdgeStyleDesc { - implicit def fromJson(json: Json) = EdgeStyleDesc( + implicit def fromJson(json: Json): EdgeStyleDesc = EdgeStyleDesc( strokeColor = json.getOrElse("stroke_color", Color.BLACK), strokeWidth = json.getOrElse("stroke_width", 1), labelPosition = (json ? "label").getOrElse("position", EdgeLabelPosition.Auto), @@ -160,139 +310,43 @@ object Theory { labelBackgroundColor = (json ? "label").get("bg_color") ) - implicit def toJson(v: EdgeStyleDesc) = + implicit def toJson(v: EdgeStyleDesc): JsonObject = JsonObject( "stroke_color" -> v.strokeColor, "stroke_width" -> v.strokeWidth, "label" -> JsonObject( "position" -> v.labelPosition, "fg_color" -> v.labelForegroundColor, - "bg_color" -> v.labelBackgroundColor.map(x=>x:Json).getOrElse(JsonNull) + "bg_color" -> v.labelBackgroundColor.map(x => x: Json).getOrElse(JsonNull) ).noEmpty ).noEmpty } - case class VertexDesc( - value: ValueDesc, - style: VertexStyleDesc, - defaultData: JsonObject - ) object VertexDesc { - implicit def fromJson(json: Json) = VertexDesc( + implicit def fromJson(json: Json): VertexDesc = VertexDesc( value = json / "value", style = json / "style", - defaultData = (json / "default_data").asObject - ) - implicit def toJson(v: VertexDesc) = JsonObject( + defaultData = json.getOrElse("default_data", "").asObject) + + implicit def toJson(v: VertexDesc): JsonObject = JsonObject( "value" -> v.value, "style" -> v.style, "default_data" -> v.defaultData ) } - case class EdgeDesc( - value: ValueDesc, - style: EdgeStyleDesc, - defaultData: JsonObject - ) object EdgeDesc { - implicit def fromJson(json: Json) = EdgeDesc( - value = (json / "value"), - style = (json / "style"), + implicit def fromJson(json: Json): EdgeDesc = EdgeDesc( + value = json / "value", + style = json / "style", defaultData = (json / "default_data").asObject ) - implicit def toJson(v: EdgeDesc) = JsonObject( + + implicit def toJson(v: EdgeDesc): JsonObject = JsonObject( "value" -> v.value, "style" -> v.style, "default_data" -> v.defaultData ) } - /** - * Same as '''fromJson(json : Json)''', but tries to parse a string to a Json object first - * @throws TheoryLoadException - */ - def fromJson(s: String): Theory = - try { fromJson(Json.parse(s)) } - catch { case e:JsonParseException => throw new TheoryLoadException("Error parsing JSON", e) } - - /** - * Create a theory instance from a Json object - * @throws TheoryLoadException - */ - def fromJson(json: Json): Theory = { - try { - val name = (json / "name").stringValue - val coreName = (json / "core_name").stringValue - val vertexTypes = (json / "vertex_types").asObject.mapValues(x => x:VertexDesc) - val defaultVertexType = (json / "default_vertex_type") - if (!vertexTypes.contains(defaultVertexType)) - throw new TheoryLoadException("Default vertex type: " + defaultVertexType + " not in list.") - - val edgeTypes = json.get("edge_types") match { - case Some(et) => et.asObject.mapValues(x => x:EdgeDesc) - case None => Map("plain"->PlainEdgeDesc) - } - val defaultEdgeType = json.getOrElse("default_edge_type", "plain").stringValue - - if (!edgeTypes.contains(defaultEdgeType)) - throw new TheoryLoadException("Default edge type: " + defaultEdgeType + " not in list.") - - Theory( - name, - coreName, - vertexTypes, - edgeTypes, - defaultVertexType, - defaultEdgeType - ) - } catch { - case e: JsonAccessException => throw new TheoryLoadException("Error reading JSON", e) - } - } - - /** Convert the theory to a JSON object */ - def toJson(thy: Theory): Json = JsonObject( - "name" -> thy.name, - "core_name" -> thy.coreName, - "vertex_types" -> JsonObject(thy.vertexTypes.mapValues(x => x:Json)), - "edge_types" -> JsonObject(thy.edgeTypes.mapValues(x => x:Json)), - "default_vertex_type" -> thy.defaultVertexType, - "default_edge_type" -> thy.defaultEdgeType - ).noEmpty - - val PlainEdgeDesc = EdgeDesc( - value = ValueDesc(typ = ValueType.Empty), - style = EdgeStyleDesc(), - defaultData = JsonObject("type" -> "plain") - ) - - val DefaultTheory = Theory( - name = "String theory", - coreName = "string_theory", - vertexTypes = Map( - "string" -> VertexDesc( - value = ValueDesc( - typ = ValueType.String - ), - style = VertexStyleDesc( - shape = VertexShape.Rectangle, - labelPosition = VertexLabelPosition.Inside - ), - defaultData = JsonObject("type" -> "string", "value" -> "") - ) - ), - defaultVertexType = "string" - ) - - /** - * Load a built-in theory from JSON file - * - * @param theoryFile name of the .qtheory file, without extension - * @return a theory object - */ - def fromFile(theoryFile: String): Theory = { - Theory.fromJson(Json.parse( - new Json.Input(Theory.getClass.getResourceAsStream(theoryFile + ".qtheory")))) - } } diff --git a/scala/src/main/scala/quanto/data/VData.scala b/scala/src/main/scala/quanto/data/VData.scala index 7d825298..d31b860e 100644 --- a/scala/src/main/scala/quanto/data/VData.scala +++ b/scala/src/main/scala/quanto/data/VData.scala @@ -1,125 +1,136 @@ package quanto.data -import quanto.data.Theory.ValueType import quanto.util.json._ /** - * An abstract class which provides a general interface for accessing - * vertex data - * - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - * @author Aleks Kissinger - */ + * An abstract class which provides a general interface for accessing + * vertex data + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + * @author Aleks Kissinger + */ abstract class VData extends GraphElementData { - def annotation : JsonObject + def annotation: JsonObject /** - * Get coordinates of vertex - * @throws JsonAccessException - * @return actual coordinates of vertex or (0,0) if none are specified - */ + * Get coordinates of vertex + * + * @return actual coordinates of vertex or (0,0) if none are specified + */ def coord: (Double, Double) = annotation.get("coord") match { - case Some(JsonArray(Vector(x,y))) => (x.doubleValue, y.doubleValue) + case Some(JsonArray(Vector(x, y))) => (x.doubleValue, y.doubleValue) case Some(otherJson) => throw new JsonAccessException("Expected: array with 2 elements", otherJson) - case None => (0,0) + case None => (0, 0) } /** Create a copy of the current vertex with the new coordinates */ - def withCoord(c: (Double,Double)): VData + def withCoord(c: (Double, Double)): VData + def typ: String def isWireVertex: Boolean - def isBoundary : Boolean + + def isBoundary: Boolean } /** - * Companion object for the VData class. Contains a method getCoord which has - * the same behaviour as VData.coord, but is static. - * - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - * @author Aleks Kissinger - */ + * Companion object for the VData class. Contains a method getCoord which has + * the same behaviour as VData.coord, but is static. + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + * @author Aleks Kissinger + */ object VData { - def getCoord(annotation: Json): (Double,Double) = annotation.get("coord") match { - case Some(JsonArray(Vector(x,y))) => (x.doubleValue, y.doubleValue) + def getCoord(annotation: Json): (Double, Double) = annotation.get("coord") match { + case Some(JsonArray(Vector(x, y))) => (x.doubleValue, y.doubleValue) case Some(otherJson) => throw new JsonAccessException("Expected: array with 2 elements", otherJson) - case None => (0,0) + case None => (0, 0) } } /** - * A class which represents node vertex data. - * - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - */ + * A class which represents node vertex data. + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + */ case class NodeV( - data: JsonObject = Theory.DefaultTheory.defaultVertexData, - annotation: JsonObject = JsonObject(), - theory: Theory = Theory.DefaultTheory) extends VData -{ + data: JsonObject = Theory.DefaultTheory.defaultVertexData, + annotation: JsonObject = JsonObject(), + theory: Theory = Theory.DefaultTheory) extends VData { /** Type of the vertex */ - val typ = (data / "type").stringValue - -// def label = data.getOrElse("label","").stringValue - def typeInfo = theory.vertexTypes(typ) - + val typ: String = (data / "type").stringValue // support input of old-style graphs, where data may be stored at value/pretty val value: String = data ? "value" match { - case str : JsonString => str.stringValue - case obj : JsonObject => obj.getOrElse("pretty", JsonString("")).stringValue + case str: JsonString => str.stringValue + case obj: JsonObject => obj.getOrElse("pretty", JsonString("")).stringValue case _ => "" } + def newValue(value: String): NodeV = NodeV(data = JsonObject( + "type" -> typ, + "value" -> value + ), annotation = annotation, theory = theory) + + // if the theory says this node should have a value, try to parse it, + // and store it in "phaseData". If it should have a value, but parsing fails, set + // it to empty. + lazy val (phaseData: CompositeExpression, hasValue: Boolean) = + try { + val phaseTypes = theory.vertexTypes(typ).value.typ + val phaseValues = CompositeExpression.parseKnowingTypes(value, phaseTypes) + (CompositeExpression(phaseTypes, phaseValues), true) + } + catch { + // YOU WILL END UP HERE IF THE PARSER WAS HANDED AN UNFAMILIAR VALUETYPE + // See uses of PhaseParseException to pinpoint where + case _: PhaseParseException => (CompositeExpression(Vector(), Vector()), false) + } + + // def label = data.getOrElse("label","").stringValue + def typeInfo = theory.vertexTypes(typ) - // if the theory says this node should have an angle, try to parse it from value, - // and store it in "angle". If it should have an angle, but parsing fails, set - // angle to "0". - val (angle: AngleExpression, hasAngle: Boolean) = - if (theory.vertexTypes(typ).value.typ == ValueType.AngleExpr) - try { (AngleExpression.parse(value), true) } - catch { case _: AngleParseException => (AngleExpression(), true) } - else (AngleExpression(), false) + def withCoord(c: (Double, Double)): NodeV = + copy(annotation = annotation + ("coord" -> JsonArray(c._1, c._2))) - def withCoord(c: (Double,Double)) = - copy(annotation = annotation + ("coord" -> JsonArray(c._1, c._2))) + /** Create a copy of the current vertex with the new value */ + def withValue(s: String): NodeV = + copy(data = data.setPath("$.value", s).asObject) - /** Create a copy of the current vertex with the new value */ - def withValue(s: String) = - copy(data = data.setPath("$.value", s).asObject) + def withTyp(s: String): NodeV = + copy(data = data.setPath("$.type", s).asObject) - def withTyp(s: String) = - copy(data = data.setPath("$.type", s).asObject) + def isWireVertex = false - def isWireVertex = false - def isBoundary = false + def isBoundary = false - override def toJson = + override def toJson: JsonObject = if (data == theory.defaultVertexData) JsonObject("annotation" -> annotation).noEmpty else JsonObject( "data" -> data, "annotation" -> annotation).noEmpty - } +} /** - * Companion object for the NodeV class. Contains methods to convert to/from - * JSON and a factory method to create instances of NodeV from a pair of - * coordinates. - * - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - * @author Aleks Kissinger - */ + * Companion object for the NodeV class. Contains methods to convert to/from + * JSON and a factory method to create instances of NodeV from a pair of + * coordinates. + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + * @author Aleks Kissinger + */ object NodeV { - def apply(coord: (Double,Double)): NodeV = NodeV(annotation = JsonObject("coord" -> JsonArray(coord._1,coord._2))) + def apply(coord: (Double, Double)): NodeV = NodeV(annotation = JsonObject("coord" -> JsonArray(coord._1, coord._2))) - def toJson(d: NodeV, theory: Theory) = JsonObject( + def toJson(d: NodeV, theory: Theory): JsonObject = JsonObject( "data" -> (if (d.data == theory.vertexTypes(d.typ).defaultData) JsonNull else d.data), "annotation" -> d.annotation).noEmpty + def fromJson(json: Json, thy: Theory = Theory.DefaultTheory): NodeV = { val data = json.getOrElse("data", thy.defaultVertexData).asObject val annotation = (json ? "annotation").asObject @@ -133,36 +144,45 @@ object NodeV { } /** - * A class which represents wire vertex data - * - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - */ + * A class which represents wire vertex data + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + */ case class WireV( - data: JsonObject = JsonObject(), - annotation: JsonObject = JsonObject(), - theory: Theory = Theory.DefaultTheory) extends VData -{ + data: JsonObject = JsonObject(), + annotation: JsonObject = JsonObject(), + theory: Theory = Theory.DefaultTheory) extends VData { def typ = "wire" + def isWireVertex = true - def isBoundary = annotation.get("boundary") match { case Some(JsonBool(b)) => b; case _ => false } - def withCoord(c: (Double,Double)) = + + def isBoundary: Boolean = annotation.get("boundary") match { + case Some(JsonBool(b)) => b; + case _ => false + } + + def withCoord(c: (Double, Double)): WireV = copy(annotation = annotation + ("coord" -> JsonArray(c._1, c._2))) + + def makeBoundary(b: Boolean): WireV = + copy(annotation = annotation + ("boundary" -> JsonBool(b))) } /** - * A companion object for the WireV class. Contains methods to convert to/from - * JSON and a factory method to create instances of WireV from a pair of - * coordinates - * - * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]] - * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] - */ + * A companion object for the WireV class. Contains methods to convert to/from + * JSON and a factory method to create instances of WireV from a pair of + * coordinates + * + * @see [[https://github.com/Quantomatic/quantomatic/blob/scala-frontend/scala/src/main/scala/quanto/data/VData.scala Source code]] + * @see [[https://github.com/Quantomatic/quantomatic/blob/integration/docs/json_formats.txt json_formats.txt]] + */ object WireV { - def apply(c: (Double,Double)): WireV = WireV(annotation = JsonObject("coord" -> JsonArray(c._1,c._2))) + def apply(c: (Double, Double)): WireV = WireV(annotation = JsonObject("coord" -> JsonArray(c._1, c._2))) - def toJson(d: NodeV, theory: Theory) = JsonObject( + def toJson(d: NodeV, theory: Theory): JsonObject = JsonObject( "data" -> d.data, "annotation" -> d.annotation).noEmpty + def fromJson(json: Json, thy: Theory = Theory.DefaultTheory): WireV = WireV((json ? "data").asObject, (json ? "annotation").asObject, thy) } diff --git a/scala/src/main/scala/quanto/gui/AddRuleDialog.scala b/scala/src/main/scala/quanto/gui/AddRuleDialog.scala index da3a989b..5ac80839 100644 --- a/scala/src/main/scala/quanto/gui/AddRuleDialog.scala +++ b/scala/src/main/scala/quanto/gui/AddRuleDialog.scala @@ -8,82 +8,73 @@ import quanto.util.Globals import scala.util.matching import scala.util.matching.Regex +import quanto.util.UserOptions.{scale, scaleInt} class AddRuleDialog(project: Project) extends Dialog { modal = true - implicit def buttonIsSelected(radButton: RadioButton) : Boolean = radButton.selected - + implicit def buttonIsSelected(radButton: RadioButton): Boolean = radButton.selected + var cancelled = false def result: Seq[RuleDesc] = { - val unsorted = if (MainPanel.radIncludeForwards) { - MainPanel.FilteredRuleList.selection.items.map(s => RuleDesc(s)) - } else if (MainPanel.radIncludeInverse) { - MainPanel.FilteredRuleList.selection.items.map(s => RuleDesc(s, inverse = true)) + if(!cancelled) { + val unsorted = if (MainPanel.radIncludeForwards) { + MainPanel.selection.map(s => RuleDesc(s)) + } else if (MainPanel.radIncludeInverse) { + MainPanel.selection.map(s => RuleDesc(s, inverse = true)) + } else { + MainPanel.selection.flatMap(s => Seq(RuleDesc(s), RuleDesc(s, inverse = true))) + } + unsorted.sortBy(rd => rd.name) } else { - MainPanel.FilteredRuleList.selection.items.flatMap(s => Seq(RuleDesc(s), RuleDesc(s, inverse = true))) + Seq() } - unsorted.sortBy(rd => rd.name) } val AddButton = new Button("Add") val CancelButton = new Button("Cancel") defaultButton = Some(AddButton) -// val dir = Files.newDirectoryStream(Paths.get(rootDir), "**/*.qrule") -// for (p <- dir.asScala) println(p) + // val dir = Files.newDirectoryStream(Paths.get(rootDir), "**/*.qrule") + // for (p <- dir.asScala) println(p) - val MainPanel = new BoxPanel(Orientation.Vertical) { - val Search = new TextField - val InitialRules : Vector[String] = project.rules.sorted - var FilteredRuleList : ListView[String] = new ListView[String](InitialRules) - - val RulePane = new ScrollPane(FilteredRuleList) + val MainPanel = new BoxPanel(Orientation.Vertical) { + val FList = new FilteredList(project.filesEndingIn(".qrule")) + contents += FList + def selection : List[String] = FList.ListComponent.selection.items.toList val radIncludeForwards = new RadioButton("Forwards") - radIncludeForwards.selected = true val radIncludeInverse = new RadioButton("Inverted") var radIncludeInverseAndForwards= new RadioButton("Both") + radIncludeInverseAndForwards.selected = true val radGroupIncludeInverse = new ButtonGroup(radIncludeForwards, radIncludeInverse, radIncludeInverseAndForwards) - RulePane.preferredSize = new Dimension(400,200) - contents += Swing.VStrut(10) - contents += new BoxPanel(Orientation.Horizontal) { - contents += (Swing.HStrut(10), new Label("Filter:"), Swing.HStrut(5), Search, Swing.HStrut(10)) - } contents += Swing.VStrut(5) + contents += new BoxPanel(Orientation.Horizontal) { - contents += (Swing.HStrut(10), RulePane, Swing.HStrut(10)) + contents += (AddButton, Swing.HStrut(5), CancelButton) } + contents += new BoxPanel(Orientation.Horizontal) { contents += (Swing.HStrut(10), radIncludeForwards, Swing.HStrut(10)) contents += (Swing.HStrut(10), radIncludeInverse, Swing.HStrut(10)) contents += (Swing.HStrut(10), radIncludeInverseAndForwards, Swing.HStrut(10)) } - contents += Swing.VStrut(5) - contents += new BoxPanel(Orientation.Horizontal) { - contents += (AddButton, Swing.HStrut(5), CancelButton) - } contents += Swing.VStrut(10) + } + contents = MainPanel - listenTo(AddButton, CancelButton, MainPanel.Search) + listenTo(AddButton, CancelButton) reactions += { case ButtonClicked(AddButton) => + cancelled = false close() case ButtonClicked(CancelButton) => - MainPanel.FilteredRuleList.selection.indices.clear() + cancelled = true close() - case ValueChanged(MainPanel.Search) => - try { - MainPanel.FilteredRuleList.listData = MainPanel.InitialRules.filter( - s => s.matches(".*" + MainPanel.Search.text + ".*")) - } catch { - case e: Exception => - //Exceptions here are thrown by inelligable regex from the user - } } } diff --git a/scala/src/main/scala/quanto/gui/BatchDerivationCreationDocument.scala b/scala/src/main/scala/quanto/gui/BatchDerivationCreationDocument.scala new file mode 100644 index 00000000..70ecd782 --- /dev/null +++ b/scala/src/main/scala/quanto/gui/BatchDerivationCreationDocument.scala @@ -0,0 +1,35 @@ +package quanto.gui + +import java.io.File + +import quanto.data.{Project, Theory} +import quanto.util.UserAlerts +import quanto.util.json.Json + +import scala.swing.{Component, Publisher} + +// +// This file is intentionally bare because this panel will spawn jobs, rather than display files +// +class BatchDerivationCreationDocument(val parent: Component) extends Document with Publisher { + val description = "Batch Derivation" + val fileExtension = "" + + + protected def clearDocument() { + } + + protected def saveDocument(f: File) { + } + + override def loadDocument(f: File) { + } + + override def titleDescription: String = "Batch Derivation" + + override def unsavedChanges: Boolean = false + + override protected def exportDocument(f: File) { + } + +} diff --git a/scala/src/main/scala/quanto/gui/BatchDerivationCreatorPanel.scala b/scala/src/main/scala/quanto/gui/BatchDerivationCreatorPanel.scala new file mode 100644 index 00000000..bcd9e6ba --- /dev/null +++ b/scala/src/main/scala/quanto/gui/BatchDerivationCreatorPanel.scala @@ -0,0 +1,174 @@ +package quanto.gui + +import java.awt.BorderLayout +import java.awt.event.{KeyAdapter, KeyEvent} +import java.io.File +import java.util.Calendar +import java.util.concurrent.TimeUnit + +import org.gjt.sp.jedit.{Mode, Registers} +import org.gjt.sp.jedit.textarea.StandaloneTextArea +import quanto.rewrite.Simproc +import quanto.util.UserOptions.scaleInt +import quanto.util.swing.ToolBar +import quanto.data.Graph + +import scala.swing.event.{ButtonClicked, Event, SelectionChanged} +import quanto.cosy.SimprocBatch +import quanto.util.{FileHelper, UserOptions} + +import scala.swing.{BorderPanel, BoxPanel, Button, Component, Dimension, GridPanel, Label, Orientation, Publisher, ScrollPane, Swing, TextArea} + +case class HaltBatchProcessesEvent() extends Event + +// +// This panel will create batch jobs +// The corresponding document for HasDocument is essentially empty + +class BatchDerivationCreatorPanel extends BorderPanel with HasDocument with Publisher { + val document = new BatchDerivationCreationDocument(this) + val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask + val Toolbar = new ToolBar { + //contents + } + val GraphList = new FilteredList(graphs) + val LabelNumSimprocs = new Label() + val LabelNumGraphs = new Label() + val LabelNumPairs = new Label() + val LabelTimeLimit = new Label() + val NotesTextBox: TextEditor = new TextEditor(TextEditor.Modes.blank) + + + val StartButton = new Button("Start") + val HaltButton = new Button("Halt ongoing") + val BatchDetails: BoxPanel = new BoxPanel(Orientation.Vertical) { + contents += VSpace + contents += new GridPanel(4, 2) { + + contents += new Label("Simprocs:") + contents += LabelNumSimprocs + contents += new Label("Graphs:") + contents += LabelNumGraphs + contents += new Label("Total pairs:") + contents += LabelNumPairs + contents += new Label("Time limit:") + contents += LabelTimeLimit + + maximumSize = new Dimension(scaleInt(200), scaleInt(400)) + } + } + val GraphChooser: Component = new BoxPanel(Orientation.Vertical) { + + contents += VSpace + contents += header("Graphs to include:") + contents += VSpace + contents += GraphList + } + val IgnitionButtons: BoxPanel = new BoxPanel(Orientation.Horizontal) { + contents += (HSpace, StartButton, HSpace, HaltButton, HSpace) + } + val NotesHolder : BoxPanel = new BoxPanel(Orientation.Vertical) { + contents += VSpace + contents += new Label("Notes:") + contents += VSpace + contents += NotesTextBox.Component + NotesTextBox.Component.maximumSize = new Dimension(scaleInt(600), scaleInt(300)) + } + def MainPanel: BoxPanel = new BoxPanel(Orientation.Vertical) { + contents += SimprocChooser + contents += GraphChooser + contents += NotesHolder + contents += BatchDetails + + contents += VSpace + contents += IgnitionButtons + contents += VSpace + } + //SimprocList is a var not a val, because it can then be destroyed and recreated when the simprocs in memory change + var SimprocList = new FilteredList(simprocs.keys.toList) + + val SimprocChooser: BoxPanel = new BoxPanel(Orientation.Vertical) { + // Made def not val because it references the var SimprocList + contents += VSpace + contents += header("Simprocs to include:") + contents += VSpace + contents += SimprocList + } + + def header(title: String): BoxPanel = new BoxPanel(Orientation.Horizontal) { + contents += (HSpace, new Label(title), HSpace) + } + + private def HSpace: Component = Swing.HStrut(scaleInt(10)) + + private def VSpace: Component = Swing.VStrut(scaleInt(10)) + + def TopScrollablePane = new ScrollPane(MainPanel) + + def simprocs: Map[String, Simproc] = QuantoDerive.CurrentProject.map(p => p.simprocs).getOrElse(Map()) + + def graphs: List[String] = QuantoDerive.CurrentProject.map(p => p.filesEndingIn(".qgraph")).getOrElse(List()) + + def loadGraph(name: String): Graph = { + val root: String = QuantoDerive.CurrentProject.map(p => p.rootFolder).getOrElse("") + val theory = QuantoDerive.CurrentProject.map(p => p.theory).get + FileHelper.readFile[Graph](new File(root + "/" + name + ".qgraph"), json => Graph.fromJson(json, theory)) + } + + def refreshDataDisplay() { + def numSimprocs: Int = simprocSelection.size + + def numGraphs: Int = graphSelection.size + + LabelNumSimprocs.text = numSimprocs.toString + LabelNumGraphs.text = numGraphs.toString + LabelNumPairs.text = (numGraphs * numSimprocs).toString + val totalMilliseconds = numSimprocs * numGraphs * SimprocBatch.timeout + LabelTimeLimit.text = s"${TimeUnit.MILLISECONDS.toHours(totalMilliseconds)} hrs ${ + TimeUnit.MILLISECONDS.toMinutes(totalMilliseconds) - + TimeUnit.MINUTES.toMinutes(TimeUnit.MILLISECONDS.toHours(totalMilliseconds)) + } mins" + } + + listenTo(this, + StartButton, + HaltButton, + SimprocList.ListComponent.selection, + GraphList.ListComponent.selection, + PythonEditPanel) + + + reactions += { + case SimprocsUpdated() => + SimprocChooser.contents -= SimprocList + SimprocList = new FilteredList(simprocs.keys.toList) + SimprocChooser.contents += SimprocList + refreshDataDisplay() + listenTo(SimprocList.ListComponent.selection) + case ButtonClicked(HaltButton) => + BatchDerivationCreatorPanel.jobID += 1 + case ButtonClicked(StartButton) => + val batch = SimprocBatch(simprocSelection, graphSelection.map(name => loadGraph(name)), NotesTextBox.getText) + batch.run() + case SelectionChanged(_) => + refreshDataDisplay() + } + + def simprocSelection: List[String] = SimprocList.ListComponent.selection.items.toList + + def graphSelection: List[String] = GraphList.ListComponent.selection.items.toList + + add(TopScrollablePane, BorderPanel.Position.Center) + + add(Toolbar, BorderPanel.Position.North) + + refreshDataDisplay() + +} + +object BatchDerivationCreatorPanel { + // The job ID exists so the user can cancel any currently running batch jobs + // Jobs check the current job ID, and if it has increased since the job started the job will stop + var jobID = 0 + +} \ No newline at end of file diff --git a/scala/src/main/scala/quanto/gui/BatchDerivationResultsDocument.scala b/scala/src/main/scala/quanto/gui/BatchDerivationResultsDocument.scala new file mode 100644 index 00000000..b69eaefd --- /dev/null +++ b/scala/src/main/scala/quanto/gui/BatchDerivationResultsDocument.scala @@ -0,0 +1,52 @@ +package quanto.gui + +import java.io.File + +import quanto.cosy.{SimprocBatchResult, SimprocLazyBatchResult, SimprocSingleRun} +import quanto.data.Graph + +import scala.swing.{Component, Publisher} + +// The document is read-only data generated by a batch run +// There are two ways to access the document +// - a standard loadDocument which lazy reads in what is essentially meta-data +// - a full load that puts all of the results into memory +class BatchDerivationResultsDocument(val parent: Component) extends Document with Publisher { + val description = "Batch Derivation Results" + val fileExtension = "qsbr" + + var simprocsUsed: Option[List[String]] = None + var resultsCount: Option[Int] = None + var allSimprocs: Option[Map[String, String]] = None + var results: Option[List[SimprocSingleRun]] = None + var timeTaken: Option[Double] = None + var notes: Option[String] = None + + override def loadDocument(f: File) = { + val lazyResults = quanto.util.FileHelper.readFile[SimprocLazyBatchResult](f, SimprocBatchResult.lazyFromJson) + simprocsUsed = Some(lazyResults.selectedSimprocs) + notes = Some(lazyResults.notes) + resultsCount = Some(lazyResults.resultCount) + publish(DocumentChanged(this)) + } + + def loadFullDocument(f: File): Unit = { + val batchResults = quanto.util.FileHelper.readFile[SimprocBatchResult](f, SimprocBatchResult.fromJson) + allSimprocs = Some(batchResults.allSimprocs) + results = Some(batchResults.singleResults) + timeTaken = Some(batchResults.singleResults.map(sr => sr.derivationTimings.map(tt => tt._2).sum).sum) + publish(DocumentChanged(this)) + } + + override def unsavedChanges: Boolean = false + + protected def clearDocument() { + } + + protected def saveDocument(f: File) { + } + + override protected def exportDocument(f: File) { + } + +} diff --git a/scala/src/main/scala/quanto/gui/BatchDerivationResultsPanel.scala b/scala/src/main/scala/quanto/gui/BatchDerivationResultsPanel.scala new file mode 100644 index 00000000..453fd53c --- /dev/null +++ b/scala/src/main/scala/quanto/gui/BatchDerivationResultsPanel.scala @@ -0,0 +1,118 @@ +package quanto.gui + +import java.io.File +import java.util.Calendar +import java.util.concurrent.TimeUnit + +import quanto.rewrite.Simproc +import quanto.util.UserOptions.scaleInt +import quanto.util.swing.ToolBar +import quanto.data.Graph + +import scala.swing.event.{ButtonClicked, Event, SelectionChanged} +import quanto.cosy.{SimprocBatch, SimprocBatchResult, SimprocSingleRun} +import quanto.util.FileHelper + +import scala.swing.{BorderPanel, BoxPanel, Button, Component, Dimension, GridPanel, Label, Orientation, Publisher, ScrollPane, Swing} + +// The document this panel displays has two ways to access the file: +// A lazy read, used on first pass, that just loads metadata +// A full read that loads everything into memory + +class BatchDerivationResultsPanel() + extends BorderPanel with HasDocument with Publisher { + + val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask + val Toolbar = new ToolBar { + //contents + } + + val document: BatchDerivationResultsDocument = new BatchDerivationResultsDocument(this) + val LabelNumSimprocs = new Label("Calculating") + val LabelSimprocsUsed = new Label("Calculating") + val LabelNumGraphs = new Label("Calculating") + val LabelNumPairs = new Label("Calculating") + val LabelTimeTaken = new Label("Calculating") + val SimprocsUsed: BoxPanel = new BoxPanel(Orientation.Vertical) { + contents += VSpace + contents += new Label("Simprocs used:") + contents += VSpace + contents += LabelSimprocsUsed + contents += VSpace + } + val BatchDetails: BoxPanel = new BoxPanel(Orientation.Vertical) { + contents += VSpace + contents += new GridPanel(3, 2) { + + contents += new Label("No. Simprocs:") + contents += LabelNumSimprocs + contents += new Label("No. Graphs:") + contents += LabelNumGraphs + contents += new Label("Total pairs:") + contents += LabelNumPairs + } + contents += VSpace + maximumSize = new Dimension(scaleInt(400), scaleInt(200)) + } + val NotesTextBox = new Label(document.notes.getOrElse("")) + val Notes: BoxPanel = new BoxPanel(Orientation.Vertical) { + contents += VSpace + contents += new Label("Notes:") + contents += VSpace + contents += NotesTextBox + contents += VSpace + NotesTextBox.preferredSize = new Dimension(scaleInt(600), scaleInt(300)) + } + val MainPanel: BoxPanel = new BoxPanel(Orientation.Vertical) { + contents += SimprocsUsed + contents += BatchDetails + contents += Notes + } + val TopScrollablePane = new ScrollPane(MainPanel) + + def allSimprocs: Option[Map[String, String]] = document.allSimprocs + + def results: Option[List[SimprocSingleRun]] = document.results + + def notes: Option[String] = document.notes + + def refreshData(): Unit = { + val numSimprocs: Int = simprocsUsed.getOrElse(List()).size + val numPairs: Int = resultsCount.getOrElse(0) + LabelNumSimprocs.text = numSimprocs.toString + LabelSimprocsUsed.text = simprocsUsed.getOrElse(List()).mkString("\n") + LabelNumPairs.text = numPairs.toString + LabelNumGraphs.text = (numPairs, numSimprocs) match { + case (0, _) => "0" + case (n, 0) => n.toString + case (a,b) => (a/b).toString + } + LabelTimeTaken.text = if (timeTaken.nonEmpty) { + timeTaken.get.toString + } else { + "---" + } + NotesTextBox.text = notes.getOrElse("") + } + + listenTo(document) + reactions += { + case DocumentChanged(d) => refreshData() + } + + def simprocsUsed: Option[List[String]] = document.simprocsUsed + + def resultsCount: Option[Int] = document.resultsCount + + def timeTaken: Option[Double] = document.timeTaken + + private def HSpace: Component = Swing.HStrut(scaleInt(10)) + + + private def VSpace: Component = Swing.VStrut(scaleInt(10)) + + add(TopScrollablePane, BorderPanel.Position.Center) + + add(Toolbar, BorderPanel.Position.North) + +} diff --git a/scala/src/main/scala/quanto/gui/ClosableTabbedPane.scala b/scala/src/main/scala/quanto/gui/ClosableTabbedPane.scala index b8633456..4dc98438 100644 --- a/scala/src/main/scala/quanto/gui/ClosableTabbedPane.scala +++ b/scala/src/main/scala/quanto/gui/ClosableTabbedPane.scala @@ -1,13 +1,22 @@ package quanto.gui +import java.awt.event.{MouseAdapter, MouseEvent} + import scala.swing._ import javax.swing.border.EmptyBorder -import javax.swing.{ImageIcon, Icon} -import java.awt.{Color, BasicStroke, RenderingHints, Graphics} +import javax.swing.{Icon, ImageIcon, SwingUtilities} +import java.awt.{BasicStroke, Color, Graphics, Point, RenderingHints} + +import quanto.gui.QuantoDerive.popup +import quanto.util.{UserAlerts, UserOptions} + +import scala.swing.event._ class ClosablePage(title0: String, component0: Component, val closeAction: () => Boolean) -extends TabbedPane.Page(title0, component0) { - lazy val tabComponent : ClosablePage.TabComponent = { new ClosablePage.TabComponent(this) } + extends TabbedPane.Page(title0, component0) { + lazy val tabComponent: ClosablePage.TabComponent = { + new ClosablePage.TabComponent(this) + } override def title_=(t: String) { super.title_=(t) @@ -15,6 +24,9 @@ extends TabbedPane.Page(title0, component0) { } } +case class PageClosed(p: ClosablePage) extends DocumentEvent + + object ClosablePage { def apply(title: String, component: Component)(closeAction: => Boolean) = new ClosablePage(title, component, () => closeAction) @@ -28,38 +40,70 @@ object ClosablePage { case _: PythonDocumentPage => new ImageIcon(getClass.getResource("text-x-script.png"), "Python Script") case _ => new ImageIcon(getClass.getResource("text-x-generic.png"), "Document") } - - val titleLabel = new Label(p.title,icon,Alignment.Left) - titleLabel.border = new EmptyBorder(new Insets(5,5,5,10)) + val titleLabel = new Label(p.title, icon, Alignment.Left) + val closeButton = new Button(Action("") { +closePage() + }) + titleLabel.border = new EmptyBorder(new Insets(5, 5, 5, 10)) contents += titleLabel + def closePage(): Unit = { + if (p.closeAction()) { + publish(PageClosed(p)) + } else { + } + } + + def TabContextMenu(): PopupMenu = new PopupMenu { + menu => + + val CloseJustThis: Action = new Action("Close this tab") { + + menu.contents += new MenuItem(this) + + def apply(): Unit = closePage() + } + + + val CloseOtherTabs: Action = new Action("Close all other tabs") { + menu.contents += new MenuItem(this) + + def apply(): Unit = { + QuantoDerive.closeAllOrListOfDocuments(Some(QuantoDerive.MainDocumentTabs.documents.toList.filter(q => q != p))) + } + } + } + def title = titleLabel.text + def title_=(t: String) { titleLabel.text = t } - val closeButton = new Button(Action("") { - if (p.closeAction()) { - if (p.parent != null) p.parent.pages -= p - printf("got successful close") - } else { - printf("tried to close") - } - }) - closeButton.border = new EmptyBorder(new Insets(0,0,0,0)) + closeButton.border = new EmptyBorder(new Insets(0, 0, 0, 0)) closeButton.contentAreaFilled = false closeButton.rolloverEnabled = true closeButton.icon = new CloseIcon(false) closeButton.rolloverIcon = new CloseIcon(true) contents += closeButton + + //listenTo(titleLabel.mouse.clicks) + // Once we know how to dispatch all events further down the hierarchy we can use this + // Until then, any attempt to pick up mouse events blocks native behaviour + reactions += { + case e : MouseEvent => + //QuantoDerive.MainDocumentTabs.focus(p.asInstanceOf[DocumentPage]) + UserAlerts.alert("Should have popped up") + if(e.isPopupTrigger) popup(TabContextMenu(), Some(e)) + peer.dispatchEvent(SwingUtilities.convertMouseEvent(e.getComponent, e, peer)) + } } + } // draw a little X -private class CloseIcon(rollover : Boolean) extends Icon { - def getIconWidth = 9 - def getIconHeight = 9 - def paintIcon(c : java.awt.Component, g : Graphics, x : Int, y : Int) : Unit = { +private class CloseIcon(rollover: Boolean) extends Icon { + def paintIcon(c: java.awt.Component, g: Graphics, x: Int, y: Int): Unit = { val g2 = g.asInstanceOf[Graphics2D] g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON) val savedStroke = g2.getStroke @@ -70,17 +114,131 @@ private class CloseIcon(rollover : Boolean) extends Icon { g2.drawLine(getIconWidth - 3, 2, 2, getIconHeight - 3) g2.setStroke(savedStroke) } + + def getIconWidth : Int = UserOptions.scaleInt(9) + + def getIconHeight : Int = UserOptions.scaleInt(9) } -class ClosableTabbedPane extends TabbedPane { tabbedPane => - def +=(p: ClosablePage) { - pages += p - peer.setTabComponentAt(pages.length-1, p.tabComponent.peer) +// Handler for the TabbedPane object, to distance the swing from the Java +// TabbedPanes hold Java Components, but we want to interact with Documents +class DocumentTabs { + + val tabbedPane = new TabbedPane() + private var pageIndex: Map[DocumentPage, Int] = Map() + + def size : Int = pageIndex.size + + def component: TabbedPane = tabbedPane + + def +=(p: DocumentPage) { + if (!pageIndex.contains(p)) { + pages += p + tabbedPane.peer.setTabComponentAt(pages.length - 1, p.tabComponent.peer) + pageIndex += (p -> (pages.length - 1)) + } + reorderPages() + focus(p) + } + + def cycle(forward: Boolean = true) : Unit = { + if(pageIndex.nonEmpty) { + focus((selection.index + (if (forward) 1 else -1) + size) % size) + } + } + + def focus(index: Int): Unit = { + selection.index = index + currentFocus match { + case Some(p) => + p.document.focusOnNaturalComponent() + case None => + } + } + + def remove(page: DocumentPage): Unit = { + val removedIndex = pageIndex(page) + val preFocus = currentFocus + remove(removedIndex) + pageIndex -= page + reorderPages() + if (preFocus.nonEmpty && preFocus.get != page) { + focus(preFocus.get) + } else { + // Handled by peer + //selection.index -=1 + } + ensureFocusIsValid() + } + + private def reorderPages(): Unit = { + val indexPages = pageIndex.map(pi => (pi._2, pi._1)).toList.sortBy(_._1) + var newPageIndex: Map[DocumentPage, Int] = Map() + for (iip <- indexPages.zipWithIndex) { + newPageIndex += (iip._1._2 -> iip._2) + } + pageIndex = newPageIndex + + } + + private def remove(index: Int): Unit = { + pages.remove(index) + for (page <- pageIndex.keys) { + if (pageIndex(page) > index) { + pageIndex += (page -> (pageIndex(page) - 1)) + } + } + } + + private def ensureFocusIsValid(): Unit = { + if (pages.length > 0) { + // There is something to focus on + if (selection.index >= pages.length) { + selection.index = pages.length - 1 + } + if (selection.index < 0) { + selection.index = 0 + } + } + } + + def selection = tabbedPane.selection + + private def pages = tabbedPane.pages + + def focus(p: DocumentPage): Unit = { + try { + focus(pageIndex(p)) + } catch { + case e: Exception => selection.index = -1 + } } - def currentContent: Option[Component] = - if (selection.index == -1) None + def currentFocus: Option[DocumentPage] = { + pageIndex.find(kv => kv._2 == selection.index).map(_._1) + } + + def currentContent: Option[Component] = { + if (selection.index == -1 || selection.page == null || selection.page.content == null) None else Some(selection.page.content) + } + + // Add this reaction to the peer, then handle changes at this level. + tabbedPane.selection.reactions += { + case SelectionChanged(`tabbedPane`) => + publishChanged() + } + + def publishChanged(): Unit = { + currentFocus.foreach(_.document.focusOnNaturalComponent()) + } + + def documents: Iterable[DocumentPage] = pageIndex.keys + + def clear(): Unit = { + pages.clear() + pageIndex = Map() + } } diff --git a/scala/src/main/scala/quanto/gui/CodeTextArea.java b/scala/src/main/scala/quanto/gui/CodeTextArea.java new file mode 100644 index 00000000..075fa678 --- /dev/null +++ b/scala/src/main/scala/quanto/gui/CodeTextArea.java @@ -0,0 +1,20 @@ +package quanto.gui; + +import org.gjt.sp.jedit.*; +import org.gjt.sp.jedit.textarea.*; + +import java.util.Properties; + + +public class CodeTextArea extends StandaloneTextArea { + static private IPropertyManager propertyManager; + static { + final Properties props = new Properties(); + + propertyManager = new IPropertyManager() { + public String getProperty(String prop) { return props.getProperty(prop); } + }; + } + + public CodeTextArea() { super(propertyManager); } +} diff --git a/scala/src/main/scala/quanto/gui/ColourSwapDialog.scala b/scala/src/main/scala/quanto/gui/ColourSwapDialog.scala new file mode 100644 index 00000000..bceb2212 --- /dev/null +++ b/scala/src/main/scala/quanto/gui/ColourSwapDialog.scala @@ -0,0 +1,87 @@ +package quanto.gui + +import quanto.data.Theory.ValueType + +import scala.swing._ +import scala.swing.event.{ButtonClicked, Key, KeyPressed, ValueChanged} +import quanto.data._ +import quanto.util.{Globals, UserOptions} + +import scala.util.matching +import scala.util.matching.Regex +import quanto.util.UserOptions.{scale, scaleInt} + +class ColourSwapDialog(theory: Theory) extends Dialog { + modal = true + title = "Specify new colours" + val smallGap: Int = UserOptions.scaleInt(5) + + val types : List[String] = theory.vertexTypes.keys.toList + + val SwapButton = new Button("Swap") + val CancelButton = new Button("Cancel") + defaultButton = Some(SwapButton) + + def result: Map[String, String] = { + if(!cancelled) { + types.map(t => t -> comboFields(t).selection.item).toMap + } else { + types.map(t => t -> t).toMap + } + } + + var cancelled = false + + val comboFields : Map[String, ComboBox[String]] = types.map(t => { + t -> new ComboBox(types) + }).toMap + + val MainPanel = new BoxPanel(Orientation.Vertical) { + + contents += Swing.VStrut(smallGap) + + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(smallGap), + new Label("Specify which types are sent to which"), + Swing.HStrut(smallGap)) + } + contents += Swing.VStrut(smallGap) + + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(smallGap), + new GridPanel(types.length, 2) { + types.foreach(t => { + val label = new Label(t) + label.horizontalAlignment = Alignment.Left + contents += label + val tf = comboFields(t) + contents += tf + tf.selection.item = t + }) + }, + Swing.HStrut(smallGap)) + } + contents += Swing.VStrut(smallGap) + + contents += new BoxPanel(Orientation.Horizontal) { + contents += (SwapButton, Swing.HStrut(smallGap), CancelButton) + } + + contents += Swing.VStrut(2*smallGap) + + } + + + contents = MainPanel + + listenTo(SwapButton, CancelButton) + + reactions += { + case ButtonClicked(SwapButton) => + cancelled = false + close() + case ButtonClicked(CancelButton) => + cancelled = true + close() + } +} diff --git a/scala/src/main/scala/quanto/gui/DerivationController.scala b/scala/src/main/scala/quanto/gui/DerivationController.scala index 4750368a..7dd5e1e5 100644 --- a/scala/src/main/scala/quanto/gui/DerivationController.scala +++ b/scala/src/main/scala/quanto/gui/DerivationController.scala @@ -1,12 +1,15 @@ package quanto.gui import quanto.data._ + import scala.swing._ -import scala.swing.event.{SelectionChanged, ButtonClicked, Event} +import scala.swing.event.{ButtonClicked, Event, SelectionChanged} import quanto.gui.histview._ import java.awt.Color + import quanto.gui.graphview.Highlight import quanto.layout.DeriveLayout +import quanto.util.UserAlerts sealed abstract class DeriveState extends HistNode { def step: Option[DSName] } case class StepState(s: DSName) extends DeriveState { @@ -193,7 +196,6 @@ class DerivationController(panel: DerivationPanel) extends Publisher { replaceDerivation(derivation.deleteHead(s), "") panel.document.undoStack.commit() case StepState(s) => - // TODO: make deletion undo-able? if (Dialog.showConfirmation( title = "Confirm deletion", message = "This will delete " + derivation.allChildren(s).size + @@ -205,34 +207,40 @@ class DerivationController(panel: DerivationPanel) extends Publisher { panel.document.undoStack.start("Delete proof step") state = HeadState(parentOpt) replaceDerivation(derivation.deleteStep(s), "") + if(parentOpt.nonEmpty && !derivation.isHead(parentOpt.get)){ + val p = parentOpt.get + replaceDerivation(derivation.addHead(p), "") + state = HeadState(Some(p)) + } + panel.document.undoStack.commit() } case _ => // do nothing on root } case ButtonClicked(panel.ExportTheoremButton) => - panel.document.file match { - case Some(f) => - val rf = panel.project.rootFolder - var dname = f.getAbsolutePath - dname = if (dname.startsWith(rf) && dname.length > rf.length) dname.substring(rf.length + 1, dname.length) else dname - dname = if (dname.endsWith(".qderive")) dname.substring(0, dname.length - 8) else dname - - state.step.map { s => - val ruleDoc = new RuleDocument(panel, panel.theory) - ruleDoc.rule = new Rule(panel.document.root, derivation.steps(s).graph, Some(dname)) - ruleDoc.showSaveAsDialog(Some(panel.project.rootFolder)) - } + if(!panel.document.unsavedChanges && panel.document.file.nonEmpty) { + val f = panel.document.file.get + val rf = panel.project.rootFolder + var dname = panel.project.relativePath(f) + dname = if (dname.endsWith(".qderive")) dname.substring(0, dname.length - 8) else dname + + state.step.foreach { s => + val ruleDoc = new RuleDocument(panel, panel.theory) + val newRule = new Rule(panel.document.root, derivation.steps(s).graph, Some(dname)) + val page = new RuleDocumentPage(panel.theory) + page.document.asInstanceOf[RuleDocument].rule = newRule + QuantoDerive.addAndFocusPage(page) + } - case None => - Dialog.showMessage( - title = "Error", - message = "You must first save this derivation before exporting a theorem", - messageType = Dialog.Message.Error) + } else { + UserAlerts.errorBox("You must first save this derivation before exporting a theorem") } case SelectionChanged(_) => - if (panel.histView.selectedNode != Some(state)) + if (panel.histView.selectedNode != Some(state)) { panel.histView.selectedNode.map { st => state = st } + panel.histView.ensureIndexIsVisible(panel.histView.selectedIndex()) + } } } diff --git a/scala/src/main/scala/quanto/gui/DerivationDocument.scala b/scala/src/main/scala/quanto/gui/DerivationDocument.scala index ea50c2e8..65914fc1 100644 --- a/scala/src/main/scala/quanto/gui/DerivationDocument.scala +++ b/scala/src/main/scala/quanto/gui/DerivationDocument.scala @@ -11,7 +11,7 @@ class DerivationDocument(panel: DerivationPanel) extends Document { val fileExtension = "qderive" protected def parent = panel - private var storedDerivation: Derivation = Derivation(panel.theory, Graph(panel.theory)) + private var storedDerivation: Derivation = Derivation(Graph()) private var _derivation: Derivation = storedDerivation def derivation = _derivation def derivation_=(d: Derivation) = { @@ -49,8 +49,8 @@ class DerivationDocument(panel: DerivationPanel) extends Document { storedDerivation = _derivation } - protected def clearDocument() = { - _derivation = Derivation(panel.theory, root = Graph(panel.theory)) + protected def clearDocument() { + _derivation = Derivation(root = Graph(panel.theory)) } def root_=(g: Graph) { @@ -60,11 +60,12 @@ class DerivationDocument(panel: DerivationPanel) extends Document { publish(DocumentChanged(this)) } - def root = rootRef.graph + def root: Graph = rootRef.graph - override protected def exportDocument(f: File) = { + override protected def exportDocument(f: File) { + previousDir = f val old_state : DeriveState = parent.controller.state - var state_opt : Option[DeriveState] = derivation.parent(old_state) + var state_opt : Option[DeriveState] = Some(old_state) //derivation.parent(old_state) /* sequence of states from root leading to current state's parent */ var state_list : List[DeriveState] = List() @@ -77,29 +78,46 @@ class DerivationDocument(panel: DerivationPanel) extends Document { state_list = state_list.tail /* enclose everything into a quote environment */ - printToFile(f, false)(p => { + printToFile(f, append=false)(p => { p.println("\\begin{quote}\\raggedright") }) + val baseName = f.getName.substring(0, f.getName.lastIndexOf('.')) + val fullBaseName = f.getAbsolutePath.substring(0, f.getAbsolutePath.lastIndexOf('.')) + + var i = 0 + /* output derivation until current state parent */ state_list.foreach { s => parent.controller.state = s + val stepFileName = fullBaseName + "-" + i + ".tikz" + println(stepFileName) + val stepFile = new File(stepFileName) /* must compute view data to prevent race condition */ parent.LhsView.computeDisplayData() - parent.LhsView.exportView(f, true) + parent.LhsView.exportView(stepFile, append=false) + + printToFile(f)( p => { + p.print("\\tikzfig{" + baseName + "-" + i + "}") - printToFile(f, true)( p => { - p.println("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") - p.println("=") - p.println("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") + // print an equals sign, unless this is the last step + if (i < state_list.length - 1) { + p.println(" $=$") + } else { + p.println() + } }) + + i += 1 } + /* finally output current state and close quote environment */ - parent.controller.state = old_state - parent.LhsView.computeDisplayData() - parent.LhsView.exportView(f, true) +// parent.controller.state = old_state +// parent.LhsView.computeDisplayData() +// parent.LhsView.exportView(f, true) + printToFile(f, true)( p => { p.println("\\end{quote}") }) diff --git a/scala/src/main/scala/quanto/gui/DerivationPanel.scala b/scala/src/main/scala/quanto/gui/DerivationPanel.scala index a04bc7a4..e70a980f 100644 --- a/scala/src/main/scala/quanto/gui/DerivationPanel.scala +++ b/scala/src/main/scala/quanto/gui/DerivationPanel.scala @@ -8,10 +8,13 @@ import scala.swing.event._ import javax.swing.ImageIcon import quanto.gui.histview.HistView +case class RequestReRunSimproc() extends Event +case class SuggestRewriteRule(relativePath: RuleDesc) extends Event class DerivationPanel(val project: Project) extends BorderPanel with HasDocument + with Publisher { def theory = project.theory val document = new DerivationDocument(this) @@ -38,6 +41,8 @@ class DerivationPanel(val project: Project) val lhsController = new GraphEditController(LhsView, document.undoStack, readOnly = true) val rhsController = new GraphEditController(RhsView, document.undoStack, readOnly = true) + def ReRunLastSimproc() : Unit = publish(RequestReRunSimproc()) + val RewindButton = new Button() { icon = new ImageIcon(getClass.getResource("go-first.png"), "First step") tooltip = "First Step" @@ -174,42 +179,11 @@ class DerivationPanel(val project: Project) add(new SplitPane(Orientation.Horizontal, topPane, PreviewScrollPane), BorderPanel.Position.Center) } - val SimplifyBuiltInPane = new BorderPanel { - val Simprocs = new ListView[String] - val SimprocsScrollPane = new ScrollPane(Simprocs) - SimprocsScrollPane.preferredSize = new Dimension(400,200) - val Preview = new GraphView(theory, DummyRef) - val PreviewScrollPane = new ScrollPane(Preview) - Preview.zoom = 0.6 - - val SimplifyButton = new Button { - icon = new ImageIcon(GraphEditor.getClass.getResource("start.png")) - preferredSize = toolbarDim - tooltip = "Start" - } - - val StopButton = new Button { - icon = new ImageIcon(GraphEditor.getClass.getResource("stop.png")) - preferredSize = toolbarDim - tooltip = "Stop" - } - - - val topPane = new BorderPanel { - add(SimprocsScrollPane, BorderPanel.Position.Center) - add(new FlowPanel(FlowPanel.Alignment.Center)( - SimplifyButton - ), BorderPanel.Position.South) - } - - add(new SplitPane(Orientation.Horizontal, topPane, PreviewScrollPane), BorderPanel.Position.Center) - } val RhsRewritePane = new TabbedPane RhsRewritePane.pages += new TabbedPane.Page("Rewrite", ManualRewritePane) RhsRewritePane.pages += new TabbedPane.Page("Simplify", SimplifyPane) - RhsRewritePane.pages += new TabbedPane.Page("Simplify (built-in)", SimplifyBuiltInPane) val LhsLabel = new Label("(root)") val RhsLabel = new Label("(head)") @@ -246,6 +220,7 @@ class DerivationPanel(val project: Project) add(GraphViewPanel, BorderPanel.Position.Center) + listenTo(document) listenTo(LhsGraphPane, RhsGraphPane) listenTo(ManualRewritePane.PreviewScrollPane, SimplifyPane.PreviewScrollPane) @@ -262,6 +237,8 @@ class DerivationPanel(val project: Project) case UIElementResized(SimplifyPane.PreviewScrollPane) => SimplifyPane.Preview.resizeViewToFit() SimplifyPane.Preview.repaint() + case DocumentRequestingNaturalFocus(_) => + LhsView.requestFocus() } // construct the controller last, as it depends on the panel elements already being initialised @@ -269,6 +246,5 @@ class DerivationPanel(val project: Project) val rewriteController = new RewriteController(this) val simplifyController = new SimplifyController(this) - val simplifyBuiltInController = new SimplifyBuiltInController(this) // rewriteController.rules = Vector(RuleDesc("axioms/test1", inverse = false), RuleDesc("axioms/test2", inverse = true)) } diff --git a/scala/src/main/scala/quanto/gui/Document.scala b/scala/src/main/scala/quanto/gui/Document.scala index d1c53309..f59dc77a 100644 --- a/scala/src/main/scala/quanto/gui/Document.scala +++ b/scala/src/main/scala/quanto/gui/Document.scala @@ -1,17 +1,21 @@ package quanto.gui -import scala.swing.{Component, FileChooser, Dialog, Publisher} -import java.io.{FileNotFoundException, IOException, File} +import scala.swing.{Component, Dialog, FileChooser, Publisher} +import java.io.{File, FileNotFoundException, IOException} + import scala.swing.event.Event import quanto.data._ import quanto.util.json.JsonParseException import javax.swing.filechooser.FileNameExtensionFilter import java.util.prefs.Preferences +import javax.swing.JOptionPane + abstract class DocumentEvent extends Event case class DocumentChanged(sender: Document) extends DocumentEvent case class DocumentSaved(sender: Document) extends DocumentEvent case class DocumentReplaced(sender: Document) extends DocumentEvent +case class DocumentRequestingNaturalFocus(sender: Document) extends DocumentEvent /** * For an object connected to a single file. Provides an undo stack, tracks changes, and gives @@ -87,7 +91,12 @@ abstract class Document extends Publisher { } def titleDescription : String = { - file.map(f => f.getName).getOrElse("Untitled").replaceAll("\\.[^.]*$", "") + val base = file.map(f => f.getName).getOrElse("Untitled").replaceAll("\\.[^.]*$", "") + if (unsavedChanges) { + base + "*" + } else { + base + } // val name : String = file.map(f => f.getName).getOrElse("Untitled") // // If there is a description then use in instead of the file extension // val nameDescription = if (description.length > 0) name.replaceAll("\\.[^.]*$", "") + " " + description else name @@ -121,36 +130,53 @@ abstract class Document extends Publisher { * @return true if the document can be closed, false otherwise * (as per user decission) */ - def promptUnsaved() = { + def promptUnsaved(): Boolean = { if (unsavedChanges) { - val choice = Dialog.showOptions( - title = "Unsaved changes", - message = "Do you want to save your changes or discard them?", - entries = "Save" :: "Discard" :: "Cancel" :: Nil, - initial = 0 - ) +// val choice = Dialog.showOptions( +// title = "Unsaved changes", +// message = "Do you want to save your changes or discard them?", +// entries = "Save" :: "Discard" :: "Cancel" :: Nil, +// initial = 0 +// ) + + val choice = JOptionPane.showOptionDialog(null, + "Do you want to save your changes or discard them?", + "Unsaved changes in "+titleDescription, + JOptionPane.DEFAULT_OPTION, + JOptionPane.WARNING_MESSAGE, null, + List("Save", "Discard", "Cancel").toArray, + "Save") // scala swing dialogs implementation is dumb, here's what I found : // Result(0) = Save, Result(1) = Discard, Result(2) = Cancel - if (choice == Dialog.Result(0)) trySave() - else choice == Dialog.Result(1) + if (choice == 0) trySave() + else choice == 1 } else true } - def promptExists(f: File) = { + def promptExists(f: File): Boolean = { if (f.exists()) { - Dialog.showConfirmation( - title = "File exists", - message = "File exists, do you wish to overwrite?") == Dialog.Result.Yes + JOptionPane.showConfirmDialog(null, + "File exists, do you wish to overwrite?", + "File exists", + JOptionPane.YES_NO_OPTION) == JOptionPane.YES_OPTION +// Dialog.showConfirmation( +// title = "File exists", +// message = "File exists, do you wish to overwrite?") == Dialog.Result.Yes } else true } + def errorDialog(action: String, reason: String) { - Dialog.showMessage( - title = "Error", - message = "Cannot " + action + " file (" + reason + ")", - messageType = Dialog.Message.Error) +// Dialog.showMessage( +// title = "Error", +// message = "Cannot " + action + " file (" + reason + ")", +// messageType = Dialog.Message.Error) + JOptionPane.showMessageDialog(null, + "Cannot " + action + " file (" + reason + ")", + "Error", + JOptionPane.ERROR_MESSAGE) } def previousDir_=(f: File) { @@ -227,6 +253,12 @@ abstract class Document extends Publisher { case UndoRegistered(_) => publish(DocumentChanged(this)) } + + // Publishes a request for the view to focus on whatever is "correct" for the document. + // E.g. give focus to the text area if you open a .py file + def focusOnNaturalComponent() : Unit = { + publish(DocumentRequestingNaturalFocus(this)) + } } trait HasDocument { diff --git a/scala/src/main/scala/quanto/gui/DocumentPage.scala b/scala/src/main/scala/quanto/gui/DocumentPage.scala index 7e41ed8c..f31c91cb 100644 --- a/scala/src/main/scala/quanto/gui/DocumentPage.scala +++ b/scala/src/main/scala/quanto/gui/DocumentPage.scala @@ -7,7 +7,9 @@ abstract class DocumentPage(component0: Component with HasDocument) extends ClosablePage( component0.document.titleDescription, component0, - closeAction = () => { component0.document.promptUnsaved() } ) + closeAction = () => { + component0.document.promptUnsaved()} +) with Reactor { val document = component0.document @@ -49,3 +51,18 @@ class DerivationDocumentPage(val project: Project) extends DocumentPage(new DerivationPanel(project)) { val documentType = "Derivation" } + +class TheoryPage + extends DocumentPage(new TheoryEditPanel) { + val documentType = "Theory Editor" +} + +class BatchDerivationPage + extends DocumentPage(new BatchDerivationCreatorPanel) { + val documentType = "Batch Derivation" +} + +class BatchDerivationResultsPage + extends DocumentPage(new BatchDerivationResultsPanel) { + val documentType = "Batch Derivation Results" +} \ No newline at end of file diff --git a/scala/src/main/scala/quanto/gui/FileTree.scala b/scala/src/main/scala/quanto/gui/FileTree.scala index 2f6f4e91..f555e673 100644 --- a/scala/src/main/scala/quanto/gui/FileTree.scala +++ b/scala/src/main/scala/quanto/gui/FileTree.scala @@ -15,6 +15,7 @@ import scala.swing.event.Event abstract class FileTreeEvent extends Event case class FileOpened(file: File) extends FileTreeEvent +case class FileContextRequested(file: File, e: Option[MouseEvent]) extends FileTreeEvent class FileTree extends BorderPanel { val fileTreeModel = new FileTreeModel @@ -41,12 +42,10 @@ class FileTree extends BorderPanel { } case _ => if (e.isPopupTrigger) { - // TODO: Make actual popup menu - // cf https://stackoverflow.com/questions/938753/scala-popup-menu - // Needs access to a swing component to have as parent + val row = fileTree.getClosestRowForLocation(e.getX, e.getY) + fileTree.setSelectionRow(row) fileTree.getLastSelectedPathComponent match { - case FileNode(file) => Desktop.getDesktop.browse(file.toURI) - case _ => + case FileNode(file) => publish(FileContextRequested(file, Some(e))) } } } diff --git a/scala/src/main/scala/quanto/gui/FilteredList.scala b/scala/src/main/scala/quanto/gui/FilteredList.scala new file mode 100644 index 00000000..a748fab3 --- /dev/null +++ b/scala/src/main/scala/quanto/gui/FilteredList.scala @@ -0,0 +1,55 @@ +package quanto.gui + +import quanto.util.UserOptions.scaleInt + +import scala.swing.event.ValueChanged +import scala.swing.{BoxPanel, Component, Dimension, Label, ListView, Orientation, ScrollPane, Swing, TextField} + +// Create a list of strings to select from, with a regex filter box at the top +// Regex: A text box that contains the regex to filter by +// listItems: The list to filter on +// ListComponent: The UI Component that the user selects elements from +class FilteredList(val options: List[String], + baseWidth: Int = 400, + baseHeight: Int = 200) extends BoxPanel(Orientation.Vertical) { + val Regex = new TextField + val listItems: List[String] = options.sorted + val ListComponent: ListView[String] = new ListView[String](listItems) + private val ScrollContainer = new ScrollPane(ListComponent) + ScrollContainer.maximumSize = new Dimension(scaleInt(baseWidth), scaleInt(baseHeight)) + + private def VSpace: Component = Swing.VStrut(scaleInt(10)) + + private def HSpace: Component = Swing.HStrut(scaleInt(10)) + + contents += VSpace + val FilterPanel : Component = new BoxPanel(Orientation.Horizontal) { + contents += (HSpace, new Label("Filter:"), HSpace, Regex, HSpace) + maximumSize = new Dimension(scaleInt(baseWidth), scaleInt(20)) + } + + contents += FilterPanel + contents += VSpace + contents += new BoxPanel(Orientation.Horizontal) { + contents += (HSpace, ScrollContainer, HSpace) + } + + listenTo(Regex) + reactions += { + case ValueChanged(Regex) => + try { + val filtered = listItems.filter( + s => s.matches("(?i).*" + Regex.text + ".*")) + ListComponent.listData = filtered + if(filtered.nonEmpty) { + ListComponent.peer.setSelectionInterval(0, filtered.length - 1) + } + if(Regex.text.isEmpty){ + ListComponent.peer.clearSelection() + } + } catch { + case e: Exception => + //Exceptions here are thrown by inelligable regex from the user + } + } +} diff --git a/scala/src/main/scala/quanto/gui/GraphDocument.scala b/scala/src/main/scala/quanto/gui/GraphDocument.scala index 6ebce8d3..1d20820f 100644 --- a/scala/src/main/scala/quanto/gui/GraphDocument.scala +++ b/scala/src/main/scala/quanto/gui/GraphDocument.scala @@ -23,6 +23,13 @@ class GraphDocument(val parent: Component, theory: Theory) extends Document with // resetDocumentInfo() // } + def replaceJson(json: Json) { + graph = Graph.fromJson(json, theory) + publish(GraphReplaced(this, clearSelection = true)) + publish(DocumentReplaced(this)) + publish(DocumentChanged(this)) + } + protected def loadDocument(f: File) { val json = Json.parse(f) storedGraph = Graph.fromJson(json, theory) diff --git a/scala/src/main/scala/quanto/gui/GraphEditController.scala b/scala/src/main/scala/quanto/gui/GraphEditController.scala index ae00b3c3..772cfb39 100644 --- a/scala/src/main/scala/quanto/gui/GraphEditController.scala +++ b/scala/src/main/scala/quanto/gui/GraphEditController.scala @@ -1,19 +1,21 @@ package quanto.gui -import graphview._ -import swing._ -import swing.event._ -import Key.Modifier +import java.awt.Toolkit +import java.awt.datatransfer._ +import java.awt.event.{ActionEvent, ActionListener} +import java.util.Calendar + +import quanto.data.Names._ import quanto.data._ -import Names._ +import quanto.gui.graphview.{BBoxOverlay, EdgeOverlay, _} import quanto.layout.ForceLayout -import quanto.util.json._ import quanto.layout.constraint._ -import java.awt.event.{ActionEvent, ActionListener} -import java.awt.datatransfer._ -import java.awt.Toolkit -import quanto.gui.graphview.{EdgeOverlay,BBoxOverlay} -import quanto.util.Globals +import quanto.util.json._ +import quanto.util.{Globals, UserOptions} + +import scala.swing._ +import scala.swing.event.Key.Modifier +import scala.swing.event._ case class VertexSelectionChanged(graph: Graph, selectedVerts: Set[VName]) extends GraphEvent @@ -67,6 +69,7 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B val layoutTimer = new javax.swing.Timer(5, new ActionListener { def actionPerformed(e: ActionEvent) { if (qLayout.graph != null) { + view.requestFocusInWindow() qLayout.step() qLayout.updateGraph() graph = qLayout.graph @@ -80,6 +83,7 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B val layoutTimer1 = new javax.swing.Timer(5, new ActionListener { def actionPerformed(e: ActionEvent) { if (q1Layout.graph != null) { + view.requestFocusInWindow() q1Layout.step() q1Layout.updateGraph() graph = q1Layout.graph @@ -159,6 +163,21 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B undoStack.register("Add Edge") { deleteEdge(e) } } + // Uses the current state of the controller to add an edge + private def addEdgeFromController(v1: VName, v2: VName): Unit = { + controlsOpt.foreach { c => + val edgeType = theory.edgeTypes(c.EdgeTypeSelect.selection.item).defaultData + val theoryJSON : JsonObject = JsonObject( + "data" -> edgeType + ) + + val eData = if (c.EdgeDirected.selected) DirEdge.fromJson(theoryJSON, theory) + else UndirEdge.fromJson(theoryJSON, theory) + + addEdge(graph.edges.fresh, eData, (v1, v2)) + } + } + private def deleteEdge(e: EName) { val d = graph.edata(e) val vs = (graph.source(e), graph.target(e)) @@ -367,6 +386,14 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B view.repaint() } + def selectAll() : Unit = { + selectedBBoxes = graph.bboxes + selectedVerts = graph.verts + selectedEdges = graph.edges + view.publish(VertexSelectionChanged(graph, graph.verts)) + view.repaint() + } + private def roundIfSnapped(d : Double) = { if (keepSnapped) math.rint(d / 0.25) * 0.25 else d // rounds to .25 } @@ -384,6 +411,84 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B replaceGraph(newGraph, "Layout Graph") } + var rDown = false + + def startRelaxGraph(expandNodes : Boolean) { + view.requestFocusInWindow() + val layout = if (!expandNodes) q1Layout else qLayout + if (!rDown) { + rDown = true + layout.initialize(graph, randomCoords = false) + layout.clearLockedVertices() + if (!selectedVerts.isEmpty) { + graph.verts.foreach { v => if (!selectedVerts.contains(v)) layout.lockVertex(v) } + } + + undoStack.start("Relax layout") + replaceGraph(graph, "") + if (!expandNodes) layoutTimer1.start() else layoutTimer.start() + } + } + + def endRelaxGraph() { + if (rDown) { + rDown = false + layoutTimer.stop() + layoutTimer1.stop() + + replaceGraph(graph, "") + undoStack.commit() + } + } + + def cycleVertexType(vertex: VName, shift: Int = 1, includeWire: Boolean = false): Unit = { + + val currentData = graph.vdata(vertex) + val options = theory.vertexTypes.keys.toSeq :+ "" + val current = options.indexOf(currentData.typ) + + val next = options((current + shift + options.length) % options.length) + val newTyp = if (next == "" & !includeWire) + options((current + shift + 1 + options.length) % options.length) + else next + + val coords = (currentData.coord._1, currentData.coord._2) + // Using replaceGraph because of data lost on cycling + replaceGraph(graph.updateVData(vertex)(d => { + newTyp match { + case "" => + WireV.apply(coords) + case _ => + NodeV(theory.vertexTypes(newTyp).defaultData, + annotation = d.annotation, + theory = theory).withCoord(coords) + } + } + ), "Cycled vertex") + + + } + + + def normaliseGraph(): Unit = { + replaceGraph(graph.normalise.coerceWiresAndBoundaries, "Normalised graph") + } + + def minimiseGraph(): Unit = { + view.requestFocusInWindow() + replaceGraph(graph.minimise.coerceWiresAndBoundaries, "Minimised graph") + } + + def focusOnGraph(): Unit = { + view.requestFocusInWindow() + view.focusOnGraph() + } + + def vertexAt(point: Point) : Option[VName] = view.vertexDisplay find { + _._2.pointHit(point) + } map { + _._1 + } view.listenTo(view.mouse.clicks, view.mouse.moves) view.reactions += { case MousePressed(_, pt, modifiers, clicks, _) => @@ -475,7 +580,25 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B mouseState = box view.selectionBox = Some(box.rect) } - + case FreehandTool(_, _) => + undoStack.start("Freehand drag") + var vname = graph.verts.freshWithSuggestion(VName("v0")) + vertexAt(pt) match { + case Some(vertex) => // Already a vertex here, start a path + vname = vertex + mouseState = FreehandTool(Some(vname), startedWithNew = false) + case None => // No vertex here! Create a wire node if enough time and distance have passed + val coord = view.trans fromScreen(pt.getX, pt.getY) + controlsOpt.foreach { c => + val vertexData = WireV(theory = theory) + addVertex(vname, vertexData.withCoord(coord)) + } + mouseState = FreehandTool(Some(vname), startedWithNew = true) + } + view.edgeOverlay = Some(EdgeOverlay(pt, src = vname, tgt = Some(vname))) + view.repaint() + case RequestMinimiseGraph() => + case RequestFocusOnGraph() => case state => throw new InvalidMouseStateException("MousePressed", state) } @@ -486,6 +609,17 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B case AddBoundaryTool() => // do nothing case AddEdgeTool() => // do nothing case AddBangBoxTool() => // do nothing + case FreehandTool(maybeVName, _) => + val start = maybeVName.get + vertexAt(pt) match { + case Some(vertex) => + view.edgeOverlay = Some(EdgeOverlay(pt, src = start, tgt = Some(vertex))) + case None => + view.edgeOverlay = Some(EdgeOverlay(pt, src = start, tgt = None)) + } + view.repaint() + case RequestMinimiseGraph() => + case RequestFocusOnGraph() => case SelectionBox(start,_) => val box = SelectionBox(start, pt) mouseState = box @@ -495,7 +629,7 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B val box = BangSelectionBox(start, pt) mouseState = box view.selectionBox = Some(box.rect) - view.repaint() + view.repaint() case DragVertex(start, prev) => shiftVertsNoRegister(selectedVerts, start, prev, pt) view.repaint() @@ -510,13 +644,38 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B else None view.bboxOverlay = Some(BBoxOverlay(pt, startBB, vertexHit, bboxHit)) view.repaint() + case state => throw new InvalidMouseStateException("MouseReleased", state) } case MouseReleased(_, pt, modifiers, _, _) => mouseState match { - case SelectTool() => // do nothing - case AddEdgeTool() => // do nothing - case AddBangBoxTool () => // do nothing + case SelectTool() => // do nothing + case AddEdgeTool() => // do nothing + case AddBangBoxTool() => // do nothing + case FreehandTool(maybeVName, startedWithNew) => + view.edgeOverlay = None + vertexAt(pt) match { + case Some(vertex) => // Vertex here + + if (maybeVName.get == vertex) { + // dragged to itself + if (!startedWithNew) cycleVertexType(vertex) + } else { + // dragged to another vertex + addEdgeFromController(maybeVName.get, vertex) + } + case None => + val coord = view.trans fromScreen(pt.getX, pt.getY) + val vname = graph.verts.freshWithSuggestion(VName("v0")) + controlsOpt.foreach { c => + val vertexData = WireV(theory = theory) + addVertex(vname, vertexData.withCoord(coord)) + } // dragged to another vertex + addEdgeFromController(maybeVName.get, vname) + } + undoStack.commit() + mouseState = FreehandTool(None, startedWithNew = false) + case SelectionBox(start,_) => val oldSelectedVerts = selectedVerts @@ -615,11 +774,7 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B case DragEdge(startV) => val vertexHit = view.vertexDisplay find { _._2.pointHit(pt) } map { _._1 } vertexHit.map { endV => - controlsOpt.map { c => - val defaultData = if (c.EdgeDirected.selected) DirEdge.fromJson(theory.defaultEdgeData, theory) - else UndirEdge.fromJson(theory.defaultEdgeData, theory) - addEdge(graph.edges.fresh, defaultData, (startV, endV)) - } + addEdgeFromController(startV, endV) } mouseState = AddEdgeTool() view.edgeOverlay = None @@ -647,6 +802,8 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B view.bboxOverlay = None view.repaint() + case RequestMinimiseGraph() => + case RequestFocusOnGraph() => case state => throw new InvalidMouseStateException("MouseReleased", state) } @@ -657,7 +814,7 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B view.listenTo(view.keys) view.listenTo(view.mouse.wheel) - var rDown = false + val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask @@ -671,27 +828,12 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B undoStack.commit() view.repaint() } + case KeyPressed(_, Key.A, m, _) => + if (Modifier.Control == (m & Modifier.Control)) selectAll() case KeyPressed(_, Key.R, m, _) => - if (!rDown) { - val layout = if ((m & Modifier.Shift) == Modifier.Shift) q1Layout else qLayout - rDown = true - layout.initialize(graph, randomCoords = false) - layout.clearLockedVertices() - if (!selectedVerts.isEmpty) { - graph.verts.foreach { v => if (!selectedVerts.contains(v)) layout.lockVertex(v) } - } - - undoStack.start("Relax layout") - replaceGraph(graph, "") - if ((m & Modifier.Shift) == Modifier.Shift) layoutTimer1.start() else layoutTimer.start() - } + startRelaxGraph((m & Modifier.Shift) != Modifier.Shift) case KeyReleased(_, Key.R, _, _) => - rDown = false - layoutTimer.stop() - layoutTimer1.stop() - - replaceGraph(graph, "") - undoStack.commit() + endRelaxGraph() case KeyReleased(_, Key.G, _, _) => snapToGrid() //replaceGraph(graph, "") @@ -705,8 +847,14 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B if ((modifiers & Globals.CommandDownMask) == Globals.CommandDownMask) { cutSubgraph() } case KeyPressed(_, Key.V, modifiers, _) => if (modifiers == 0) { + mouseState = AddVertexTool() - controlsOpt.map { c => c.setMouseState(mouseState) } + controlsOpt.foreach { c => + if (c.GraphToolGroup.selected.contains(c.AddVertexButton)) { + c.VertexTypeSelect.selection.index = (c.VertexTypeSelect.selection.index + 1) % (theory.vertexTypes.size + 1) + } + c.setMouseState(mouseState) + } } else if ((modifiers & Globals.CommandDownMask) == Globals.CommandDownMask) { pasteSubgraph() } case KeyPressed(_, Key.S, modifiers, _) => @@ -717,7 +865,12 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B case KeyPressed(_, Key.E, modifiers, _) => if (modifiers == 0) { mouseState = AddEdgeTool() - controlsOpt.map { c => c.setMouseState(mouseState) } + controlsOpt.foreach { c => + if (c.GraphToolGroup.selected.contains(c.AddEdgeButton)) { + c.EdgeTypeSelect.selection.index = (c.EdgeTypeSelect.selection.index + 1) % theory.edgeTypes.size + } + c.setMouseState(mouseState) + } } case KeyPressed(_, Key.B, modifiers, _) => if (modifiers == 0) { @@ -738,15 +891,45 @@ class GraphEditController(view: GraphView, undoStack: UndoStack, val readOnly: B if ((modifiers & Globals.CommandDownMask) == Globals.CommandDownMask) { snapToGrid() } - case KeyPressed(_, Key.F, _, _) => - undoStack.start("Flip edge direction") - selectedEdges.foreach { flipEdge } - undoStack.commit() - view.repaint() - case KeyPressed(_, Key.D, _, _) => - undoStack.start("Toggle edge directed") - selectedEdges.foreach { toggleDirected } - undoStack.commit() - view.repaint() + case KeyPressed(_, Key.M, _, _) => + minimiseGraph() + case KeyPressed(_, Key.F, modifiers, _) => + if ((modifiers & Modifier.Shift) != Modifier.Shift) { + mouseState = FreehandTool(None, startedWithNew = false) + controlsOpt.foreach { c => c.setMouseState(mouseState) } + } else { + undoStack.start("Flip edge direction") + selectedEdges.foreach { + flipEdge + } + undoStack.commit() + view.repaint() + } + case KeyPressed(_, Key.D, modifiers, _) => + // DOn't do this if pressing ctrl: ctrl+d is "Start new derivation" + if((modifiers & Modifier.Control) != Modifier.Control) { + if (selectedEdges.nonEmpty) { + if((modifiers & Modifier.Shift) != Modifier.Shift){ + + undoStack.start("Toggle edge directed") + selectedEdges.foreach { + toggleDirected + } + }else{ + undoStack.start("Flip edge direction") + selectedEdges.foreach { + flipEdge + } + } + undoStack.commit() + view.repaint() + } else { + controlsOpt.foreach { c => + if (c.GraphToolGroup.selected.contains(c.AddEdgeButton)) { + c.EdgeDirected.selected = true + } + } + } + } } } diff --git a/scala/src/main/scala/quanto/gui/GraphEditPanel.scala b/scala/src/main/scala/quanto/gui/GraphEditPanel.scala index 3ce570c2..071dbf38 100644 --- a/scala/src/main/scala/quanto/gui/GraphEditPanel.scala +++ b/scala/src/main/scala/quanto/gui/GraphEditPanel.scala @@ -2,9 +2,12 @@ package quanto.gui import graphview.GraphView import quanto.data._ + import swing._ import swing.event._ import javax.swing.ImageIcon + +import quanto.util.UserAlerts import quanto.util.swing.ToolBar case class MouseStateChanged(m : MouseState) extends Event @@ -51,7 +54,7 @@ class GraphEditControls(theory: Theory) extends Publisher { } val AddEdgeButton = new ToggleButton() with ToolButton { - icon = new ImageIcon(GraphEditor.getClass.getResource("draw-path.png"), "Add Edge") + icon = new ImageIcon(GraphEditor.getClass.getResource("add-edge.png"), "Add Edge") tool = AddEdgeTool() tooltip = "Add Edge (E)" } @@ -62,16 +65,53 @@ class GraphEditControls(theory: Theory) extends Publisher { tooltip = "Add Bang Box (B)" } + val RelaxButton = new ToggleButton() with ToolButton { + icon = new ImageIcon(GraphEditor.getClass.getResource("expand.png"), "Relax graph") + tool = RelaxToolDown() + tooltip = "Relax graph (R/shift-R)" + } + + + val FocusGraphButton = new ToggleButton() with ToolButton { + icon = new ImageIcon(GraphEditor.getClass.getResource("focus.png"), "Resize Viewport") + tool = RequestFocusOnGraph() + tooltip = "Focus the viewport on the whole graph" + } + + + + val FreehandButton = new ToggleButton() with ToolButton { + icon = new ImageIcon(GraphEditor.getClass.getResource("draw-path.png"), "Freehand drawing") + tool = FreehandTool(None, startedWithNew = false) + tooltip = "Freehand draw (F)" + } + + + val MinimiseButton = new ToggleButton() with ToolButton { + icon = new ImageIcon(GraphEditor.getClass.getResource("normalise.png"), "Minimise") + tooltip = "Straighten edges, and convert leaves to boundaries (M)" + tool = RequestMinimiseGraph() + } + val GraphToolGroup = new ButtonGroup(SelectButton, AddVertexButton, AddBoundaryButton, AddEdgeButton, - AddBangBoxButton) + AddBangBoxButton, + FreehandButton + ) def setMouseState(m : MouseState) { - val previousTool = GraphToolGroup.selected + val previousToolButton = GraphToolGroup.selected publish(MouseStateChanged(m)) m match { + case FreehandTool(_,_) => + VertexTypeLabel.enabled = false + VertexTypeSelect.enabled = false + EdgeTypeLabel.enabled = false + EdgeTypeSelect.enabled = false + EdgeDirected.enabled = false + GraphToolGroup.select(FreehandButton) case SelectTool() => VertexTypeLabel.enabled = false VertexTypeSelect.enabled = false @@ -80,7 +120,7 @@ class GraphEditControls(theory: Theory) extends Publisher { EdgeDirected.enabled = false GraphToolGroup.select(SelectButton) case AddVertexTool() => - if(previousTool.nonEmpty && previousTool.get == AddVertexButton){ + if(previousToolButton.nonEmpty && previousToolButton.get == AddVertexButton){ //VertexTypeSelect.selection.index = (VertexTypeSelect.selection.index + 1) % vertexOptions.size } VertexTypeLabel.enabled = true @@ -90,7 +130,7 @@ class GraphEditControls(theory: Theory) extends Publisher { EdgeDirected.enabled = false GraphToolGroup.select(AddVertexButton) case AddEdgeTool() => - if(previousTool.nonEmpty && previousTool.get == AddEdgeButton){ + if(previousToolButton.nonEmpty && previousToolButton.get == AddEdgeButton){ //EdgeTypeSelect.selection.index = (EdgeTypeSelect.selection.index + 1) % edgeOptions.size } VertexTypeLabel.enabled = false @@ -123,9 +163,38 @@ class GraphEditControls(theory: Theory) extends Publisher { setMouseState(t.tool) } + // These are all different, because they need to not take focus away from the view + listenTo(RelaxButton.mouse.clicks, MinimiseButton.mouse.clicks, FocusGraphButton.mouse.clicks) + + reactions += { + case MousePressed(RelaxButton,_,_,_,_) => + RelaxButton.selected = false + publish(MouseStateChanged(RelaxToolDown())) + case MouseReleased(RelaxButton,_,_,_,_) => + RelaxButton.selected= false + publish(MouseStateChanged(RelaxToolUp())) + case MousePressed(FocusGraphButton,_,_,_,_) => + FocusGraphButton.selected = false + publish(MouseStateChanged(RequestFocusOnGraph())) + case MouseReleased(FocusGraphButton,_,_,_,_) => + FocusGraphButton.selected= false + publish(MouseStateChanged(RequestFocusOnGraph())) + case MousePressed(MinimiseButton,_,_,_,_) => + MinimiseButton.selected = false + publish(MouseStateChanged(RequestMinimiseGraph())) + case MouseReleased(MinimiseButton,_,_,_,_) => + MinimiseButton.selected= false + publish(MouseStateChanged(RequestMinimiseGraph())) + } + val MainToolBar = new ToolBar { - contents += (SelectButton, AddVertexButton, AddBoundaryButton, AddEdgeButton, AddBangBoxButton) + contents += (SelectButton, AddVertexButton, AddBoundaryButton, AddEdgeButton, AddBangBoxButton, FreehandButton) } + MainToolBar.peer.addSeparator() + MainToolBar.contents += FocusGraphButton + MainToolBar.contents += RelaxButton + MainToolBar.contents += MinimiseButton + } @@ -168,7 +237,15 @@ with HasDocument case UIElementResized(GraphViewScrollPane) => graphView.resizeViewToFit() graphView.repaint() + case MouseStateChanged(RelaxToolDown()) => graphEditController.startRelaxGraph(true) + case MouseStateChanged(RelaxToolUp()) => graphEditController.endRelaxGraph() + case MouseStateChanged(RequestMinimiseGraph()) => graphEditController.minimiseGraph() + case MouseStateChanged(RequestFocusOnGraph()) => graphEditController.focusOnGraph() case MouseStateChanged(m) => + if (graphEditController.rDown) graphEditController.endRelaxGraph() graphEditController.mouseState = m - } + case DocumentRequestingNaturalFocus(d) => + graphView.requestFocus() + } + } diff --git a/scala/src/main/scala/quanto/gui/JsonConsole.scala b/scala/src/main/scala/quanto/gui/JsonConsole.scala index d48d651a..2086949a 100644 --- a/scala/src/main/scala/quanto/gui/JsonConsole.scala +++ b/scala/src/main/scala/quanto/gui/JsonConsole.scala @@ -94,7 +94,7 @@ object JsonConsole extends SimpleSwingApplication { border = new LineBorder(Color.WHITE, 1) } - def componentFor(list: ListView[_], isSelected: Boolean, + override def componentFor(list: ListView[_ <: CoreOutputItem], isSelected: Boolean, focused: Boolean, a: CoreOutputItem, index: Int): Component = { @@ -114,6 +114,7 @@ object JsonConsole extends SimpleSwingApplication { panel } + } } diff --git a/scala/src/main/scala/quanto/gui/MouseState.scala b/scala/src/main/scala/quanto/gui/MouseState.scala index 6748cdab..35d53fe9 100644 --- a/scala/src/main/scala/quanto/gui/MouseState.scala +++ b/scala/src/main/scala/quanto/gui/MouseState.scala @@ -61,4 +61,14 @@ case class DragEdge(startVertex: VName) extends MouseState case class AddBangBoxTool() extends MouseState /** A nesting edge is being dragged from the bang box corner */ -case class DragBangBoxNesting(startBBox: BBName) extends MouseState \ No newline at end of file +case class DragBangBoxNesting(startBBox: BBName) extends MouseState + +/** The relax tool is pressed down or released again */ +case class RelaxToolDown() extends MouseState +case class RelaxToolUp() extends MouseState + +// The freehand tool is being used +case class FreehandTool(start: Option[VName], startedWithNew: Boolean) extends MouseState + +case class RequestMinimiseGraph() extends MouseState +case class RequestFocusOnGraph() extends MouseState \ No newline at end of file diff --git a/scala/src/main/scala/quanto/gui/NewProjectDialog.scala b/scala/src/main/scala/quanto/gui/NewProjectDialog.scala index d947f634..c762def1 100644 --- a/scala/src/main/scala/quanto/gui/NewProjectDialog.scala +++ b/scala/src/main/scala/quanto/gui/NewProjectDialog.scala @@ -15,7 +15,12 @@ class NewProjectDialog extends Dialog { val NameField = new TextField() val ProjectLocationField = new TextField(System.getProperty("user.home")) val BrowseProjectButton = new Button("...") - val TheoryChoiceDropdown = new ComboBox(Seq[String]("ZX", "ZW", "From existing project", "From .qtheory file")) + val TheoryChoiceDropdown = new ComboBox(Seq[String]( + "ZX", + "ZW", + "From existing project", + "From .qtheory file", + "plain")) val TheoryLocationField = new TextField("") val BrowseTheoryButton = new Button("...") val theoryName = new TextField("") @@ -102,7 +107,7 @@ class NewProjectDialog extends Dialog { fileChoiceFilter = filter } - disableFileChoosers("red_green") + disableFileChoosers("ZX") reactions += { case ButtonClicked(CreateButton) => @@ -111,11 +116,11 @@ class NewProjectDialog extends Dialog { val path = ProjectLocationField.text val folder = new File(path + "/" + name) if (name.isEmpty) { - UserAlerts.errorbox("Please enter a name for your project.") + UserAlerts.errorBox("Please enter a name for your project.") } else if (folder.exists()) { - UserAlerts.errorbox("That folder is already in use.") + UserAlerts.errorBox("That folder is already in use.") } else if (theory.isEmpty) { - UserAlerts.errorbox("Please choose a theory.") + UserAlerts.errorBox("Please choose a theory.") } else { result = Some((theory, name, path)) close() @@ -143,9 +148,11 @@ class NewProjectDialog extends Dialog { case SelectionChanged(TheoryChoiceDropdown) => TheoryChoiceDropdown.selection.item match { case "ZX" => - disableFileChoosers("red_green") + disableFileChoosers("ZX") case "ZW" => - disableFileChoosers("black_white") + disableFileChoosers("ZW") + case "plain" => + disableFileChoosers("plain") case "From existing project" => enableFileChoosers("Project files", "qproject") case "From .qtheory file" => diff --git a/scala/src/main/scala/quanto/gui/PythonEditPanel.scala b/scala/src/main/scala/quanto/gui/PythonEditPanel.scala index 3208443c..3d4bb3b4 100644 --- a/scala/src/main/scala/quanto/gui/PythonEditPanel.scala +++ b/scala/src/main/scala/quanto/gui/PythonEditPanel.scala @@ -11,50 +11,35 @@ import java.awt.event.{KeyAdapter, KeyEvent} import javax.swing.ImageIcon import quanto.util.swing.ToolBar -import quanto.util.UserAlerts.{alert, Elevation, SelfAlertingProcess} +import quanto.util.UserAlerts.{Elevation, SelfAlertingProcess, alert} -import scala.swing.event.ButtonClicked +import scala.swing.event.{ButtonClicked, Event} import quanto.util._ import java.io.{File, PrintStream} -class PythonEditPanel extends BorderPanel with HasDocument { - val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask - - val pyMode = new Mode("Python") - - //val modeXml = - // if (Globals.isBundle) new File("python.xml").getAbsolutePath - // else getClass.getResource("python.xml").getPath - pyMode.setProperty("file", QuantoDerive.pythonModeFile) - //println(sml.getProperty("file")) - val code = StandaloneTextArea.createTextArea() - code.setFont(UserOptions.font) - //mlCode.setFont(new Font("Menlo", Font.PLAIN, 14)) +import quanto.rewrite.Simproc - val buf = new JEditBuffer1 - buf.setMode(pyMode) +case class SimprocsUpdated() extends Event - var execThread : Thread = null +class PythonEditPanel extends BorderPanel with HasDocument { + val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask - code.setBuffer(buf) + val CodeArea : TextEditor = new TextEditor(TextEditor.Modes.python) - code.addKeyListener(new KeyAdapter { - override def keyPressed(e: KeyEvent) { - if (e.getModifiers == CommandMask) e.getKeyChar match { - case 'x' => Registers.cut(code, '$') - case 'c' => Registers.copy(code, '$') - case 'v' => Registers.paste(code, '$') - case _ => - } - } - }) + // Inject python to expose some relevant variables + def documentName: String = document.file.map( + // Assume the user has a project loaded, otherwise shouldn't be able to access GUI + f => QuantoDerive.CurrentProject.get.relativePath(f) + ).getOrElse("Unsaved File") - val document = new CodeDocument("Python Script", "py", this, code) + // Now run the python along with the header + def code : String = CodeArea.getText + val document = new CodeDocument("Python Script", "py", this, CodeArea.TextArea) + listenTo(document) - val textPanel = new BorderPanel { - peer.add(code, BorderLayout.CENTER) - } + var execThread : Thread = null + val textPanel = CodeArea.Component val RunButton = new Button() { icon = new ImageIcon(GraphEditor.getClass.getResource("start.png"), "Run scala code") @@ -87,20 +72,40 @@ class PythonEditPanel extends BorderPanel with HasDocument { listenTo(RunButton, InterruptButton) + def allSimprocs : Map[String, Simproc] = QuantoDerive.CurrentProject.map(p => p.simprocs).getOrElse(Map()) reactions += { + case DocumentRequestingNaturalFocus(_) => + CodeArea.TextArea.requestFocus() case ButtonClicked(RunButton) => if (execThread == null) { - val processReporting = new SelfAlertingProcess("Python from source") + val processReporting = new SelfAlertingProcess(s"Python $documentName") execThread = new Thread(new Runnable { def run() { try { val python = new PythonInterpreter + + QuantoDerive.CurrentProject.foreach(project => project.lastRunPythonFilePath = Some(documentName)) + def simprocsFromThisFile = allSimprocs.filter(kv => kv._2.sourceFile == documentName).keys + // unregister any simprocs previously linked to this file + simprocsFromThisFile.foreach(simprocName => QuantoDerive.CurrentProject.foreach( + p => p.simprocs -= simprocName + )) + QuantoDerive.CurrentProject.foreach(pr => python.getSystemState.path.add(pr.rootFolder)) + outputTextArea.text = "" python.set("output", output) - - //python.set("output", output) - python.exec(code.getBuffer.getText) + python.exec(code) + + // Tell the user which simprocs are linked to this file + alert(s"Simprocs registered to $documentName: " + + simprocsFromThisFile.mkString(", ") + ) + // Link this python to those simprocs + simprocsFromThisFile.foreach(simprocName => QuantoDerive.CurrentProject.foreach( + p => p.simprocs(simprocName).sourceCode = code + )) + PythonEditPanel.publishUpdate() processReporting.finish() } catch { case e : Throwable => @@ -124,3 +129,9 @@ class PythonEditPanel extends BorderPanel with HasDocument { } } } + +object PythonEditPanel extends Publisher { + def publishUpdate() : Unit = { + publish(SimprocsUpdated()) + } +} \ No newline at end of file diff --git a/scala/src/main/scala/quanto/gui/QuantoDerive.scala b/scala/src/main/scala/quanto/gui/QuantoDerive.scala index 1e4ac6c4..3ccf26e7 100644 --- a/scala/src/main/scala/quanto/gui/QuantoDerive.scala +++ b/scala/src/main/scala/quanto/gui/QuantoDerive.scala @@ -5,13 +5,16 @@ import org.python.util.PythonInterpreter import scala.io.Source import scala.swing._ -import scala.swing.event.{Key, SelectionChanged} -import javax.swing.{KeyStroke, UIManager} +import scala.swing.event.{Key, KeyPressed, SelectionChanged} +import javax.swing.{JOptionPane, KeyStroke, SwingUtilities, UIManager} import java.awt.event.KeyEvent +import java.awt.Frame +import java.awt.event.{KeyEvent, MouseAdapter, MouseEvent} -import quanto.util.json.{Json, JsonString} +import quanto.util.json.{Json, JsonObject, JsonString} import quanto.data._ import java.io.{File, FilenameFilter, IOException, PrintWriter} + import javax.swing.plaf.metal.MetalLookAndFeel import java.util.prefs.Preferences @@ -25,44 +28,69 @@ import akka.actor.PoisonPill import scala.concurrent.duration._ import scala.concurrent.ExecutionContext import ExecutionContext.Implicits.global -import java.awt.{Color, Window} -import javax.swing.SwingUtilities +import java.awt.{Color, Desktop, Window} +import java.lang.NullPointerException + +import javax.imageio.ImageIO +import javax.swing.filechooser.FileNameExtensionFilter +import quanto.gui.QuantoDerive.FileMenu.mnemonic +import quanto.gui.FileOpened +import quanto.util._ -import quanto.util.{Globals, UserAlerts, UserOptions, WebHelper} +class NoProjectException extends Exception("No project open.") object QuantoDerive extends SimpleSwingApplication { + override def main(args: Array[String]): Unit = { + super.main(args) + if (args.length > 0) { + alert("Loading project from commandline") + val arg = args(0) + loadProject(arg) + } + if (args.length > 1) { + alert("Opening file from commandline") + val fname = args(1) + ProjectFileTree.publish(FileOpened(new File(fname))) + } + } + + System.setProperty("apple.eawt.quitStrategy", "CLOSE_ALL_WINDOWS") val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask val actorSystem = ActorSystem("QuantoDerive") //val core = actorSystem.actorOf(Props { new Core }, "core") implicit val timeout = Timeout(1.day) - // copy python mode xml into a temp file, as jEdit component can't handle JAR resources - lazy val pythonModeFile = { - val f = File.createTempFile("python", "xml") - f.deleteOnExit() - val pr = new PrintWriter(f) - Source.fromInputStream(getClass.getResourceAsStream("python.xml")).foreach(pr.print) - pr.close() - f.getCanonicalPath - } - // pre-initialise jython, so its zippy when the user clicks "run" in a script new Thread(new Runnable { def run() { new PythonInterpreter() }}).start() - println(new File(".").getAbsolutePath) + UserAlerts.alert("Working directory: " + new File(".").getAbsolutePath) + + // Dialogs in in scala.swing seem to be broken since updated scala to 2.12, so + // we're using the javax.swing versions instead + def error(msg: String) = + UserAlerts.errorBox(msg) - def error(msg: String) = Dialog.showMessage( - title = "Error", message = msg, messageType = Dialog.Message.Error) + def alert(msg: String) = + UserAlerts.alert(msg) + def warn(msg: String) = + UserAlerts.alert(msg, UserAlerts.Elevation.WARNING) + + def uiScale(i : Int) : Int = UserOptions.scaleInt(i) + + //Dialog.showMessage(title = "Error", message = msg, messageType = Dialog.Message.Error) val prefs = Preferences.userRoot().node(this.getClass.getName) try { UIManager.setLookAndFeel(new MetalLookAndFeel) // tabs in OSX PLAF look bad - UserOptions.uiScale = prefs.getDouble("uiScale", 1.0) + UserOptions.uiScale = UserOptions.uiScale // Initiliases all the UI options } catch { - case e: Exception => e.printStackTrace() + case e: Exception => { + UserAlerts.alert("Could not load UI preferences on startup.") + e.printStackTrace() + } } def unloadProject() { @@ -70,22 +98,44 @@ object QuantoDerive extends SimpleSwingApplication { ProjectFileTree.root = None } - def loadProject(projectLocation: String) : Option[Project] = { - UserAlerts.alert(s"Opening project: $projectLocation") - val projectFile = new File(projectLocation + "/main.qproject") + def updateProjectFile(projectFile: File): Unit = { + if (CurrentProject.nonEmpty) { + val project = CurrentProject.get + try { + if (projectFile.exists) { + val parsedInput = Json.parse(projectFile) + if (Project.toJson(project).toString != parsedInput.toString) { + Project.toJson(project).writeTo(project.projectFile) + UserAlerts.alert(s"Updated project file", UserAlerts.Elevation.DEBUG) + } + } + } catch { + case e: Exception => + throw new ProjectLoadException("Error loading project", e) + } + } + } + + def loadProject(projectFileLocation: String) : Option[Project] = { + alert(s"Opening project: $projectFileLocation") + + val projectFile = if(new File(projectFileLocation).isDirectory){ + new File(projectFileLocation + "/main.qproject") + } else { + new File(projectFileLocation) + } try { if (projectFile.exists) { val parsedInput = Json.parse(projectFile) - val project = Project.fromJson(parsedInput, projectLocation) + val project = Project.fromJson(parsedInput, new File(projectFileLocation)) // Old .qproject files had links rather than embedded theories - if (Project.toJson(project).toString != parsedInput.toString) { - UserAlerts.alert("Updating out of date .qproject file") - Project.toJson(project).writeTo(new File(projectLocation + "/main.qproject")) - } + // So update when loading in CurrentProject = Some(project) - ProjectFileTree.root = Some(projectLocation) - prefs.put("lastProjectFolder", projectLocation) - UserAlerts.alert(s"Successfully loaded project: $projectLocation") + updateProjectFile(projectFile) + ProjectFileTree.root = Some(project.rootFolder) + prefs.put("lastProjectFile", projectFileLocation) + UserAlerts.registerLogFile(Some(new File(project.rootFolder + s"/${project.name}_log.txt"))) + alert(s"Successfully loaded project: $projectFileLocation") Some(project) } else { UserAlerts.alert("Selected project file does not exist", UserAlerts.Elevation.ERROR) @@ -98,16 +148,16 @@ object QuantoDerive extends SimpleSwingApplication { unloadProject() None } finally { - FileMenu.updateNewEnabled() + refreshAllMenusAndTitle() } } //CurrentProject.map { pr => core ! SetMLWorkingDir(pr.rootFolder) } val ProjectFileTree = new FileTree - ProjectFileTree.preferredSize = new Dimension(250,360) + ProjectFileTree.preferredSize = new Dimension(uiScale(250), uiScale(360)) ProjectFileTree.filenameFilter = Some(new FilenameFilter { - val extns = Set("qgraph", "qrule", "qderive", "ML", "py") + val extns = Set("qgraph", "qrule", "qderive", "ML", "py", "qsbr") def accept(parent: File, name: String) = { val extn = name.lastIndexOf('.') match { case i if i > 0 => name.substring(i+1) ; case _ => ""} @@ -120,7 +170,7 @@ object QuantoDerive extends SimpleSwingApplication { }) - var CurrentProject : Option[Project] = prefs.get("lastProjectFolder", null) match { + var CurrentProject : Option[Project] = prefs.get("lastProjectFile", null) match { case path : String => try { loadProject(path) @@ -132,6 +182,28 @@ object QuantoDerive extends SimpleSwingApplication { case _ => None } + // Access via even is preferred, as then we can pinpoint where to put the popup + def popup(menu: PopupMenu, e: Option[MouseEvent]) : Unit = { + if (e.nonEmpty){ + val componentBounds = e.get.getComponent.getBounds + val shift : Int = UserOptions.scaleInt(5) + popup(menu, e.get.getX + componentBounds.x + shift, e.get.getY + componentBounds.y+ shift) + } else { + popup(menu, 0, 0) + } + } + + def popup(menu: PopupMenu, x: Int, y: Int) : Unit = { + menu.show(Main, x, y) + } + + def addAndFocusPage(d : DocumentPage): Unit = { + MainDocumentTabs += d + listenTo(d.tabComponent) + MainDocumentTabs.focus(d) + d.document.publish(DocumentChanged(d.document)) + d.document.focusOnNaturalComponent() + } listenTo(quanto.util.UserOptions.OptionsChanged) reactions += { @@ -144,16 +216,12 @@ object QuantoDerive extends SimpleSwingApplication { } } - val MainTabbedPane = new ClosableTabbedPane + val MainDocumentTabs = new DocumentTabs - def currentDocument: Option[HasDocument] = - MainTabbedPane.currentContent match { - case Some(doc: HasDocument) => Some(doc) - case _ => None - } + def currentDocument: Option[DocumentPage] = MainDocumentTabs.currentFocus def currentGraphController: Option[GraphEditController] = - MainTabbedPane.currentContent match { + MainDocumentTabs.currentContent match { case Some(p: GraphEditPanel) => Some(p.graphEditController) case Some(p: RuleEditPanel) => Some(p.focusedController) case _ => None @@ -185,17 +253,18 @@ object QuantoDerive extends SimpleSwingApplication { def histView = _histView object LeftSplit extends SplitPane { + resizeWeight = 0.5 orientation = Orientation.Horizontal contents_=(ProjectFileTree, HistViewSlot) } object Split extends SplitPane { orientation = Orientation.Vertical - contents_=(LeftSplit, MainTabbedPane) + contents_=(LeftSplit, MainDocumentTabs.component) } def hasUnsaved = - MainTabbedPane.pages.exists { p => p.content match { + MainDocumentTabs.documents.exists { p => p.content match { case c : HasDocument => c.document.unsavedChanges case _ => false }} @@ -207,8 +276,8 @@ object QuantoDerive extends SimpleSwingApplication { * (depends on user choice) */ def trySaveAll() = { - MainTabbedPane.pages.forall { p => - MainTabbedPane.selection.index = p.index // focus a pane before saving + MainDocumentTabs.documents.forall { p => + MainDocumentTabs.focus(p) // focus a pane before saving p.content match { case c : HasDocument => c.document.trySave() case _ => false @@ -218,43 +287,79 @@ object QuantoDerive extends SimpleSwingApplication { /** * Show a dialog (when necessary) asking the user if the program should quit + *@param specific : Specify a list to close, or None to close all * @return true if the program should quit, false otherwise */ - def closeAllDocuments() = { + def closeAllOrListOfDocuments(specific: Option[List[DocumentPage]] = None) : Boolean = { if (hasUnsaved) { - val choice = Dialog.showOptions( - title = "Confirm quit", - message = "Some documents have unsaved changes.\nDo you want to save your changes or discard them?", - entries = "Save" :: "Discard" :: "Cancel" :: Nil, - initial = 0 - ) +// val choice = Dialog.showOptions( +// title = "Confirm quit", +// message = "Some documents have unsaved changes.\nDo you want to save your changes or discard them?", +// entries = "Save" :: "Discard" :: "Cancel" :: Nil, +// initial = 0 +// ) + + val choice = JOptionPane.showOptionDialog(null, + "Do you want to save your changes or discard them?", + "Unsaved changes", + JOptionPane.DEFAULT_OPTION, + JOptionPane.WARNING_MESSAGE, null, + List("Save", "Discard", "Cancel").toArray, + "Save") + // scala swing dialogs implementation is dumb, here's what I found : // Result(0) = Save, Result(1) = Discard, Result(2) = Cancel - if (choice == Dialog.Result(2)) false - else if (choice == Dialog.Result(1)) { - MainTabbedPane.pages.clear() + if (choice == 2) false + else if (choice == 1) { + if(specific.nonEmpty){ + for(page <- specific.get) {MainDocumentTabs.remove(page)} + } else { + MainDocumentTabs.clear() + } true } else { val b = trySaveAll() - if (b) MainTabbedPane.pages.clear() + if (b) { + if (specific.nonEmpty) { + for (page <- specific.get) { + MainDocumentTabs.remove(page) + } + } else { + MainDocumentTabs.clear() + } + } b } } else { - MainTabbedPane.pages.clear() + if(specific.nonEmpty){ + for(page <- specific.get) {MainDocumentTabs.remove(page)} + } else { + MainDocumentTabs.clear() + } true } } - def quitQuanto() = { - if (closeAllDocuments()) { + def quitQuanto(): Boolean = { + val close = closeAllOrListOfDocuments() + if (close) { try { //core ! StopCore //core ! PoisonPill } catch { case e : Exception => e.printStackTrace() } + val rect = _mainframe.peer.getBounds() + val isFullScreen : Boolean = _mainframe.peer.getExtendedState() == Frame.MAXIMIZED_BOTH + prefs.putBoolean("fullscreen", isFullScreen) + if (!isFullScreen) { + prefs.putInt("locationx",rect.x) + prefs.putInt("locationy",rect.y) + prefs.putInt("screenwidth",rect.width) + prefs.putInt("screenheight",rect.height) + } true } else { false @@ -274,6 +379,79 @@ object QuantoDerive extends SimpleSwingApplication { // } + def FolderContextMenu(folder: File) : PopupMenu = new PopupMenu { //Context menu for project folders + menu => + + val OpenLocationAction: Action = new Action("Open Folder") { + menu.contents += new MenuItem(this) { + mnemonic = Key.L + } + + def apply() { + Desktop.getDesktop.browse(folder.toURI) + } + } + + } + + def FileContextMenu(file: File): PopupMenu = new PopupMenu { //Context menu for project files + menu => + + val OpenLocationAction: Action = new Action("Open File Location") { + menu.contents += new MenuItem(this) { + mnemonic = Key.L + } + + def apply() { + Desktop.getDesktop.browse(file.getParentFile.toURI) + } + } + + + val DeleteFile: Action = new Action("Delete File") { + menu.contents += new MenuItem(this) { + mnemonic = Key.D + } + + def apply() { + file.delete() + } + } + + (FileHelper.extension(file), MainDocumentTabs.currentContent) match { + case ("qrule", Some(dp: DerivationPanel)) => + val AddToRewrites : Action = new Action("Add to current derivation") { + menu.contents += new MenuItem(this) { + mnemonic = Key.R + } + + def apply() { + alert(s"Publishing request for rule") + if(CurrentProject.nonEmpty){ + val project = CurrentProject.get + val relativePath = project.relativePath(file) + val ruleDesc = RuleDesc(relativePath.substring(0, relativePath.length-".qrule".length)) + dp.publish(SuggestRewriteRule(ruleDesc)) + } + } + } + case _ => + } + } + + + def newGraph(): GraphDocument = { + CurrentProject match { + case Some(project) => + val page = new GraphDocumentPage(project.theory) + addAndFocusPage(page) + page.document.asInstanceOf[GraphDocument] + case None => + throw new NoProjectException + } + } + + object FileMenu extends Menu("File") { menu => mnemonic = Key.F @@ -281,11 +459,7 @@ object QuantoDerive extends SimpleSwingApplication { accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_N, CommandMask)) menu.contents += new MenuItem(this) { mnemonic = Key.G } def apply() { - CurrentProject.foreach{ project => - val page = new GraphDocumentPage(project.theory) - MainTabbedPane += page - MainTabbedPane.selection.index = page.index - } + newGraph() } } @@ -295,8 +469,7 @@ object QuantoDerive extends SimpleSwingApplication { def apply() { CurrentProject.foreach{ project => val page = new RuleDocumentPage(project.theory) - MainTabbedPane += page - MainTabbedPane.selection.index = page.index + addAndFocusPage(page) } } } @@ -319,33 +492,18 @@ object QuantoDerive extends SimpleSwingApplication { def apply() { CurrentProject.foreach{ project => val page = new PythonDocumentPage - MainTabbedPane += page - MainTabbedPane.selection.index = page.index + addAndFocusPage(page) } } } - def updateNewEnabled() { - CurrentProject match { - case Some(_) => - NewGraphAction.enabled = true - NewAxiomAction.enabled = true - //NewMLAction.enabled = true - case None => - NewGraphAction.enabled = false - NewAxiomAction.enabled = false - //NewMLAction.enabled = false - } - } - - updateNewEnabled() val SaveAction = new Action("Save") { accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_S, CommandMask)) enabled = false menu.contents += new MenuItem(this) { mnemonic = Key.S } def apply() { - MainTabbedPane.currentContent match { + MainDocumentTabs.currentContent match { case Some(doc: HasDocument) => doc.document.file match { case Some(_) => doc.document.save() @@ -361,7 +519,7 @@ object QuantoDerive extends SimpleSwingApplication { enabled = false menu.contents += new MenuItem(this) { mnemonic = Key.A } def apply() { - MainTabbedPane.currentContent match { + MainDocumentTabs.currentContent match { case Some(doc: HasDocument) => doc.document.showSaveAsDialog(CurrentProject.map(_.rootFolder)) case _ => @@ -374,9 +532,9 @@ object QuantoDerive extends SimpleSwingApplication { enabled = false menu.contents += new MenuItem(this) { mnemonic = Key.V } def apply() { - val selection = MainTabbedPane.selection.index + val selection = MainDocumentTabs.selection.index trySaveAll() - MainTabbedPane.selection.index = selection + MainDocumentTabs.selection.index = selection } } @@ -386,7 +544,7 @@ object QuantoDerive extends SimpleSwingApplication { menu.contents += new MenuItem(this) { mnemonic = Key.N } def apply() { - if (closeAllDocuments()) { + if (closeAllOrListOfDocuments()) { val d = new NewProjectDialog() d.centerOnScreen() d.open() @@ -395,11 +553,11 @@ object QuantoDerive extends SimpleSwingApplication { println("got: " + (theoryFile, name, path)) val folder = new File(path + "/" + name) if (name.isEmpty) { - UserAlerts.errorbox("Please enter a name for your project.") + error("Please enter a name for your project.") } else if (folder.exists()) { - UserAlerts.errorbox("That folder is already in use.") + error("That folder is already in use.") } else if (theoryFile.isEmpty) { - UserAlerts.errorbox("Please enter a theory file.") + error("Please enter a theory file.") } else { folder.mkdirs() new File(folder.getPath + "/graphs").mkdir() @@ -407,12 +565,13 @@ object QuantoDerive extends SimpleSwingApplication { new File(folder.getPath + "/theorems").mkdir() new File(folder.getPath + "/derivations").mkdir() new File(folder.getPath + "/simprocs").mkdir() + val projectFile = new File(folder.getPath + "/" + name + ".qproject") val rootFolder = folder.getAbsolutePath - val proj = Project.fromTheoryOrProjectFile(theoryFile, rootFolder, name) - Project.toJson(proj).writeTo(new File(folder.getPath + "/main.qproject")) - loadProject(folder.getPath) + val proj = Project.fromTheoryOrProjectFile(new File(theoryFile), new File(rootFolder), name) + Project.toJson(proj).writeTo(projectFile) + loadProject(projectFile.getAbsolutePath) //core ! SetMLWorkingDir(rootFolder) - updateNewEnabled() + refreshAllMenusAndTitle() } case None => } @@ -423,16 +582,16 @@ object QuantoDerive extends SimpleSwingApplication { val OpenProjectAction = new Action("Open Project...") { menu.contents += new MenuItem(this) { mnemonic = Key.O } def apply() { - if (closeAllDocuments()) { + if (closeAllOrListOfDocuments()) { val chooser = new FileChooser() - chooser.fileSelectionMode = FileChooser.SelectionMode.DirectoriesOnly + chooser.fileFilter = new FileNameExtensionFilter("Quantomatic Project File (*.qproject)", "qproject") + chooser.fileSelectionMode = FileChooser.SelectionMode.FilesOnly chooser.showOpenDialog(Split) match { case FileChooser.Result.Approve => - val rootFolder = chooser.selectedFile.toString - val projectFile = new File(rootFolder + "/main.qproject") + val projectFile = new File(chooser.selectedFile.toString) if (projectFile.exists) { try { - loadProject(rootFolder) + loadProject(chooser.selectedFile.toString) //core ! SetMLWorkingDir(rootFolder) } catch { case _: ProjectLoadException => @@ -441,10 +600,10 @@ object QuantoDerive extends SimpleSwingApplication { error("Unexpected error when opening project") e.printStackTrace() } finally { - updateNewEnabled() + refreshAllMenusAndTitle() } } else { - error("Folder does not contain a Quantomatic project") + error(s"Folder does not contain a Quantomatic project: $projectFile") } case _ => } @@ -455,14 +614,15 @@ object QuantoDerive extends SimpleSwingApplication { val CloseProjectAction = new Action("Close Project") { menu.contents += new MenuItem(this) { mnemonic = Key.C } def apply() { - if (closeAllDocuments()) { + if (closeAllOrListOfDocuments()) { ProjectFileTree.root = None CurrentProject = None - updateNewEnabled() + refreshAllMenusAndTitle() } } } + menu.contents += new Separator() val QuitAction = new Action("Quit") { @@ -489,15 +649,17 @@ object QuantoDerive extends SimpleSwingApplication { case Some(doc) => enabled = doc.document.undoStack.canUndo title = "Undo " + doc.document.undoStack.undoActionName.getOrElse("") + //listenTo(doc.document) case None => enabled = false title = "Undo" } - listenTo(MainTabbedPane.selection) + listenTo(MainDocumentTabs.selection) reactions += { - case DocumentChanged(_) => updateUndoCommand() + case DocumentChanged(_) => + updateUndoCommand() case SelectionChanged(_) => currentDocument.foreach { doc => listenTo(doc.document) } updateUndoCommand() @@ -524,7 +686,7 @@ object QuantoDerive extends SimpleSwingApplication { title = "Redo" } - listenTo(MainTabbedPane.selection) + listenTo(MainDocumentTabs.selection) reactions += { case DocumentChanged(_) => updateRedoCommand() @@ -554,13 +716,6 @@ object QuantoDerive extends SimpleSwingApplication { def apply() { currentGraphController.foreach(_.pasteSubgraph()) } } - contents += new Separator - - val SnapToGridAction = new Action("Snap to grid") { - menu.contents += new MenuItem(this) { mnemonic = Key.S } - //accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_G, CommandMask)) - def apply() { currentGraphController.foreach(_.snapToGrid()) } - } // val LayoutAction = new Action("Layout Graph") with Reactor { // accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_L, CommandMask)) @@ -575,53 +730,346 @@ object QuantoDerive extends SimpleSwingApplication { // contents += new MenuItem(LayoutAction) { mnemonic = Key.L } } - val DeriveMenu = new Menu("Derive") { menu => + val RuleMenu = new Menu("Rule") { + menu => + mnemonic = Key.R + + val ColourSwapRule = new Action("Colour Swap") { + accelerator = None + enabled = true + menu.contents += new MenuItem(this) { + mnemonic = Key.C + } + + def apply(): Unit = (CurrentProject, MainDocumentTabs.currentContent) match { + case (Some(project), Some(doc: HasDocument)) => + doc.document match { + case ruleDoc: RuleDocument => + val dialog = new ColourSwapDialog(project.theory) + dialog.centerOnScreen() + dialog.open() + dialog.result + if (dialog.result != project.theory.vertexTypes.keys.map(k => k -> k).toMap) { + val map = dialog.result + warn(map.mkString("Mapping types: ",", ","")) + val page = new RuleDocumentPage(project.theory) + page.document.asInstanceOf[RuleDocument].rule = ruleDoc.rule.colourSwap(map) + addAndFocusPage(page) + } + case _ => + warn("Trying to colour swap a rule but no rule is active") + + } + case _ => // no project and/or document open, do nothing + } + } + + val InvertRule = new Action("Invert Rule") { + accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_I, CommandMask)) + enabled = true + menu.contents += new MenuItem(this) { + mnemonic = Key.I + } + + def apply() = (CurrentProject, MainDocumentTabs.currentContent) match { + case (Some(project), Some(doc: HasDocument)) => + doc.document match { + case (ruleDoc: RuleDocument) => + ruleDoc.rule = ruleDoc.rule.inverse + + case _ => + warn("WARNING: Invert rule called with no rule active") + } + case _ => // no project and/or document open, do nothing + } + } + visible = false + } + + val TheoryMenu = new Menu("Theory") { + menu => + + mnemonic = Key.P + + + val EditTheoryAction = new Action("Alter Theory") { + menu.contents += new MenuItem(this) { + mnemonic = Key.T + } + + def apply() { + CurrentProject.foreach { project => + val page = MainDocumentTabs.documents.find(tp => tp.title == "Theory Editor") match { + case Some(p) => p + case None => + val p = new TheoryPage() + listenTo(p.document) + p.title = "Theory Editor" + addAndFocusPage(p) + p + } + MainDocumentTabs.focus(page) + } + } + } + + + val BatchDerivationAction = new Action("Batch Derivation") { + menu.contents += new MenuItem(this) { + mnemonic = Key.B + } + + def apply() { + CurrentProject.foreach { project => + val page = MainDocumentTabs.documents.find(tp => tp.title == "Batch Derivation") match { + case Some(p) => p + case None => + val p = new BatchDerivationPage() + listenTo(p.document) + p.title = "Batch Derivation" + addAndFocusPage(p) + p + } + MainDocumentTabs.focus(page) + } + } + } + + visible = true + enabled = CurrentProject.nonEmpty + + } + + val GraphMenu = new Menu("Graph") { + menu => + mnemonic = Key.G + val StartDerivation = new Action("Start derivation") { accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_D, CommandMask)) enabled = false - menu.contents += new MenuItem(this) { mnemonic = Key.D } - def apply() = (CurrentProject, MainTabbedPane.currentContent) match { - case (Some(project), Some(doc: HasDocument)) => - doc.document match { - case (graphDoc: GraphDocument) => - val page = new DerivationDocumentPage(project) - MainTabbedPane += page - MainTabbedPane.selection.index = page.index - page.document.asInstanceOf[DerivationDocument].root = graphDoc.graph - - case _ => - System.err.println("WARNING: Start derivation called with no graph active") - } - case _ => // no project and/or document open, do nothing + menu.contents += new MenuItem(this) { + mnemonic = Key.D } + + def apply() = (CurrentProject, MainDocumentTabs.currentContent) match { + case (Some(project), Some(doc: HasDocument)) => + doc.document match { + case (graphDoc: GraphDocument) => + val page = new DerivationDocumentPage(project) + addAndFocusPage(page) + page.document.asInstanceOf[DerivationDocument].root = graphDoc.graph + + case _ => + warn("WARNING: Start derivation called with no graph active") + } + case _ => // no project and/or document open, do nothing + } + } + + val StartRule = new Action("Make into axiom") { + accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_R, CommandMask)) + enabled = false + menu.contents += new MenuItem(this) { + mnemonic = Key.X + } + + def apply() = (CurrentProject, MainDocumentTabs.currentContent) match { + case (Some(project), Some(doc: HasDocument)) => + doc.document match { + case (graphDoc: GraphDocument) => + val page = new RuleDocumentPage(project.theory) + page.document.asInstanceOf[RuleDocument].lhsRef.graph = graphDoc.graph + addAndFocusPage(page) + case _ => + warn("WARNING: Start rule called with no graph active") + } + case _ => // no project and/or document open, do nothing + } + } + + + val ExtractGraph = new Action("Extract selection to new graph") { + enabled = true + menu.contents += new MenuItem(this) { + mnemonic = Key.N + } + + def apply() = (CurrentProject, MainDocumentTabs.currentContent) match { + case (Some(project), Some(gep: GraphEditPanel)) => + gep.document match { + case (graphDoc: GraphDocument) => + val newPage = new GraphDocumentPage(project.theory) + val vertSelection = gep.graphEditController.selectedVerts + if(vertSelection.nonEmpty) { + val inverseSelection = gep.graphEditController.graph.verts -- vertSelection + val snippedGraph = inverseSelection.foldLeft(graphDoc.graph) { + (g, v) => g.cutVertex(v, g.verts.filter(g.isBoundary))._1 + } + newPage.document.asInstanceOf[GraphDocument].graph = snippedGraph + addAndFocusPage(newPage) + } + case _ => + warn("WARNING: Extract selection with no graph active") + } + case (Some(project), Some(rep: RuleEditPanel)) => + rep.document match { + case (ruleDoc: RuleDocument) => + val newPage = new GraphDocumentPage(project.theory) + val vertSelection = rep.focusedController.selectedVerts + if(vertSelection.nonEmpty) { + val inverseSelection = rep.focusedController.graph.verts -- vertSelection + val snippedGraph = inverseSelection.foldLeft(rep.focusedController.graph) { + (g, v) => g.cutVertex(v, g.verts.filter(g.isBoundary))._1 + } + newPage.document.asInstanceOf[GraphDocument].graph = snippedGraph + addAndFocusPage(newPage) + } + case _ => + warn("WARNING: Extract selection with no graph active") + } + case _ => // no project and/or document open, do nothing + } + } + + + val SnapToGrid = new Action("Snap to grid") { + enabled = false + menu.contents += new MenuItem(this) { + mnemonic = Key.S} + + def apply() = { + currentGraphController.foreach(gc => gc.snapToGrid()) + } + + } + + val MinimiseGraph = new Action("Minimise") { + enabled = false + menu.contents += new MenuItem(this) { + mnemonic = Key.M} + + def apply() = { + currentGraphController.foreach(gc => gc.minimiseGraph()) + } + } + visible = false + } + + + val DeriveMenu = new Menu("Derivation") { + menu => + mnemonic = Key.D + val LayoutDerivation = new Action("Layout derivation") { -// accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_L, CommandMask)) + // accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_L, CommandMask)) enabled = false - menu.contents += new MenuItem(this) { mnemonic = Key.L } - def apply() = (CurrentProject, MainTabbedPane.currentContent) match { + menu.contents += new MenuItem(this) { + mnemonic = Key.L + } + + def apply() = (CurrentProject, MainDocumentTabs.currentContent) match { case (Some(project), Some(derivePanel: DerivationPanel)) => derivePanel.controller.layoutDerivation() case _ => // no project and/or derivation open, do nothing } } + + + val ViewGraph = new Action("Extract to new graph") { + enabled = true + menu.contents += new MenuItem(this) { + mnemonic = Key.E + } + + def apply() = (CurrentProject, MainDocumentTabs.currentContent) match { + case (Some(project), Some(dp: DerivationPanel)) => + dp.document match { + case (derivationDoc: DerivationDocument) => + val page = new GraphDocumentPage(project.theory) + val graph = dp.lhsController.graph + page.document.asInstanceOf[GraphDocument].graph = graph + addAndFocusPage(page) + } + case _ => + warn("WARNING: Extract selection with no graph active") + } + } + + + + val ReRunLastSimproc = new Action("Re-run last simproc") { + accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_SPACE, CommandMask)) + enabled = true + menu.contents += new MenuItem(this) { + mnemonic = Key.R + } + + def apply() = (CurrentProject, MainDocumentTabs.currentContent) match { + case (Some(project), Some(dp: DerivationPanel)) => + dp.ReRunLastSimproc() + case _ => + warn("WARNING: Re-run simproc called with no derivation active") + } + } + + + visible = false } + val WindowMenu = new Menu("Window") { menu => + mnemonic = Key.W + val CloseAction = new Action("Close tab") { accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_W, CommandMask)) enabled = false menu.contents += new MenuItem(this) { mnemonic = Key.C } def apply() { - MainTabbedPane.currentContent match { - case Some(doc: HasDocument) => - if (doc.document.promptUnsaved()) MainTabbedPane.pages.remove(MainTabbedPane.selection.index) + MainDocumentTabs.currentFocus match { + case Some(page: DocumentPage) => + if (page.document.promptUnsaved()) MainDocumentTabs.remove(MainDocumentTabs.currentFocus.get) case _ => } } } + val NextTabAction = new Action("Next tab") { + accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_RIGHT, CommandMask)) + enabled = false + menu.contents += new MenuItem(this) {mnemonic = Key.N} + + def apply(): Unit ={ + MainDocumentTabs.cycle() + } + } + + val PreviousTabAction = new Action("Previous tab") { + accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_LEFT, CommandMask)) + enabled = false + menu.contents += new MenuItem(this) {mnemonic = Key.P} + + def apply(): Unit ={ + MainDocumentTabs.cycle(forward = false) + } + } + + val CloseAllAction = new Action("Close all tabs") { + accelerator = None + enabled = false + menu.contents += new MenuItem(this) { + mnemonic = Key.A + } + + def apply() { + MainDocumentTabs.documents.foreach(page => { + MainDocumentTabs.focus(page) + if (page.document.promptUnsaved()) MainDocumentTabs.remove(page) + }) + } + } + contents += new Separator val IncreaseUIScaling = new Action("Increase UI scaling") { @@ -640,7 +1088,9 @@ object QuantoDerive extends SimpleSwingApplication { } val HelpMenu = new Menu("Help") { menu => - val CloseAction = new Action("Quantomatic website") { + mnemonic = Key.H + + val WebsiteAction = new Action("Quantomatic website") { menu.contents += new MenuItem(this) { mnemonic = Key.Q } def apply() { WebHelper.openWebpage("https://quantomatic.github.io/") @@ -650,19 +1100,30 @@ object QuantoDerive extends SimpleSwingApplication { val SimprocAPIAction = new Action("Simproc API") { menu.contents += new MenuItem(this) { mnemonic = Key.S } def apply() { - WebHelper.openWebpage("https://quantomatic.github.io/SimprocAPI.html") + WebHelper.openWebpage("https://quantomatic.github.io#SimprocAPI") } } + //private val project = getClass.getPackage + // val version = project.getImplementationVersion() + // TODO: Implement versioning + val UpdateAction = new Action(s"Get latest version") { + menu.contents += new MenuItem(this) { mnemonic = Key.V } + def apply() { + WebHelper.openWebpage("https://bintray.com/quantomatic/quantomatic/quantomatic/bleeding-edge") + } + } } val ExportMenu = new Menu("Export") { menu => + mnemonic = Key.X + val ExportAction = new Action("Export to LaTeX") { accelerator = Some(KeyStroke.getKeyStroke(KeyEvent.VK_E, CommandMask)) enabled = false menu.contents += new MenuItem(this) { mnemonic = Key.E } def apply() { - MainTabbedPane.currentContent match { + MainDocumentTabs.currentContent match { case Some(doc: HasDocument) => if (doc.document.unsavedChanges) { Dialog.showMessage(title = "Unsaved Changes", @@ -677,55 +1138,25 @@ object QuantoDerive extends SimpleSwingApplication { } } + val StatusBar = new StatusBar() - val UserMessage = new Label(UserAlerts.latestMessage.toString) - val ConsoleProgress = new ProgressBar - val ConsoleProgressLabel = new Label(" ") - val StatusBar = new GridPanel(1, 2) { - contents += new FlowPanel(FlowPanel.Alignment.Left)(UserMessage) - contents += new FlowPanel(FlowPanel.Alignment.Right)(ConsoleProgressLabel, ConsoleProgress) - } - ConsoleProgress.preferredSize = ConsoleProgressSize //Currently doesn't respond to UI scaling - def ConsoleProgressSize: Dimension = new Dimension(UserOptions.scaleInt(100), UserOptions.scaleInt(15)) + listenTo(ProjectFileTree, MainDocumentTabs.selection) - - listenTo(UserAlerts.AlertPublisher) reactions += { - case UserAlerts.UserAlertEvent(alert: UserAlerts.Alert) => - UserMessage.text = alert.toString - UserMessage.foreground = alert.color - case UserAlerts.UserProcessUpdate(_) => - UserAlerts.leastCompleteProcess match { - case Some(process) => if (process.determinate) { - ConsoleProgress.indeterminate = false - ConsoleProgress.value = process.value - } else { - ConsoleProgress.indeterminate = true - } - case _ => ConsoleProgress.value = 100 - } - val ongoing = UserAlerts.ongoingProcesses.filter(op => op.value < 100) - ongoing.count(_ => true) match { - case 0 => ConsoleProgressLabel.text = " " //keep non-empty so the progressbar stays in line with text - case 1 => ConsoleProgressLabel.text = ongoing.head.name - case n => ConsoleProgressLabel.text = n.toString + " processes ongoing" + case PageClosed(p : DocumentPage) => + MainDocumentTabs.remove(p) + case FileContextRequested(file, e) => + if(file.isDirectory){ + popup(FolderContextMenu(file), e) + } else { + popup(FileContextMenu(file), e) } - } - - val Main = new BorderPanel { - add(Split, BorderPanel.Position.Center) - add(StatusBar, BorderPanel.Position.South) - } - - listenTo(ProjectFileTree, MainTabbedPane.selection) - - reactions += { case FileOpened(file) => CurrentProject match { case Some(project) => - val existingPage = MainTabbedPane.pages.find { p => + val existingPage = MainDocumentTabs.documents.find { p => p.content match { case doc : HasDocument => doc.document.file.exists(_.getPath == file.getPath) case _ => false @@ -734,7 +1165,7 @@ object QuantoDerive extends SimpleSwingApplication { existingPage match { case Some(p) => - MainTabbedPane.selection.index = p.index + MainDocumentTabs.focus(p) case None => val extn = file.getName.lastIndexOf('.') match { case i if i > 0 => file.getName.substring(i+1) ; case _ => ""} @@ -745,15 +1176,15 @@ object QuantoDerive extends SimpleSwingApplication { case "qderive" => Some(new DerivationDocumentPage(project)) case "py" => Some(new PythonDocumentPage) case "ML" => Some(new MLDocumentPage) + case "qsbr" => Some(new BatchDerivationResultsPage) case _ => None } - pageOpt.map{ page => - MainTabbedPane += page - MainTabbedPane.selection.index = page.index + pageOpt.foreach{ page => + addAndFocusPage(page) if (!page.document.load(file)) { - MainTabbedPane.pages -= page + MainDocumentTabs.remove(page) } } } @@ -761,25 +1192,67 @@ object QuantoDerive extends SimpleSwingApplication { } case SelectionChanged(_) => + refreshAllMenus() + } + +// val versionResp = core ? Call("!!", "system", "version") +// versionResp.onSuccess { case Success(JsonString(version)) => +// Swing.onEDT { CoreStatus.text = "OK"; CoreStatus.foreground = new Color(0,150,0) } +// } + + private def refreshAllMenusAndTitle(): Unit = { + refreshAllMenus() + refreshTitle() + } + + private def refreshAllMenus(): Unit = { + try { FileMenu.SaveAction.enabled = false FileMenu.SaveAsAction.enabled = false FileMenu.SaveAllAction.enabled = false + CurrentProject match { + case Some(_) => + FileMenu.NewGraphAction.enabled = true + FileMenu.NewAxiomAction.enabled = true + case None => + FileMenu.NewGraphAction.enabled = false + FileMenu.NewAxiomAction.enabled = false + } + TheoryMenu.visible = true + TheoryMenu.enabled = CurrentProject.nonEmpty + TheoryMenu.EditTheoryAction.enabled = CurrentProject.nonEmpty + TheoryMenu.BatchDerivationAction.enabled = CurrentProject.nonEmpty EditMenu.CutAction.enabled = false EditMenu.CopyAction.enabled = false EditMenu.PasteAction.enabled = false - EditMenu.SnapToGridAction.enabled = false - DeriveMenu.StartDerivation.enabled = false + RuleMenu.visible = false + RuleMenu.InvertRule.enabled = false + GraphMenu.visible = false + GraphMenu.StartDerivation.enabled = false + GraphMenu.SnapToGrid.enabled = false + GraphMenu.MinimiseGraph.enabled = false + GraphMenu.StartRule.enabled = false + GraphMenu.ExtractGraph.enabled = false + DeriveMenu.visible = false DeriveMenu.LayoutDerivation.enabled = false + DeriveMenu.ViewGraph.enabled = false + DeriveMenu.ReRunLastSimproc.enabled = false WindowMenu.CloseAction.enabled = false + WindowMenu.PreviousTabAction.enabled = false + WindowMenu.NextTabAction.enabled = false + WindowMenu.CloseAllAction.enabled = false ExportMenu.ExportAction.enabled = false histView = None FileMenu.SaveAction.title = "Save" FileMenu.SaveAsAction.title = "Save As..." - MainTabbedPane.currentContent match { + MainDocumentTabs.currentContent match { case Some(content: HasDocument) => WindowMenu.CloseAction.enabled = true + WindowMenu.CloseAllAction.enabled = true + WindowMenu.NextTabAction.enabled = MainDocumentTabs.size > 1 + WindowMenu.PreviousTabAction.enabled = MainDocumentTabs.size > 1 FileMenu.SaveAction.enabled = true FileMenu.SaveAsAction.enabled = true FileMenu.SaveAllAction.enabled = true @@ -793,46 +1266,76 @@ object QuantoDerive extends SimpleSwingApplication { EditMenu.CutAction.enabled = true EditMenu.CopyAction.enabled = true EditMenu.PasteAction.enabled = true - EditMenu.SnapToGridAction.enabled = true - DeriveMenu.StartDerivation.enabled = true + GraphMenu.visible = true + GraphMenu.StartDerivation.enabled = true + GraphMenu.StartRule.enabled = true + GraphMenu.SnapToGrid.enabled = true + GraphMenu.MinimiseGraph.enabled = true + GraphMenu.ExtractGraph.enabled = true ExportMenu.ExportAction.enabled = true case panel: RuleEditPanel => EditMenu.CutAction.enabled = true EditMenu.CopyAction.enabled = true EditMenu.PasteAction.enabled = true - EditMenu.SnapToGridAction.enabled = true ExportMenu.ExportAction.enabled = true + RuleMenu.visible = true + RuleMenu.InvertRule.enabled = true + GraphMenu.visible = true + GraphMenu.SnapToGrid.enabled = true + GraphMenu.MinimiseGraph.enabled = true + GraphMenu.ExtractGraph.enabled = true case panel: DerivationPanel => - DeriveMenu.LayoutDerivation.enabled = true ExportMenu.ExportAction.enabled = true histView = Some(panel.histView) + DeriveMenu.visible = true + DeriveMenu.LayoutDerivation.enabled = true + DeriveMenu.ViewGraph.enabled = true + DeriveMenu.ReRunLastSimproc.enabled = true case _ => // nothing else enabled for ML } case _ => // leave everything disabled } + } catch { + case _: NullPointerException => + // Null Pointer Exception thrown when accessing GUI too early + } } -// val versionResp = core ? Call("!!", "system", "version") -// versionResp.onSuccess { case Success(JsonString(version)) => -// Swing.onEDT { CoreStatus.text = "OK"; CoreStatus.foreground = new Color(0,150,0) } -// } + // The highest level GUI contents + val Main = new BorderPanel { + add(Split, BorderPanel.Position.Center) + add(StatusBar.Status, BorderPanel.Position.South) + } + + + val _mainframe = new MainFrame { - def top = new MainFrame { - override def title : String = { - if (CurrentProject.isEmpty) {"Quantomatic"} else { + def refreshTitle() : Unit = { + // Setting the iconImage here isn't working on Windows + // iconImage = ImageIO.read(getClass.getResource("quantoderive.ico")) + title = if (CurrentProject.isEmpty) {"Quantomatic"} else { CurrentProject.get.name match { case "" => "Quantomatic" - case s => "Quantomatic - $s" + case s => s"Quantomatic - $s" } } } contents = Main - size = new Dimension(1280,720) + + if (prefs.getBoolean("fullscreen",false)) { + peer.setExtendedState(peer.getExtendedState() | Frame.MAXIMIZED_BOTH) + } + else { + size = new Dimension(prefs.getInt("screenwidth",1280),prefs.getInt("screenheight",720)) + peer.setLocation(prefs.getInt("locationx",300),prefs.getInt("locationy",300)) + } + peer.setVisible(true) + menuBar = new MenuBar { - contents += (FileMenu, EditMenu, DeriveMenu, WindowMenu, ExportMenu, HelpMenu) + contents += (FileMenu, TheoryMenu, EditMenu, DeriveMenu, RuleMenu, GraphMenu, WindowMenu, ExportMenu, HelpMenu) } import javax.swing.WindowConstants.DO_NOTHING_ON_CLOSE @@ -842,4 +1345,17 @@ object QuantoDerive extends SimpleSwingApplication { if (quitQuanto()) scala.sys.exit(0) } } + + def top = _mainframe + + def refreshTitle(): Unit = { + try { + top.refreshTitle() + } catch { + case _: NullPointerException => + // Null Pointer Exception thrown when accessing GUI too early + } + } + + refreshAllMenusAndTitle() } diff --git a/scala/src/main/scala/quanto/gui/RewriteController.scala b/scala/src/main/scala/quanto/gui/RewriteController.scala index 1421e846..95415d74 100644 --- a/scala/src/main/scala/quanto/gui/RewriteController.scala +++ b/scala/src/main/scala/quanto/gui/RewriteController.scala @@ -3,29 +3,35 @@ package quanto.gui import quanto.data._ import quanto.data.Names._ import quanto.rewrite._ -import scala.concurrent.{Future, Lock} +import java.util.concurrent.locks.ReentrantLock + +import scala.concurrent.Future import scala.swing._ import scala.swing.event._ import scala.swing.event.ButtonClicked -import scala.util.{Success,Failure} +import scala.util.{Failure, Success} import quanto.util.json._ + import scala.concurrent.ExecutionContext.Implicits.global import java.io.File -import quanto.layout.ForceLayout + +import quanto.data.Theory.ValueType +import quanto.util.UserAlerts +import quanto.util.UserAlerts.Elevation class RewriteController(panel: DerivationPanel) extends Publisher { implicit val timeout = QuantoDerive.timeout var queryId = 0 - val resultLock = new Lock + val resultLock = new ReentrantLock() var resultSet = ResultSet(Vector()) def theory = panel.project.theory class ResultGraphRef(rule: RuleDesc, i: Int) extends HasGraph { protected def gr_=(g: Graph) { - resultLock.acquire() + resultLock.lock() resultSet = resultSet.replaceGraph(rule, i, g) - resultLock.release() + resultLock.unlock() } protected def gr = resultSet.graph(rule, i) @@ -33,7 +39,7 @@ class RewriteController(panel: DerivationPanel) extends Publisher { def rules = resultSet.rules def rules_=(rules: Vector[RuleDesc]) { - resultLock.acquire() + resultLock.lock() resultSet = ResultSet(rules) queryId += 1 @@ -42,21 +48,18 @@ class RewriteController(panel: DerivationPanel) extends Publisher { else panel.LhsView.graph.verts for (rd <- rules) { - val rule = Rule.fromJson(Json.parse(new File(panel.project.rootFolder + "/" + rd.name + ".qrule")), theory) - val ms = Matcher.initialise(if (rd.inverse) rule.rhs else rule.lhs, panel.LhsView.graph, sel) - pullRewrite(ms, rd, rule) - -// QuantoDerive.core ? Call(theory.coreName, "rewrite", "find_rewrites", -// JsonObject( -// "rule" -> Rule.toJson(if (rd.inverse) rule.inverse else rule, theory), -// "graph" -> Graph.toJson(panel.LhsView.graph, theory), -// "vertices" -> JsonArray(sel.toVector.map(v => JsonString(v.toString))) -// )) -// resp.map { case Success(JsonString(stack)) => pullRewrite(queryId, rd, stack); case _ => } + try { + val rule = Rule.fromJson(Json.parse(new File(panel.project.rootFolder + "/" + rd.name + ".qrule")), theory) + val ms = Matcher.initialise(if (rd.inverse) rule.rhs else rule.lhs, panel.LhsView.graph, sel) + pullRewrite(ms, rd, rule) + } catch { + case RuleLoadException(message, _) => + UserAlerts.alert(s"Could not load ${rd.name}, error: $message", Elevation.WARNING) + } } - resultLock.release() - + resultLock.unlock() + refreshRewriteDisplay(clearSelection = true) } @@ -111,10 +114,9 @@ class RewriteController(panel: DerivationPanel) extends Publisher { name = DSName("s"), ruleName = rd.name, rule = rule1, - variant = if (rd.inverse) RuleInverse else RuleNormal, graph = graph1.minimise).layout - resultLock.acquire() + resultLock.lock() // make sure this rewrite query is still in progress, and the rule hasn't been manually removed by the user if (resultSet.rules.contains(rd)) { @@ -125,7 +127,7 @@ class RewriteController(panel: DerivationPanel) extends Publisher { } } - resultLock.release() + resultLock.unlock() refreshRewriteDisplay() case Success(None) => // out of matches case Failure(t) => println("An error occurred in the matcher: " + t.getMessage) @@ -135,7 +137,7 @@ class RewriteController(panel: DerivationPanel) extends Publisher { def refreshRewriteDisplay(clearSelection: Boolean = false) { Swing.onEDT { - resultLock.acquire() + resultLock.lock() if (clearSelection) { panel.ManualRewritePane.PreviousResultButton.enabled = false @@ -152,10 +154,27 @@ class RewriteController(panel: DerivationPanel) extends Publisher { } } - resultLock.release() + resultLock.unlock() } } + def promptForVariableSpecification(strings: Set[(ValueType, String)]): + Map[(ValueType, String), String] = { + val d = new SpecifyVariablesDialog(strings.toList.sortBy(_._2)) + d.centerOnScreen() + d.open() + d.result + } + + def removeSelectedRulesFromList(): Unit = { + + resultLock.lock() + panel.ManualRewritePane.Rewrites.selection.items.foreach { line => resultSet -= line.rule } + resultLock.unlock() + refreshRewriteDisplay() + } + + def selectedRule = if (panel.ManualRewritePane.Rewrites.selection.items.length == 1) Some(panel.ManualRewritePane.Rewrites.selection.items(0).asInstanceOf[ResultLine].rule) @@ -165,8 +184,20 @@ class RewriteController(panel: DerivationPanel) extends Publisher { listenTo(panel.ManualRewritePane.PreviousResultButton, panel.ManualRewritePane.NextResultButton) listenTo(panel.ManualRewritePane.ApplyButton) listenTo(panel.ManualRewritePane.Rewrites.selection) + listenTo(panel) + listenTo(panel.ManualRewritePane.Rewrites.keys) reactions += { + case KeyPressed(_, Key.Delete | Key.BackSpace, _, _) => + if (panel.ManualRewritePane.Rewrites.hasFocus) { + removeSelectedRulesFromList() + } + case SuggestRewriteRule(ruleDesc) => + val currentRules = rules.toSet + val newRules = Set(ruleDesc).filter(!currentRules.contains(_)) + + if (newRules.nonEmpty) rules ++= newRules + rules = rules.sortBy(r => r.name) case ButtonClicked(panel.ManualRewritePane.AddRuleButton) => val d = new AddRuleDialog(panel.project) d.centerOnScreen() @@ -178,12 +209,9 @@ class RewriteController(panel: DerivationPanel) extends Publisher { if (newRules.nonEmpty) rules ++= newRules rules = rules.sortBy(r => r.name) case ButtonClicked(panel.ManualRewritePane.RemoveRuleButton) => - resultLock.acquire() - panel.ManualRewritePane.Rewrites.selection.items.foreach { line => resultSet -= line.rule } - resultLock.release() - refreshRewriteDisplay() + removeSelectedRulesFromList() case ButtonClicked(panel.ManualRewritePane.PreviousResultButton) => - resultLock.acquire() + resultLock.lock() selectedRule match { case Some(rd) => resultSet = resultSet.previousResult(rd) @@ -191,10 +219,10 @@ class RewriteController(panel: DerivationPanel) extends Publisher { { case 0 => panel.DummyRef ; case i => new ResultGraphRef(rd, i) } case None => } - resultLock.release() + resultLock.unlock() refreshRewriteDisplay() case ButtonClicked(panel.ManualRewritePane.NextResultButton) => - resultLock.acquire() + resultLock.lock() selectedRule match { case Some(rd) => resultSet = resultSet.nextResult(rd) @@ -202,18 +230,51 @@ class RewriteController(panel: DerivationPanel) extends Publisher { { case 0 => panel.DummyRef ; case i => new ResultGraphRef(rd, i) } case None => } - resultLock.release() + resultLock.unlock() refreshRewriteDisplay() case ButtonClicked(panel.ManualRewritePane.ApplyButton) => - selectedRule.foreach { rd => resultSet.currentResult(rd).map { step => + selectedRule.foreach { rd => resultSet.currentResult(rd).foreach { step => val parentOpt = panel.controller.state.step - val stepFr = step.copy(name = panel.derivation.steps.freshWithSuggestion(DSName(rd.name.replaceFirst("^.*\\/", "") + "-0"))) panel.ManualRewritePane.Preview.graphRef = panel.DummyRef + // Check to see if new variables were introduced + val oldHead : Option[Graph] = parentOpt.map(panel.derivation.steps(_).graph) + val newHead : Graph = stepFr.graph + + val oldVariables: Set[(ValueType, String)] = oldHead match { + case Some(head) => Graph.variablesUsedWithType(theory, head) + case None => Set() + } + val newVariables: Set[(ValueType, String)] = Graph.variablesUsedWithType(theory, newHead) -- oldVariables + val subs: Map[(ValueType, String), String] = + if (newVariables.nonEmpty) { + promptForVariableSpecification(newVariables) + } else { + Map() + } + val graphWithReplacements: Graph = if (newVariables.nonEmpty) { + newHead.updateAllVData { + case node: NodeV => + node.newValue(node.phaseData.substSubVariables(subs).toString) + case vd => + vd + } + } else { + newHead + } + + val stepWithReplacements: DStep = stepFr.copy(graph = graphWithReplacements) + + val description: String = if (subs.isEmpty) { + "" + } else { + subs.mkString(", ") + } + panel.document.undoStack.start("Apply rewrite") - panel.controller.replaceDerivation(panel.derivation.addStep(parentOpt, stepFr), "") - panel.controller.state = HeadState(Some(stepFr.name)) + panel.controller.replaceDerivation(panel.derivation.addStep(parentOpt, stepWithReplacements), description) + panel.controller.state = HeadState(Some(stepWithReplacements.name)) panel.document.undoStack.commit() }} diff --git a/scala/src/main/scala/quanto/gui/RuleDocument.scala b/scala/src/main/scala/quanto/gui/RuleDocument.scala index 2ef9e3f0..9da17b81 100644 --- a/scala/src/main/scala/quanto/gui/RuleDocument.scala +++ b/scala/src/main/scala/quanto/gui/RuleDocument.scala @@ -1,10 +1,13 @@ package quanto.gui import java.io.File + import quanto.data._ import quanto.util.json.Json + import scala.swing.Component import quanto.util.FileHelper.printToFile +import quanto.util.UserAlerts class RuleDocument(val parent: Component, theory: Theory) extends Document { val description = "Rule" @@ -20,12 +23,13 @@ class RuleDocument(val parent: Component, theory: Theory) extends Document { } def rule = Rule(lhsRef.graph, rhsRef.graph, derivation) - def rule_=(r: Rule) { - lhsRef.graph = r.lhs - rhsRef.graph = r.rhs - derivation = r.derivation + def rule_=(newRule: Rule) { + lhsRef.graph = newRule.lhs + rhsRef.graph = newRule.rhs + derivation = newRule.derivation lhsRef.publish(GraphReplaced(lhsRef, clearSelection = true)) rhsRef.publish(GraphReplaced(rhsRef, clearSelection = true)) + publish(DocumentChanged(this)) } private var storedRule: Rule = Rule(Graph(theory), Graph(theory)) diff --git a/scala/src/main/scala/quanto/gui/RuleEditPanel.scala b/scala/src/main/scala/quanto/gui/RuleEditPanel.scala index 2bbb09c6..a5793ad9 100644 --- a/scala/src/main/scala/quanto/gui/RuleEditPanel.scala +++ b/scala/src/main/scala/quanto/gui/RuleEditPanel.scala @@ -1,9 +1,10 @@ package quanto.gui import quanto.gui.graphview.GraphView -import quanto.data.{HasGraph, Theory, Graph} -import scala.swing.{GridPanel, BorderPanel, ScrollPane} -import scala.swing.event.UIElementResized +import quanto.data.{Graph, HasGraph, Theory} + +import scala.swing.{BorderPanel, GridPanel, ScrollPane} +import scala.swing.event.{MouseClicked, UIElementResized} class RuleEditPanel(val theory: Theory, val readOnly: Boolean = false) extends BorderPanel @@ -19,11 +20,10 @@ with HasDocument val lhsController = new GraphEditController(lhsView, document.undoStack, readOnly) lhsController.controlsOpt = Some(controls) - val rhsController = new GraphEditController(rhsView, document.undoStack, readOnly) rhsController.controlsOpt = Some(controls) - def focusedController = if (rhsView.hasFocus) rhsController else lhsController + def focusedController: GraphEditController = if (rhsView.hasFocus) rhsController else lhsController val LhsScrollPane = new ScrollPane(lhsView) val RhsScrollPane = new ScrollPane(rhsView) @@ -50,8 +50,19 @@ with HasDocument case UIElementResized(RhsScrollPane) => rhsView.resizeViewToFit() rhsView.repaint() + case MouseStateChanged(RequestFocusOnGraph()) => + focusedController.focusOnGraph() + case MouseStateChanged(RequestMinimiseGraph()) => + focusedController.minimiseGraph() + case MouseStateChanged(RelaxToolDown()) => focusedController.startRelaxGraph(true) + case MouseStateChanged(RelaxToolUp()) => focusedController.endRelaxGraph() case MouseStateChanged(m) => + lhsController.endRelaxGraph() + rhsController.endRelaxGraph() lhsController.mouseState = m rhsController.mouseState = m + case DocumentRequestingNaturalFocus(_) => + lhsView.requestFocus() } + } diff --git a/scala/src/main/scala/quanto/gui/ScalaEditPanel.hide b/scala/src/main/scala/quanto/gui/ScalaEditPanel.hide deleted file mode 100644 index 8b7b49f3..00000000 --- a/scala/src/main/scala/quanto/gui/ScalaEditPanel.hide +++ /dev/null @@ -1,122 +0,0 @@ -package quanto.gui - -import scala.reflect.runtime.universe._ -import scala.tools.reflect.ToolBox -import scala.swing._ -import org.gjt.sp.jedit.{Registers, Mode} -import org.gjt.sp.jedit.textarea.StandaloneTextArea -import java.awt.{Color, BorderLayout} -import java.awt.event.{KeyEvent, KeyAdapter} -import javax.swing.ImageIcon -import quanto.util.swing.ToolBar -import scala.swing.event.ButtonClicked -import quanto.util._ -import java.io.{File, PrintStream} - -class ScalaEditPanel extends BorderPanel with HasDocument { - val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask - - val sml = new Mode("StandardML") - - val scalaModeXml = if (Globals.isBundle) new File("scala.xml").getAbsolutePath - else getClass.getResource("scala.xml").getPath - sml.setProperty("file", scalaModeXml) - //println(sml.getProperty("file")) - val scalaCode = StandaloneTextArea.createTextArea() - //mlCode.setFont(new Font("Menlo", Font.PLAIN, 14)) - - val buf = new JEditBuffer1 - buf.setMode(sml) - - var scalaThread : Thread = null - - scalaCode.setBuffer(buf) - - // mlCode.addKeyListener(new KeyAdapter { - // override def keyPressed(e: KeyEvent) { - // if (e.getModifiers == CommandMask) e.getKeyChar match { - // case 'x' => Registers.cut(mlCode, '$') - // case 'c' => Registers.copy(mlCode, '$') - // case 'v' => Registers.paste(mlCode, '$') - // case _ => - // } - // } - // }) - - val document = new CodeDocument("Scala Code", "scala", this, scalaCode) - - val toolbox = runtimeMirror(getClass.getClassLoader).mkToolBox() //currentMirror.mkToolBox() - - - val textPanel = new BorderPanel { - peer.add(scalaCode, BorderLayout.CENTER) - } - - val RunButton = new Button() { - icon = new ImageIcon(GraphEditor.getClass.getResource("start.png"), "Run scala code") - tooltip = "Run Scala" - } - - val InterruptButton = new Button() { - icon = new ImageIcon(GraphEditor.getClass.getResource("stop.png"), "Interrupt execution") - tooltip = "Interrupt execution" - } - - val ScalaToolbar = new ToolBar { - contents += (RunButton, InterruptButton) - } - - val outputTextArea = new TextArea() - outputTextArea.editable = false - val textOut = new TextAreaOutputStream(outputTextArea) - - val scalaOutput = new PrintStream(new TextAreaOutputStream(outputTextArea)) - - add(ScalaToolbar, BorderPanel.Position.North) - - object Split extends SplitPane { - orientation = Orientation.Horizontal - contents_=(textPanel, new ScrollPane(outputTextArea)) - } - - add(Split, BorderPanel.Position.Center) - - listenTo(RunButton, InterruptButton) - - reactions += { - case ButtonClicked(RunButton) => - if (scalaThread == null) { - QuantoDerive.CoreStatus.text = "Running scala code" - QuantoDerive.CoreStatus.foreground = Color.BLUE - - scalaThread = new Thread(new Runnable { - def run() { - try { - val tree = toolbox.parse(scalaCode.getBuffer.getText()) - toolbox.eval(tree) - QuantoDerive.CoreStatus.text = "Scala compiled sucessfully" - QuantoDerive.CoreStatus.foreground = new Color(0, 150, 0) - } catch { - case e : Throwable => - QuantoDerive.CoreStatus.text = "Error in scala code" - QuantoDerive.CoreStatus.foreground = Color.RED - Swing.onEDT { e.printStackTrace(scalaOutput) } - } finally { - scalaThread = null - } - } - }) - scalaThread.start() - - } else { - QuantoDerive.CoreStatus.text = "Scala already running" - QuantoDerive.CoreStatus.foreground = Color.RED - } - - case ButtonClicked(InterruptButton) => - if (scalaThread != null) { - scalaThread.interrupt() - scalaThread = null - } - } -} diff --git a/scala/src/main/scala/quanto/gui/SimplifyBuiltInController.scala b/scala/src/main/scala/quanto/gui/SimplifyBuiltInController.scala deleted file mode 100644 index 6e89eba2..00000000 --- a/scala/src/main/scala/quanto/gui/SimplifyBuiltInController.scala +++ /dev/null @@ -1,368 +0,0 @@ -package quanto.gui - -import java.io.File - -import scala.swing._ -import quanto.core._ -import quanto.data._ -import quanto.data.Names._ -import quanto.util.json._ -import akka.pattern.ask - -import scala.concurrent.Future -import scala.concurrent.ExecutionContext.Implicits.global -import scala.swing.event.ButtonClicked -import quanto.cosy.SimplificationProcedure -import quanto.cosy.GraphAnalysis -import quanto.data.Derivation.DerivationWithHead - -import scala.util.Random - - -class SimplifyBuiltInController(panel: DerivationPanel) extends Publisher { - implicit val timeout = QuantoDerive.timeout - private var simpId = 0 // incrementing the simpId will (lazily) cancel any pending simplification jobs - - listenTo(panel.SimplifyBuiltInPane.SimplifyButton) - - def refreshSimprocs() { - simpId += 1 - // val res = QuantoDerive.core ? Call(theory.coreName, "simplify", "list") - // res.map { - // case Success(JsonArray(procs)) => - // Swing.onEDT { panel.SimplifyBuiltInPane.Simprocs.listData = procs.map(_.stringValue) } - // case r => println("ERROR: Unexpected result from core: " + r) // TODO: errror dialogs - // } - } - - def theory = panel.theory - - private def pullSimp(simproc: String, sid: Int, stack: String, parentOpt: Option[DSName]) { - // if (simpId == sid) { - // val res = QuantoDerive.core ? Call(theory.coreName, "simplify", "pull_next_step", - // JsonObject("stack" -> JsonString(stack))) - // - // res.map { - // case Success(JsonNull) => // out of steps - // Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false } - // simpId += 1 - // case Success(json) => - // if (simpId == sid) { - // val suggest = simproc + "-" + (json / "rule_name").stringValue.replaceFirst("^.*\\/", "") + "-0" - // val sname = panel.derivation.steps.freshWithSuggestion(DSName(suggest)) - // val step = DStep.fromJson(sname, json, theory).layout - // - // Swing.onEDT { - // panel.document.derivation = panel.document.derivation.addStep(parentOpt, step) - // panel.controller.state = HeadState(Some(step.name)) - // pullSimp(simproc, sid, stack, Some(sname)) - // } - // - // } else { - // Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false } - // QuantoDerive.core ! Call(theory.coreName, "simplify", "delete_stack", - // JsonObject("stack" -> JsonString(stack))) - // } - // case _ => println("ERROR: Unexpected result from core: " + res) // TODO: errror dialogs - // } - // } else { - // Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false } - // QuantoDerive.core ! Call(theory.coreName, "simplify", "delete_stack", - // JsonObject("stack" -> JsonString(stack))) - // } - } - - private def evaluateSimproc(): Unit = { - val d = new EvaluationInputPanel(panel.project) - d.centerOnScreen() - d.open() - - val targetString: String = d.TargetText.text.replaceAll(raw"\\", raw"\\\\") - val replacementString: String = d.ReplacementText.text.replaceAll(raw"\\", raw"\\\\") - val initialDerivation = (panel.derivation, panel.controller.state.step) - - import quanto.cosy.SimplificationProcedure.Evaluation._ - - if (targetString.length > 0) { - - val initialState: State = new State( - List(), - 0, - Some(initialDerivation.verts.size), - new Random(), - initialDerivation.verts.toList, - targetString, - replacementString, - Some(initialDerivation.verts.size) - ) - val simproc = new SimplificationProcedure[State]( - initialDerivation, - initialState, - step, - progress, - (der, state) => state.currentStep == state.maxSteps.get - ) - val evluationProgressController = new SimprocProgress[State]( - panel.project, "Evaluation", simproc - ) - evluationProgressController.centerOnScreen() - evluationProgressController.open() - updateDerivation(evluationProgressController.returningDerivation, "evaluation") - } - } - - private def annealSimproc(): Unit = { - val d = new SimulatedAnnealingDialog(panel.project) - d.centerOnScreen() - d.open() - - val timeSteps = d.MainPanel.TimeSteps.text.toInt - val vertexLimit = d.MainPanel.vertexLimit() - if (timeSteps > 0) { -import quanto.cosy.SimplificationProcedure.Annealing._ - val initialState: State = new State(allowedRules, 0, Some(timeSteps), new Random(), 3, vertexLimit) - val simproc = new SimplificationProcedure[State]( - (panel.derivation, panel.controller.state.step), - initialState, - step, - progress, - (_, state) => state.currentStep == state.maxSteps.get - ) - val simulatedAnnealingController = new SimprocProgress[State]( - panel.project, "Simulated Annealing", simproc - ) - simulatedAnnealingController.centerOnScreen() - simulatedAnnealingController.open() - updateDerivation(simulatedAnnealingController.returningDerivation, "annealing reduce") - } - } - - refreshSimprocs() - - private def pullErrorsSimproc(): Unit = { - val initialDerivation = (panel.derivation, panel.controller.state.step) - val graph = Derivation.derivationHeadPairToGraph(initialDerivation) - val boundaries = graph.verts.filter(v => graph.vdata(v).isBoundary) - val d = new SimpleSelectionPanel(panel.project, - "Select target boundaries:", - boundaries.toList.sorted.map(_.toString)) - d.centerOnScreen() - d.open() - - val targets = d.MainPanel.OptionList.selection.items.map(s => VName(s)).toList - - val e = new SimpleSelectionPanel(panel.project, - "Select greedy rules:", - allowedRules.map(_.description.name)) - e.centerOnScreen() - e.open() - - val greedyRules = e.MainPanel.OptionList.selection.items.map(s => ruleFromDesc(RuleDesc(s))).toList - - - - println(targets) - import quanto.cosy.SimplificationProcedure.PullErrors - if (targets.nonEmpty && allowedRules.nonEmpty) { - - val initialState: PullErrors.State = PullErrors.State( - allowedRules.filterNot(r => greedyRules.contains(r)), - 0, - None, - new Random(), - PullErrors.errorsDistance(targets.toSet), - greedyRules = Some(greedyRules), - currentDistance = None, - heldVertices = None, - vertexLimit = None - ) - - - val simplificationProcedure = new quanto.cosy.SimplificationProcedure[PullErrors.State]( - initialDerivation, - initialState, - PullErrors.step, - PullErrors.progress, - (_, state) => state.currentStep == state.maxSteps.getOrElse(-1) || state.currentDistance.getOrElse(2.0) < 1 - ) - val progressController = new SimprocProgress[PullErrors.State]( - panel.project, "Pull Errors Through", simplificationProcedure - ) - progressController.centerOnScreen() - progressController.open() - updateDerivation(progressController.returningDerivation, "pull errors") - } - } - - - private def pullSpecialsSimproc(): Unit = { - val initialDerivation = (panel.derivation, panel.controller.state.step) - val graph = Derivation.derivationHeadPairToGraph(initialDerivation) - val boundaries = graph.verts.filter(v => graph.vdata(v).isBoundary) - val d = new SimpleSelectionPanel(panel.project, - "Select target boundaries:", - boundaries.toList.sorted.map(_.toString)) - d.centerOnScreen() - d.open() - - val e = new SimpleSelectionPanel(panel.project, - "Select vertices to hold in place:", - graph.verts.toList.sorted.map(_.toString)) - e.centerOnScreen() - e.open() - - val targets = d.MainPanel.OptionList.selection.items.map(s => VName(s)).toList - val specials = e.MainPanel.OptionList.selection.items.map(s => VName(s)).toList - import quanto.cosy.SimplificationProcedure.LTEByWeight._ - println(targets) - if (targets.nonEmpty) { - val initialState: State = State( - allowedRules, - 0, - None, - new Random(), - quanto.cosy.GraphAnalysis.distanceSpecialFromEnds(specials)(targets), - None, - heldVertices = Some(specials.toSet), - None - ) - val simproc = new SimplificationProcedure[State]( - (panel.derivation, panel.controller.state.step), - initialState, - step, - progress, - (_, state) => state.currentStep == state.maxSteps.getOrElse(-1) || state.currentDistance.getOrElse(1) == 0 - ) - val progressController = new SimprocProgress[State]( - panel.project, "Pull Errors Through", simproc - ) - progressController.centerOnScreen() - progressController.open() - updateDerivation(progressController.returningDerivation, "pull errors") - } - } - - private def greedySimproc(): Unit = { - import quanto.cosy.SimplificationProcedure.Greedy._ - val initialState: State = new State( - allowedRules, - 0, - None, - new Random(), - allowedRules, - None) - val simplificationProcedure = new SimplificationProcedure[State]( - (panel.derivation, panel.controller.state.step), - initialState, - step, - progress, - (_, state) => state.currentStep == state.maxSteps.getOrElse(-1) || state.remainingRules.isEmpty - ) - val progressController = new SimprocProgress[State]( - panel.project, "Greedy Reduction", simplificationProcedure - ) - progressController.centerOnScreen() - progressController.open() - updateDerivation(progressController.returningDerivation, "greedy reduce") - } - - private def lteSimproc(): Unit = { - import quanto.cosy.SimplificationProcedure.LTEByWeight._ - val initialState: State = State( - rules = allowedRules, - currentStep = 0, - currentDistance = None, - maxSteps = Some(100), - seed = new Random(), - weightFunction = g => Some(g.verts.size + g.edges.size), - heldVertices = None, - vertexLimit = None) - val simproc = new SimplificationProcedure[State]( - (panel.derivation, panel.controller.state.step), - initialState, - step, - progress, - (_, state) => state.currentStep == state.maxSteps.getOrElse(-1) - ) - val progressController = new SimprocProgress[State]( - panel.project, "Greedy Reduction", simproc - ) - progressController.centerOnScreen() - progressController.open() - updateDerivation(progressController.returningDerivation, "lte reduce") - } - - - private def randomSimproc(): Unit = { - import quanto.cosy.SimplificationProcedure.LTEByWeight._ - val initialState: State = State( - rules = allowedRules, - currentStep = 0, - currentDistance = None, - maxSteps = Some(100), - seed = new Random(), - weightFunction = _ => None, - heldVertices = None, - vertexLimit = None) - val simproc = new SimplificationProcedure[State]( - (panel.derivation, panel.controller.state.step), - initialState, - step, - progress, - (_, state) => state.currentStep == state.maxSteps.getOrElse(-1) - ) - val progressController = new SimprocProgress[State]( - panel.project, "Random Rule Application", simproc - ) - progressController.centerOnScreen() - progressController.open() - updateDerivation(progressController.returningDerivation, "random apply") - } - - - val availableProcedures: Map[String, () => Unit] = Map( - "Random x 100" -> randomSimproc, - "Graph shrink x 100" -> lteSimproc, - "Greedy reduce" -> greedySimproc, - "Pull specials" -> pullSpecialsSimproc, - "Pull pi-errors" -> pullErrorsSimproc, - "Anneal" -> annealSimproc, - "Evaluate" -> evaluateSimproc - ) - Swing.onEDT { panel.SimplifyBuiltInPane.Simprocs.listData = availableProcedures.keys.toSeq } - - implicit def ruleFromDesc(ruleDesc: RuleDesc): Rule = { - Rule.fromJson(Json.parse(new File(panel.project.rootFolder + "/" + ruleDesc.name + ".qrule")), - theory, - description = Some(ruleDesc)) - } - - private def allowedRules = panel.rewriteController.rules.map(ruleFromDesc).toList - - - private def updateDerivation(derivationWithHead: DerivationWithHead, desc: String): Unit = { - println("updating derivation") - val currentDerivation = panel.document.derivation - - panel.document.undoStack.register(desc) { - updateDerivation((currentDerivation, panel.controller.state.step), desc) - } - - panel.document.derivation = derivationWithHead._1 - derivationWithHead._2 match { - case Some(stepName) => panel.controller.state = StepState(stepName) - case None => panel.controller.state = HeadState(None) - } - } - - - reactions += { - case ButtonClicked(panel.SimplifyBuiltInPane.SimplifyButton) => - if (panel.SimplifyBuiltInPane.Simprocs.selection.indices.nonEmpty) { - val procedureName: String = panel.SimplifyBuiltInPane.Simprocs.selection.items(0) - val procedure = availableProcedures(procedureName) - procedure.apply() - } - } - -} diff --git a/scala/src/main/scala/quanto/gui/SimplifyController.scala b/scala/src/main/scala/quanto/gui/SimplifyController.scala index 37af3f61..15ed1d43 100644 --- a/scala/src/main/scala/quanto/gui/SimplifyController.scala +++ b/scala/src/main/scala/quanto/gui/SimplifyController.scala @@ -6,8 +6,10 @@ import quanto.data._ import quanto.data.Names._ import quanto.util.json._ import akka.pattern.ask +import quanto.util.UserAlerts + import scala.concurrent.ExecutionContext.Implicits.global -import scala.swing.event.ButtonClicked +import scala.swing.event.{ButtonClicked, Event} import scala.util.{Failure, Success, Try} import quanto.util.UserAlerts.SelfAlertingProcess @@ -18,7 +20,19 @@ class SimplifyController(panel: DerivationPanel) extends Publisher { private var activeSimp: Option[Future[Boolean]] = None - listenTo(panel.SimplifyPane.RefreshButton, panel.SimplifyPane.SimplifyButton, panel.SimplifyPane.StopButton) + + case class StartSimproc(simproc: String) extends Event + case class RefreshSimprocs() extends Event + case class ReRunSimproc() extends Event + case class HaltSimproc() extends Event + case class SimprocHaltedWhileRunning() extends Exception("Simproc halted") + + listenTo(this, + panel.SimplifyPane.RefreshButton, + panel.SimplifyPane.SimplifyButton, + panel.SimplifyPane.StopButton, + panel, + PythonEditPanel) def theory = panel.theory @@ -33,98 +47,79 @@ class SimplifyController(panel: DerivationPanel) extends Publisher { refreshSimprocs() - private def pullSimp(simproc: String, sid: Int, stack: String, parentOpt: Option[DSName]) { - if (simpId == sid) { - //val res = QuantoDerive.core ? Call(theory.coreName, "simplify", "pull_next_step", - // JsonObject("stack" -> JsonString(stack))) - -// res.map { -// case Success(JsonNull) => // out of steps -// Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false } -// simpId += 1 -// case Success(json) => -// if (simpId == sid) { -// val suggest = simproc + "-" + (json / "rule_name").stringValue.replaceFirst("^.*\\/", "") + "-0" -// val sname = panel.derivation.steps.freshWithSuggestion(DSName(suggest)) -// val step = DStep.fromJson(sname, json, theory).layout -// -// Swing.onEDT { -// panel.document.derivation = panel.document.derivation.addStep(parentOpt, step) -// panel.controller.state = HeadState(Some(step.name)) -// pullSimp(simproc, sid, stack, Some(sname)) -// } -// -// } else { -// Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false } -// QuantoDerive.core ! Call(theory.coreName, "simplify", "delete_stack", -// JsonObject("stack" -> JsonString(stack))) -// } -// case _ => println("ERROR: Unexpected result from core: " + res) // TODO: errror dialogs -// } - None - } else { -// Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false } -// QuantoDerive.core ! Call(theory.coreName, "simplify", "delete_stack", -// JsonObject("stack" -> JsonString(stack))) - None - } - } + private var _lastRunSimproc : String = "" reactions += { - case ButtonClicked(panel.SimplifyPane.RefreshButton) => refreshSimprocs() - case ButtonClicked(panel.SimplifyPane.SimplifyButton) => - if (panel.SimplifyPane.Simprocs.selection.indices.nonEmpty) { - simpId += 1 - val simpName = panel.SimplifyPane.Simprocs.selection.items(0) - - QuantoDerive.CurrentProject.flatMap { pr => pr.simprocs.get(simpName) }.foreach { simproc => - var parentOpt = panel.controller.state.step - val processReporting = new SelfAlertingProcess("Simproc: " + simpName) - - val res = Future[Boolean] { - for ((graph, rule) <- simproc.simp(panel.LhsView.graph)) { - val suggest = simpName + "-" + rule.name.replaceFirst("^.*\\/", "") + "-0" - val step = DStep( - name = panel.derivation.steps.freshWithSuggestion(DSName(suggest)), - rule = rule, - graph = graph.minimise) // layout is already done by simproc now + case RequestReRunSimproc() => + publish(ReRunSimproc()) + case ReRunSimproc() => + publish(StartSimproc(_lastRunSimproc)) + case StartSimproc(simpName) => + _lastRunSimproc = simpName + QuantoDerive.CurrentProject.flatMap { pr => pr.simprocs.get(simpName) }.foreach { simproc => + var parentOpt = panel.controller.state.step + val sourceMessage = s"Running simproc '$simpName'".concat( + if (simproc.sourceFile != "") { + s" from ${simproc.sourceFile}" + } else { + "" + } + ) + UserAlerts.alert(sourceMessage) + val processReporting = new SelfAlertingProcess("Simproc: " + simpName) + val simpIdAtStart = simpId + val res = Future[Boolean] { + val iteratedSimp : Iterator[(Graph, Rule)] = simproc.simp(panel.LhsView.graph) + + // Don't update the derivation if the simpId call has changed + while(iteratedSimp.hasNext && simpId == simpIdAtStart){ + val (graph, rule) = iteratedSimp.next() + val suggest = simpName + "-" + rule.name.replaceFirst("^.*\\/", "") + "-0" + val step = DStep( + name = panel.derivation.steps.freshWithSuggestion(DSName(suggest)), + rule = rule, + graph = graph.minimise) // layout is already done by simproc now panel.document.derivation = panel.document.derivation.addStep(parentOpt, step) parentOpt = Some(step.name) + Swing.onEDT { + panel.controller.state = HeadState(Some(step.name)) + } + } - Swing.onEDT { panel.controller.state = HeadState(Some(step.name)) } - } - true + if(simpId != simpIdAtStart) { + throw SimprocHaltedWhileRunning() } - res.onComplete { - case Success(b) => - processReporting.finish() - case Failure(e) => - processReporting.fail() - e.printStackTrace() - } + true } - -// val res = QuantoDerive.core ? Call(theory.coreName, "simplify", "simplify", JsonObject( -// "simproc" -> JsonString(simproc), -// "graph" -> Graph.toJson(panel.LhsView.graph, theory) -// )) -// res.map { -// case Success(JsonString(stack)) => -// Swing.onEDT { -// QuantoDerive.ConsoleProgress.indeterminate = true -// pullSimp(simproc, simpId, stack, panel.controller.state.step) -// } -// -// case _ => println("ERROR: Unexpected result from core: " + res) // TODO: errror dialogs -// } + res.onComplete { + case Success(b) => + processReporting.finish() + case Failure(SimprocHaltedWhileRunning()) => + processReporting.halt() + case Failure(e) => + e.printStackTrace() + } } - case ButtonClicked(panel.SimplifyPane.StopButton) => - Swing.onEDT { QuantoDerive.ConsoleProgress.indeterminate = false } + case HaltSimproc() => simpId += 1 + case RefreshSimprocs() => + refreshSimprocs() + case ButtonClicked(panel.SimplifyPane.RefreshButton) => + publish(RefreshSimprocs()) + case ButtonClicked(panel.SimplifyPane.SimplifyButton) => + if (panel.SimplifyPane.Simprocs.selection.indices.nonEmpty) { + simpId += 1 + val simpName = panel.SimplifyPane.Simprocs.selection.items(0) + publish(StartSimproc(simpName)) + } + case ButtonClicked(panel.SimplifyPane.StopButton) => + publish(HaltSimproc()) + case SimprocsUpdated() => + refreshSimprocs() } } diff --git a/scala/src/main/scala/quanto/gui/SpecifyVariablesDialog.scala b/scala/src/main/scala/quanto/gui/SpecifyVariablesDialog.scala new file mode 100644 index 00000000..9f6475bc --- /dev/null +++ b/scala/src/main/scala/quanto/gui/SpecifyVariablesDialog.scala @@ -0,0 +1,93 @@ +package quanto.gui + +import quanto.data.Theory.ValueType + +import scala.swing._ +import scala.swing.event.{ButtonClicked, Key, KeyPressed, ValueChanged} +import quanto.data._ +import quanto.util.{Globals, UserOptions} + +import scala.util.matching +import scala.util.matching.Regex +import quanto.util.UserOptions.{scale, scaleInt} + +class SpecifyVariablesDialog(variables: List[(ValueType, String)]) extends Dialog { + modal = true + title = "Specify values for introduced variables?" + val smallGap: Int = UserOptions.scaleInt(5) + + + + val AddButton = new Button("Add") + val CancelButton = new Button("Cancel") + defaultButton = Some(AddButton) + + def result: Map[(ValueType, String), String] = { + if(!cancelled) { + variables.map(sv => sv -> textFields(sv).text).toMap + } else { + variables.map(sv => sv -> sv._2).toMap + } + } + + + + implicit def buttonIsSelected(radButton: RadioButton): Boolean = radButton.selected + var cancelled = false + + + // val dir = Files.newDirectoryStream(Paths.get(rootDir), "**/*.qrule") + // for (p <- dir.asScala) println(p) + val textFields : Map[(ValueType, String), TextField] = variables.map(name => { + name -> new TextField + }).toMap + + val MainPanel = new BoxPanel(Orientation.Vertical) { + + contents += Swing.VStrut(smallGap) + + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(smallGap), + new Label("This rule introduces new variables. Would you like to specify values?"), + Swing.HStrut(smallGap)) + } + contents += Swing.VStrut(smallGap) + + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(smallGap), + new GridPanel(variables.length, 2) { + variables.foreach(name => { + val label = new Label(s"${name._2} (${TheoryEditPanel.valueTypesAsHumanReadable(name._1)})") + label.horizontalAlignment = Alignment.Left + contents += label + val tf = textFields(name) + contents += tf + tf.text = name._2 + }) + }, + Swing.HStrut(smallGap)) + } + contents += Swing.VStrut(smallGap) + + contents += new BoxPanel(Orientation.Horizontal) { + contents += (AddButton, Swing.HStrut(smallGap), CancelButton) + } + + contents += Swing.VStrut(2*smallGap) + + } + + + contents = MainPanel + + listenTo(AddButton, CancelButton) + + reactions += { + case ButtonClicked(AddButton) => + cancelled = false + close() + case ButtonClicked(CancelButton) => + cancelled = true + close() + } +} diff --git a/scala/src/main/scala/quanto/gui/StatusBar.scala b/scala/src/main/scala/quanto/gui/StatusBar.scala new file mode 100644 index 00000000..88a8a562 --- /dev/null +++ b/scala/src/main/scala/quanto/gui/StatusBar.scala @@ -0,0 +1,118 @@ +package quanto.gui + +import java.awt.Desktop +import java.awt.event.{MouseAdapter, MouseEvent} + +import quanto.gui.QuantoDerive.{Split, listenTo, popup, CurrentProject} +import quanto.util.{UserAlerts, UserOptions} + +import scala.swing.{Action, BorderPanel, Dimension, FlowPanel, GridPanel, Label, MenuItem, ProgressBar, Publisher, Separator} +import scala.swing.event.{Event, Key} + +class StatusBar extends BorderPanel { + + + def MessagesContextMenu() : PopupMenu = new PopupMenu { + menu => + + def showAlert(a : UserAlerts.Alert) : Unit = { + menu.contents += new Label(" "+a.toString) + } + + val numAlerts : Int = UserAlerts.alerts.count(_ => true) + if(numAlerts < 6) { + // Few alerts, so list all of them + UserAlerts.alerts.reverse.foreach(showAlert) + } else { + menu.contents += new Label(" ...") + // List only first 5 + UserAlerts.alerts.slice(0,5).reverse.foreach(showAlert) + } + + + def whichLoggingTitle : String = { + if (UserOptions.logging) {"Disable logging"} else {"Enable logging"} + } + + val ShowLogAction: Action = new Action(whichLoggingTitle) { + + menu.contents += new MenuItem(this) { + mnemonic = Key.L + } + + def apply() { + UserOptions.logging = !UserOptions.logging + } + } + + menu.contents += new Separator() + + val ClearMessagesAction: Action = new Action("Clear message") { + menu.contents += new MenuItem(this) { + mnemonic = Key.C + } + + def apply() { + clearStatusBar() + } + } + + } + + + val UserMessage = new Label(UserAlerts.latestMessage.toString) + val ConsoleProgress = new ProgressBar + val ConsoleProgressLabel = new Label(" ") + val Status : GridPanel = new GridPanel(1, 2) { + contents += new FlowPanel(FlowPanel.Alignment.Left)(UserMessage) + contents += new FlowPanel(FlowPanel.Alignment.Right)(ConsoleProgressLabel, ConsoleProgress) + } + + ConsoleProgress.preferredSize = ConsoleProgressSize //Currently doesn't respond to UI scaling + + def ConsoleProgressSize: Dimension = new Dimension(UserOptions.scaleInt(100), UserOptions.scaleInt(15)) + + Status.peer.addMouseListener(new MouseAdapter { + override def mousePressed(e: MouseEvent) { + e.getButton match { + case _ => + if (e.isPopupTrigger) { + popup(MessagesContextMenu(), Some(e)) + } + } + } + + }) + + def clearStatusBar() : Unit = { + UserMessage.text = "" + ConsoleProgressLabel.text = " " + ConsoleProgress.value = 0 + ConsoleProgress.indeterminate = false + } + + + listenTo(UserAlerts.AlertPublisher) + reactions += { + case UserAlerts.UserAlertEvent(alert: UserAlerts.Alert) => + UserMessage.text = alert.toString + UserMessage.foreground = alert.color + case UserAlerts.UserProcessUpdate(_) => + UserAlerts.leastCompleteProcess match { + case Some(process) => if (process.determinate) { + ConsoleProgress.indeterminate = false + ConsoleProgress.value = process.value + } else { + ConsoleProgress.indeterminate = true + } + case _ => ConsoleProgress.value = 100 + } + val ongoing = UserAlerts.ongoingProcesses.filter(op => op.value < 100) + ongoing.count(_ => true) match { + case 0 => ConsoleProgressLabel.text = " " //keep non-empty so the progressbar stays in line with text + case 1 => ConsoleProgressLabel.text = ongoing.head.name + case n => ConsoleProgressLabel.text = n.toString + " processes ongoing" + } + } + +} diff --git a/scala/src/main/scala/quanto/gui/TextEditor.scala b/scala/src/main/scala/quanto/gui/TextEditor.scala new file mode 100644 index 00000000..7f7b96a5 --- /dev/null +++ b/scala/src/main/scala/quanto/gui/TextEditor.scala @@ -0,0 +1,107 @@ +package quanto.gui + +import java.awt.BorderLayout +import java.awt.event.{KeyAdapter, KeyEvent} +import java.io.{File, PrintWriter} + +import org.gjt.sp.jedit.buffer.JEditBuffer +import org.gjt.sp.jedit.textarea.StandaloneTextArea +import org.gjt.sp.jedit.{Mode, Registers} +import quanto.util.UserOptions + +import scala.io.Source +import scala.swing.{BorderPanel, Label} + +class TextEditor(val mode: Mode) extends BorderPanel { + lazy val Component: BorderPanel = new BorderPanel { + peer.add(TextArea, BorderLayout.CENTER) + } + + import org.gjt.sp.jedit.IPropertyManager + import org.gjt.sp.jedit.textarea.StandaloneTextArea + + // TODO: check font availability and/or allow user to select one + val props = new java.util.Properties() + val propFile = this.getClass.getResourceAsStream("jedit.props") + val keyFile = this.getClass.getResourceAsStream("jEdit_keys.props") + props.load(propFile) + props.load(keyFile) + //props.setProperty("view.font", "Arial") + props.setProperty("view.fontsize", UserOptions.fontSize.toString) + //props.setProperty("view.fontstyle", "0") + //props.putAll(loadProperties("/keymaps/jEdit_keys.props")) + //props.putAll(loadProperties("/org/gjt/sp/jedit/jedit.props")) + val TextArea = new StandaloneTextArea(new IPropertyManager() { + override def getProperty(name: String): String = props.getProperty(name) + }) + //textArea.getBuffer.setProperty("folding", "explicit") + //val TextArea: StandaloneTextArea = StandaloneTextArea.createTextArea() + + private val buf = new JEditBuffer1 + private val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask + + def registerBuffer() { + TextArea.setFont(UserOptions.font) + buf.setMode(mode) + TextArea.setBuffer(buf) + TextArea.addKeyListener(new KeyAdapter { + override def keyPressed(e: KeyEvent) { + if (e.getModifiers == CommandMask) e.getKeyChar match { + case 'x' => Registers.cut(TextArea, '$') + case 'c' => Registers.copy(TextArea, '$') + case 'v' => Registers.paste(TextArea, '$') + case _ => + } + } + }) + } + + def getText: String = { + TextArea.getBuffer.getText + } + + registerBuffer() +} + +object TextEditor { + // JEdit can't access scala resources directly, so need to make a dummy file to hold the xml + def makeDummyXMLFile(name: String): String = { + val f = File.createTempFile(name, "xml") + f.deleteOnExit() + val pr = new PrintWriter(f) + Source.fromInputStream(getClass.getResourceAsStream(s"$name.xml")).foreach(pr.print) + pr.close() + f.getCanonicalPath + } + + object Modes { + def python: Mode = { + val mode = new Mode("Python") + val modeFile = makeDummyXMLFile("python") + mode.setProperty("file", modeFile) + mode + } + + def markdown: Mode = { + val mode = new Mode("Markdown") + val modeFile = makeDummyXMLFile("markdown") + mode.setProperty("file", modeFile) + mode + } + + def blank : Mode = { + val mode = new Mode("Blank") + val modeFile = makeDummyXMLFile("blank") + mode.setProperty("file", modeFile) + mode + } + + def fromFile(path: String, identifier: String): Mode = { + val mode = new Mode(identifier) + val modeFile = path + mode.setProperty("file", modeFile) + mode + } + } + +} \ No newline at end of file diff --git a/scala/src/main/scala/quanto/gui/TheoryDocument.scala b/scala/src/main/scala/quanto/gui/TheoryDocument.scala new file mode 100644 index 00000000..f2ec1c92 --- /dev/null +++ b/scala/src/main/scala/quanto/gui/TheoryDocument.scala @@ -0,0 +1,85 @@ +package quanto.gui + +import java.io.File + +import quanto.data._ +import quanto.util.UserAlerts +import quanto.util.json.Json + +import scala.collection.mutable +import scala.ref.Reference +import scala.swing.Reactions.Reaction +import scala.swing.{Component, Publisher, RefSet} +import scala.swing.event.Event +import scala.swing.event.Event + + +case class TheoryChanged() extends Event + +// Will infer the theory from QuantoDerive! +// Otherwise can have multiple theories editing at once, but only one should be affecting the project +class TheoryDocument(val parent: Component) extends Document with Publisher { + val description = "Theory" + val fileExtension = "qtheory" + private var _theory: Theory = getTheory() + + def getTheory(): Theory = { + if (QuantoDerive.CurrentProject.isEmpty) { + blankTheory() + } else { + QuantoDerive.CurrentProject.get.theory + } + } + + protected def clearDocument() { + _theory = blankTheory() + publish(TheoryChanged()) + } + + def blankTheory(): Theory = { + new Theory(name = "unsaved", + coreName = "unsaved", + vertexTypes = Map(), edgeTypes = Map(), + defaultVertexType = "") + } + + override def loadDocument(f: File) { + // Not convinced this should ever be called + val json = Json.parse(f) + val newTheory = Theory.fromJson(json) + //theory = newTheory + publish(TheoryChanged()) + } + + // There should never be unsaved changes + override def unsavedChanges: Boolean = false + + protected def saveDocument(f: File) { + val json = Theory.toJson(theory) + json.writeTo(f) + publish(TheoryChanged()) + } + + def theory: Theory = _theory + + def theory_=(th: Theory): Unit = { + if (QuantoDerive.CurrentProject.isEmpty){ + UserAlerts.alert("Please open a project before altering theory files", UserAlerts.Elevation.WARNING) + } else { + val project = QuantoDerive.CurrentProject.get + QuantoDerive.CurrentProject = Some(new Project(th, project.projectFile, project.name)) + QuantoDerive.updateProjectFile(project.projectFile) + } + _theory = getTheory() + publish(TheoryChanged()) + } + + + override protected def exportDocument(f: File) { + showSaveAsDialog(None) + } + + override def titleDescription: String = "Theory Editor" + + +} diff --git a/scala/src/main/scala/quanto/gui/TheoryEditPanel.scala b/scala/src/main/scala/quanto/gui/TheoryEditPanel.scala new file mode 100644 index 00000000..66c41629 --- /dev/null +++ b/scala/src/main/scala/quanto/gui/TheoryEditPanel.scala @@ -0,0 +1,757 @@ +package quanto.gui + +import java.awt.{Color, Shape} + +import org.lindenb.svg.SVGUtils +import quanto.data.Theory._ +import quanto.data.{CompositeExpression, GenericParseException, Theory, TheoryLoadException} +import quanto.util.UserAlerts.{Elevation, alert} +import quanto.util.UserOptions.scaleInt +import quanto.util._ +import quanto.util.json.{Json, JsonObject} +import quanto.util.swing.ToolBar + +import scala.language.postfixOps + +import scala.swing._ +import scala.swing.event.{ButtonClicked, SelectionChanged} + +import TheoryEditPanel._ + +class TheoryEditPanel() extends BorderPanel with HasDocument { + val document = new TheoryDocument(this) + val CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask + val Toolbar = new ToolBar { + //contents + } + val EditorsCombined = new BoxPanel(Orientation.Vertical) + val TopScrollablePane = new ScrollPane(EditorsCombined) + + + + def chooseNodeDataType(node: String, current: String): Unit = { + val dialog = new DataTypePickingDialog(node, current) + dialog.centerOnScreen() + dialog.open() + if(dialog.wasAccepted){ + val typeSelected = dialog.DataComboBox.selection.item + val newTypeVector: Vector[ValueType] = if(typeSelected == "composite"){ + try{ + CompositeExpression.parseTypes(dialog.CustomText.text) + } catch { + case _ : GenericParseException => Vector() + } + } else { + CompositeExpression.parseTypes(typeSelected) + } + + + val oldVertexDesc = theory.vertexTypes(node) + val newTypeValue : ValueDesc = new ValueDesc(typ = newTypeVector, + oldVertexDesc.value.enumOptions, + oldVertexDesc.value.latexConstants, + oldVertexDesc.value.validateWithCore) + + val newVertexDesc : VertexDesc = new VertexDesc(newTypeValue, + oldVertexDesc.style, + oldVertexDesc.defaultData) + + val newVertexTypes: Map[String, VertexDesc] = theory.vertexTypes + (node -> newVertexDesc) + // Update the base document's theory + theory = theory.copy(vertexTypes = newVertexTypes) + } + } + + def chooseNodeBorderWidth(node: String, current: Int): Unit = { + + val dialog = new SizeDialog(s"$node's border", current) + dialog.centerOnScreen() + dialog.open() + val newSize = dialog.SizeComboBox.selection.item.toInt + + val oldStyle: VertexStyleDesc = theory.vertexTypes(node).style + val newStyle: VertexStyleDesc = oldStyle.copy(strokeWidth = newSize) + + val oldVertexDesc = theory.vertexTypes(node) + val newVertexDesc = new VertexDesc(oldVertexDesc.value, + newStyle, + oldVertexDesc.defaultData) + + val newVertexTypes: Map[String, VertexDesc] = theory.vertexTypes + (node -> newVertexDesc) + // Update the base document's theory + theory = theory.copy(vertexTypes = newVertexTypes) + } + + def chooseNodeColour(node: String, current: Color): Unit = { + val dialog = new ColourPickingDialog(node, current) + dialog.centerOnScreen() + dialog.open() + val newColourHex: String = dialog.CustomText.text + val newColour = Color.decode(newColourHex) + val oldStyle: VertexStyleDesc = theory.vertexTypes(node).style + val newStyle: VertexStyleDesc = oldStyle.copy(fillColor = newColour) + + val oldVertexDesc = theory.vertexTypes(node) + val newVertexDesc = new VertexDesc(oldVertexDesc.value, + newStyle, + oldVertexDesc.defaultData) + + val newVertexTypes: Map[String, VertexDesc] = theory.vertexTypes + (node -> newVertexDesc) + // Update the base document's theory + theory = theory.copy(vertexTypes = newVertexTypes) + } + + // Method for choosing the strokecolour of an edge + def chooseEdgeColour(edge: String, current: Color): Unit = { + val dialog = new ColourPickingDialog(s"Choose colour for edge $edge", current) + dialog.centerOnScreen() + dialog.open() + val newColourHex: String = dialog.CustomText.text + val newColour = Color.decode(newColourHex) + val oldStyle: EdgeStyleDesc = theory.edgeTypes(edge).style + val newStyle: EdgeStyleDesc = oldStyle.copy(strokeColor = newColour) + + val oldEdgeDesc = theory.edgeTypes(edge) + val newEdgeDesc = new EdgeDesc(oldEdgeDesc.value, + newStyle, + oldEdgeDesc.defaultData) + + val newEdgeTypes: Map[String, EdgeDesc] = theory.edgeTypes + (edge -> newEdgeDesc) + // Update the base document's theory + theory = theory.copy(edgeTypes = newEdgeTypes) + } + + // Method for choosing the strokewidth of an edge + def chooseEdgeWidth(edge: String, current: Int): Unit = { + val dialog = new SizeDialog(edge, current) + dialog.centerOnScreen() + dialog.open() + val newSize = dialog.SizeComboBox.selection.item.toInt + val oldStyle: EdgeStyleDesc = theory.edgeTypes(edge).style + val newStyle: EdgeStyleDesc = oldStyle.copy(strokeWidth = newSize) + + val oldEdgeDesc = theory.edgeTypes(edge) + val newEdgeDesc = new EdgeDesc(oldEdgeDesc.value, + newStyle, + oldEdgeDesc.defaultData) + + val newEdgeTypes: Map[String, EdgeDesc] = theory.edgeTypes + (edge -> newEdgeDesc) + // Update the base document's theory + theory = theory.copy(edgeTypes = newEdgeTypes) + } + + // Choose label position for current node + def chooseLabelPlacement(node: String, current: VertexLabelPosition): Unit = { + val dialog = new LabelPlacementDialog(node, current) + dialog.centerOnScreen() + dialog.open() + val newPlacementName: String = dialog.PlacementComboBox.item + val newPlacement: VertexLabelPosition = VertexLabelPosition.fromName(newPlacementName). + getOrElse(VertexLabelPosition.values.head) + val oldStyle: VertexStyleDesc = theory.vertexTypes(node).style + val newStyle: VertexStyleDesc = oldStyle.copy(labelPosition = newPlacement) + + val oldVertexDesc = theory.vertexTypes(node) + val newVertexDesc = new VertexDesc(oldVertexDesc.value, + newStyle, + oldVertexDesc.defaultData) + + val newVertexTypes: Map[String, VertexDesc] = theory.vertexTypes + (node -> newVertexDesc) + // Update the base document's theory + theory = theory.copy(vertexTypes = newVertexTypes) + } + + def theory: Theory = document.theory + + def theory_=(th: Theory) = document.theory_=(th) + + // Prompts the user to select a type of node, including "custom" as an option + def chooseNodeShape(node: String, current: (VertexShape, Option[Shape])): Unit = { + val currentPath = current._2 match { + case None => None + case Some(shape) => Some(SVGUtils.shapeToPath(shape)) + } + val dialog = new NodeShapeDialog(node, current._1, currentPath.getOrElse("")) + dialog.centerOnScreen() + dialog.open() + val newShapeName: String = dialog.ShapeComboBox.item + val newShape = VertexShape.fromName(newShapeName).getOrElse(VertexShape.Circle) + val oldStyle: VertexStyleDesc = theory.vertexTypes(node).style + val newStyle: VertexStyleDesc = oldStyle.copy( + newShape, + newShape match { + case VertexShape.Custom => + try { + Some(SVGUtils.pathToShape(dialog.CustomText.text)) + } catch { + case e: Exception => new TheoryLoadException("Could not interpret custom shape path", e) + None + } + case _ => None + } + ) + + val oldVertexDesc = theory.vertexTypes(node) + val newVertexDesc = new VertexDesc(oldVertexDesc.value, + newStyle, + oldVertexDesc.defaultData) + + val newVertexTypes: Map[String, VertexDesc] = theory.vertexTypes + (node -> newVertexDesc) + // Update the base document's theory + theory = theory.copy(vertexTypes = newVertexTypes) + } + + def createPageComponents(): Unit = { + + val separation = UserOptions.scaleInt(20) + + def horizontalWrap(component: Component): Unit = { + EditorsCombined.contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), component, Swing.HStrut(10)) + } + } + + val buttonSize = new Dimension(maxGridSize.width, scaleInt(30)) + + EditorsCombined.contents.clear + + EditorsCombined.contents += Swing.VStrut(separation) + horizontalWrap(new Label( + """All changes are saved immediately. Re-open graphs to see the changes.""")) + horizontalWrap(new Label( + """It is recommended that you back up your project before altering anything on this page!""")) + + EditorsCombined.contents += Swing.VStrut(separation) + EditorsCombined.contents += new Separator() + EditorsCombined.contents += Swing.VStrut(separation) + + horizontalWrap(new Label("Vertices:")) + theory.vertexTypes.toSeq.sortBy(vt => vt._1).map(vt => { + EditorsCombined.contents += new NodeEditor(vt._1, vt._2) + EditorsCombined.contents += Swing.VStrut(separation) + }) + val AddVertexButton = new Button("Add Vertex Type") + AddVertexButton.preferredSize = buttonSize + horizontalWrap(AddVertexButton) + EditorsCombined.contents += Swing.VStrut(separation) + EditorsCombined.contents += new Separator() + EditorsCombined.contents += Swing.VStrut(separation) + horizontalWrap(new Label("Edges:")) + theory.edgeTypes.toSeq.sortBy(et => et._1).map(et => { + EditorsCombined.contents += new EdgeEditor(et._1, et._2) + EditorsCombined.contents += Swing.VStrut(separation) + } + ) + val AddEdgeButton = new Button("Add Edge Type") + AddEdgeButton.preferredSize = buttonSize + horizontalWrap(AddEdgeButton) + EditorsCombined.contents += Swing.VStrut(separation) + + listenTo(AddEdgeButton, AddVertexButton) + reactions += { + case ButtonClicked(AddEdgeButton) => + addNewEdge() + case ButtonClicked(AddVertexButton) => + addNewVertex() + + } + repaint() + } + + def addNewVertex(): Unit = { + val d = new ChooseStringDialog("Name for new vertex type") + d.centerOnScreen() + d.open() + + val result = d.StringField.text + if (result != "" && !theory.vertexTypes.keys.exists(k => k == result)) { + val newVertexDesc = new VertexDesc( + value = new ValueDesc(), + style = new VertexStyleDesc(shape = VertexShape.Circle), + defaultData = JsonObject( + "type" -> Json.stringToJson(result), + "value" -> Json.stringToJson("") + ) + ) + val newVertexTypes: Map[String, Theory.VertexDesc] = theory.vertexTypes ++ Map(result -> newVertexDesc) + // Update the base document's theory + theory = theory.copy(vertexTypes = newVertexTypes) + } + } + + def addNewEdge(): Unit = { + val d = new ChooseStringDialog("Name for new edge type") + d.centerOnScreen() + d.open() + + val result = d.StringField.text + if (result != "") { + if (!theory.edgeTypes.keys.exists(k => k == result)) { + val newEdgeTypes: Map[String, Theory.EdgeDesc] = theory.edgeTypes ++ Map(result -> EdgeDesc( + value = ValueDesc(typ = Vector(ValueType.Empty)), + style = EdgeStyleDesc(), + defaultData = JsonObject("type" -> result) + )) + // Update the base document's theory + theory = new Theory(name = theory.name, + coreName = theory.coreName, + vertexTypes = theory.vertexTypes, + edgeTypes = newEdgeTypes, + defaultVertexType = theory.defaultVertexType, + defaultEdgeType = theory.defaultEdgeType + ) + } else { + UserAlerts.alert("That name is already in use", Elevation.WARNING) + } + } + } + + def colourToString(colour: Color): String = { + var returnColour = f"#${0xFFFFFF & colour.hashCode()}%06X" + for ((colourName, colourValue) <- approvedColours) { + if (colourValue == colour) { + returnColour = colourName + } + } + returnColour + } + + class ChooseStringDialog(title: String) extends Dialog { + modal = true + + + val AcceptButton = new Button("Accept") + val CancelButton = new Button("Cancel") + val StringField = new TextField("") + defaultButton = Some(AcceptButton) + val ShapeEditorPanel : BoxPanel = new BoxPanel(Orientation.Vertical) { + + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), new Label("Name:"), Swing.HStrut(5), StringField, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), AcceptButton, Swing.HStrut(5), CancelButton, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + } + + contents = ShapeEditorPanel + + listenTo(AcceptButton, CancelButton) + + reactions += { + case ButtonClicked(AcceptButton) => + close() + case ButtonClicked(CancelButton) => + StringField.text = "" + close() + } + } + + class SizeDialog(identifier: String, + current: Int, + sizeOptions: Seq[String] = (1 to 5).map(_.toString)) extends Dialog { + title = s"Choose the size of $identifier" + modal = true + + val AcceptButton = new Button("Accept") + val CancelButton = new Button("Cancel") + val SizeComboBox = new ComboBox(sizeOptions) + SizeComboBox.selection.item = current.toString + defaultButton = Some(CancelButton) + val ShapeEditorPanel : BoxPanel = new BoxPanel(Orientation.Vertical) { + + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), new Label("Size:"), Swing.HStrut(5), SizeComboBox, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), AcceptButton, Swing.HStrut(5), CancelButton, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + } + + contents = ShapeEditorPanel + + listenTo(AcceptButton, CancelButton, SizeComboBox.selection) + + reactions += { + case ButtonClicked(AcceptButton) => + close() + case ButtonClicked(CancelButton) => + SizeComboBox.item = current.toString + close() + case SelectionChanged(SizeComboBox) => + } + } + + class LabelPlacementDialog(identifier: String, current: VertexLabelPosition) extends Dialog { + title = s"Choose the label placement for $identifier" + modal = true + + val positionOptions: Seq[String] = VertexLabelPosition.values. + map(vt => vt.toString).filterNot(s => s == "custom").toSeq + + + val AcceptButton = new Button("Accept") + val CancelButton = new Button("Cancel") + val PlacementComboBox = new ComboBox(positionOptions) + PlacementComboBox.selection.item = current.toString + defaultButton = Some(CancelButton) + val ShapeEditorPanel : BoxPanel = new BoxPanel(Orientation.Vertical) { + + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), new Label("Label placement:"), Swing.HStrut(5), PlacementComboBox, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), AcceptButton, Swing.HStrut(5), CancelButton, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + } + + contents = ShapeEditorPanel + + listenTo(AcceptButton, CancelButton, PlacementComboBox.selection) + + reactions += { + case ButtonClicked(AcceptButton) => + close() + case ButtonClicked(CancelButton) => + PlacementComboBox.item = current.toString + close() + case SelectionChanged(PlacementComboBox) => + } + } + + class DataTypePickingDialog(str: String, current: String) extends Dialog { + modal = true + + private var currentForBox : String = current + + private val currentParsedType : Option[ValueType] = { + try { + val types = CompositeExpression.parseTypes(current) + if(types.length == 1) { + currentForBox = "" + Some(types.head) + } else { + currentForBox = valueTypeVectorToString(types) + None + } + } + catch { + case GenericParseException(msg) => { + currentForBox = "" + None + } + } + } + + val dataTypeOptions: Seq[String] = valueTypesAsHumanReadable.values.toSet.toSeq.sorted :+ "composite" + val AcceptButton = new Button("Accept") + val CancelButton = new Button("Cancel") + val DataComboBox = new ComboBox(dataTypeOptions) + DataComboBox.selection.item = { + if (currentParsedType.nonEmpty) { + current + } else { + "composite" + } + } + val CustomText = new TextField() + CustomText.text = currentForBox + val ShapeEditorPanel: BoxPanel = new BoxPanel(Orientation.Vertical) { + + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), new Label("Data Type:"), Swing.HStrut(5), DataComboBox, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), new Label("Composite Type:"), Swing.HStrut(5), CustomText, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), AcceptButton, Swing.HStrut(5), CancelButton, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + } + + enableCustom() + defaultButton = Some(CancelButton) + + + def enableCustom(): Unit = { + DataComboBox.selection.item match { + case "composite" => CustomText.enabled = true + case _ => CustomText.enabled = false + } + } + + contents = ShapeEditorPanel + + listenTo(AcceptButton, CancelButton, DataComboBox.selection) + + var wasAccepted : Boolean = false + + reactions += { + case ButtonClicked(AcceptButton) => + wasAccepted = true + close() + case ButtonClicked(CancelButton) => + wasAccepted = false + close() + case SelectionChanged(DataComboBox) => + enableCustom() + } + } + + // example: ColourPickingDialog("X", "foreground colour", Color(10,10,200)) + class ColourPickingDialog(title: String, current: Color) extends Dialog { + modal = true + + val currentColourName = colourToString(current) + + val colourOptions: Seq[String] = approvedColours.keys.toSet.toSeq.sorted :+ "custom" + val AcceptButton = new Button("Accept") + val CancelButton = new Button("Cancel") + val ColourComboBox = new ComboBox(colourOptions) + ColourComboBox.selection.item = currentColourName + val CustomText = new TextField() + displayColour(current) + val ShapeEditorPanel: BoxPanel = new BoxPanel(Orientation.Vertical) { + + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), new Label("Colour:"), Swing.HStrut(5), ColourComboBox, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), new Label("Custom data:"), Swing.HStrut(5), CustomText, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), AcceptButton, Swing.HStrut(5), CancelButton, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + } + + enableCustom() + defaultButton = Some(CancelButton) + + def displayColour(colour: Color): Unit = { + CustomText.text = f"#${0xFFFFFF & colour.hashCode()}%06X" + } + + def enableCustom(): Unit = { + ColourComboBox.selection.item match { + case "custom" => CustomText.enabled = true + case _ => CustomText.enabled = false + } + } + + contents = ShapeEditorPanel + + listenTo(AcceptButton, CancelButton, ColourComboBox.selection) + + reactions += { + case ButtonClicked(AcceptButton) => + close() + case ButtonClicked(CancelButton) => + ColourComboBox.item = currentColourName + close() + case SelectionChanged(ColourComboBox) => + ColourComboBox.selection.item match { + case "custom" => + case s => displayColour(approvedColours(s)) + } + enableCustom() + } + } + + class NodeShapeDialog(name: String, current: VertexShape, currentCustom: String) extends Dialog { + title = s"Choose the shape for node $name" + modal = true + + val shapeOptions: Seq[String] = VertexShape.values. + map(vt => vt.toString).filterNot(s => s == "custom").toSeq + + + val AcceptButton = new Button("Accept") + val CancelButton = new Button("Cancel") + val ShapeComboBox = new ComboBox(shapeOptions) + ShapeComboBox.selection.item = current.toString + val CustomText = new TextField(currentCustom) + enableCustom() + defaultButton = Some(CancelButton) + val ShapeEditorPanel = new BoxPanel(Orientation.Vertical) { + + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), new Label("Vertex shape:"), Swing.HStrut(5), ShapeComboBox, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), new Label("Custom data:"), Swing.HStrut(5), CustomText, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + contents += new BoxPanel(Orientation.Horizontal) { + contents += (Swing.HStrut(10), AcceptButton, Swing.HStrut(5), CancelButton, Swing.HStrut(10)) + } + contents += Swing.VStrut(10) + } + + def enableCustom(): Unit = { + ShapeComboBox.selection.item match { + case "custom" => CustomText.enabled = true + case _ => CustomText.enabled = false + } + } + + contents = ShapeEditorPanel + + listenTo(AcceptButton, CancelButton, ShapeComboBox.selection) + + reactions += { + case ButtonClicked(AcceptButton) => + close() + case ButtonClicked(CancelButton) => + ShapeComboBox.item = current.toString + close() + case SelectionChanged(ShapeComboBox) => + enableCustom() + } + } + + class NodeEditor(nodeName: String, desc: VertexDesc) extends BoxPanel(Orientation.Vertical) { + // Note that GridPanel expands to fill parent + maximumSize = maxGridSize + contents += new GridPanel(5, 2) { + + contents += new Label("Name") + contents += new Label(nodeName) // Currently doesn't support renaming nodes (and shouldn't?) + + contents += new Label("Shape") + val EditShapeButton: Button = new Button(desc.style.shape.toString) + contents += EditShapeButton + + contents += new Label("Colour") + val EditColourButton: Button = new Button(colourToString(desc.style.fillColor)) + contents += EditColourButton + + contents += new Label("Values") + val EditValueTypeButton: Button = new Button(valueTypeVectorToString(desc.value.typ)) + contents += EditValueTypeButton + + contents += new Label("Border width") + val EditBorderWidthButton: Button = new Button(desc.style.strokeWidth.toString) + contents += EditBorderWidthButton + + + //contents += new Label("Label placement") + //val EditPlacementButton: Button = new Button(desc.style.labelPosition.toString) + //contents += EditPlacementButton + //contents += new Label("Example") + //contents += new Label("Example here") + listenTo(EditShapeButton, EditColourButton, EditValueTypeButton, EditBorderWidthButton) + reactions += { + case ButtonClicked(EditShapeButton) => + chooseNodeShape(nodeName, (desc.style.shape, desc.style.customShape)) + case ButtonClicked(EditColourButton) => + chooseNodeColour(nodeName, desc.style.fillColor) + case ButtonClicked(EditValueTypeButton) => + chooseNodeDataType(nodeName, TheoryEditPanel.valueTypeVectorToString(desc.value.typ)) + case ButtonClicked(EditBorderWidthButton) => + chooseNodeBorderWidth(nodeName, desc.style.strokeWidth) + } + } + } +val maxGridSize = new Dimension(scaleInt(200), scaleInt(100)) + + class EdgeEditor(edgeName: String, desc: EdgeDesc) extends BoxPanel(Orientation.Vertical) { + // Note that GridPanel expands to fill parent + maximumSize = maxGridSize + contents += new GridPanel(3, 2) { + contents += new Label("Name") + contents += new Label(edgeName) // Currently doesn't support renaming nodes (and shouldn't?) + contents += new Label("Width") + val EditShapeButton: Button = new Button(desc.style.strokeWidth.toString) + contents += EditShapeButton + contents += new Label("Colour") + val EditColourButton: Button = new Button(colourToString(desc.style.strokeColor)) + contents += EditColourButton + //contents += new Label("Label placement") + //val EditPlacementButton: Button = new Button(desc.style.labelPosition.toString) + //contents += EditPlacementButton + //contents += new Label("Example") + //contents += new Label("Example here") + listenTo(EditShapeButton, EditColourButton) + reactions += { + case ButtonClicked(EditShapeButton) => + chooseEdgeWidth(edgeName, desc.style.strokeWidth) + case ButtonClicked(EditColourButton) => + chooseEdgeColour(edgeName, desc.style.strokeColor) + //case ButtonClicked(EditPlacementButton) => + // chooseLabelPlacement(name, desc.style.labelPosition) + } + } + } + + createPageComponents() + + + add(TopScrollablePane, BorderPanel.Position.Center) + + add(Toolbar, BorderPanel.Position.North) + + listenTo(document) + + reactions += { + case TheoryChanged() => + alert("Theory changed", Elevation.DEBUG) + createPageComponents() + + } +} + +object TheoryEditPanel { + + + implicit private def valueTypeVectorToString(vs: Vector[ValueType]) : String = + vs.map(valueTypesAsHumanReadable).mkString("",", ","") + + def colourMute(c: Color): Color = { + def m(i: Int): Int = math.floor(i * 1).toInt + + new Color(m(c.getRed), m(c.getGreen), m(c.getBlue)) + } + + val valueTypesAsHumanReadable: Map[ValueType, String] = ValueType.values.toList map (v => + v -> (v match { + case ValueType.AngleExpr => "angle" + case ValueType.Boolean => "boolean" + case ValueType.Integer => "integer" + case ValueType.Rational => "rational" + case ValueType.String => "string" + case ValueType.Long => "long" + case ValueType.Enum => "string" + case ValueType.Empty => "empty" + })) toMap + + val approvedColours: Map[String, Color] = Map( + "red" -> colourMute(new Color(255, 0, 0)), + "green" -> colourMute(new Color(0, 255, 0)), + "blue" -> colourMute(new Color(0, 0, 255)), + "white" -> colourMute(new Color(255, 255, 255)), + "black" -> colourMute(new Color(0, 0, 0)), + "yellow" -> colourMute(new Color(255, 255, 0)), + "magenta" -> colourMute(new Color(255, 0, 255)), + "cyan" -> colourMute(new Color(0, 255, 255)) + ) + +} \ No newline at end of file diff --git a/scala/src/main/scala/quanto/gui/Transformer.scala b/scala/src/main/scala/quanto/gui/Transformer.scala index ff75181c..dbed175c 100644 --- a/scala/src/main/scala/quanto/gui/Transformer.scala +++ b/scala/src/main/scala/quanto/gui/Transformer.scala @@ -3,11 +3,12 @@ package quanto.gui class Transformer( var scale: Double = 50.0f, - var origin: (Double, Double) = (250.0f, 250.0f) + var screenDrawOrigin: (Double, Double) = (250f, 250f) ) { - def toScreen(pt: (Double,Double)) = (pt._1 * scale + origin._1, origin._2 - pt._2 * scale) - def fromScreen(pt: (Double,Double)) = ((pt._1 - origin._1) / scale, ((-1 * pt._2) + origin._2) / scale) + // screenDrawOrigin is where _on the screen_ _in pixels_ we want the origin to be drawn + def toScreen(pt: (Double,Double)) = (pt._1 * scale + screenDrawOrigin._1, screenDrawOrigin._2 - pt._2 * scale) + def fromScreen(pt: (Double,Double)) = ((pt._1 - screenDrawOrigin._1) / scale, ((-1 * pt._2) + screenDrawOrigin._2) / scale) def scaleToScreen(x: Double) = x * scale def scaleFromScreen(x: Double) = x / scale diff --git a/scala/src/main/scala/quanto/gui/graphview/BBoxDisplayData.scala b/scala/src/main/scala/quanto/gui/graphview/BBoxDisplayData.scala index 34ec21b4..f984093b 100644 --- a/scala/src/main/scala/quanto/gui/graphview/BBoxDisplayData.scala +++ b/scala/src/main/scala/quanto/gui/graphview/BBoxDisplayData.scala @@ -29,17 +29,17 @@ trait BBoxDisplayData { self: VertexDisplayData => protected def computeBBoxDisplay() { bboxDisplay.clear() - var offset = Math.max(boundsForVertexSet(graph.verts).getMaxX, trans.origin._1) + var offset = Math.max(boundsForVertexSet(graph.verts).getMaxX, trans.screenDrawOrigin._1) // used to compute relative padding sizes - val em = trans.scaleToScreen(0.1) + val em = trans.scaleToScreen(0.25) graph.bboxesChildrenFirst.foreach { bbox => val vset = graph.contents(bbox) val rect = if (vset.isEmpty) { - offset += 8*em - new Rectangle2D.Double(offset, trans.origin._2 - 2*em, 4*em, 4*em) + offset += 4.0*em + new Rectangle2D.Double(offset, trans.screenDrawOrigin._2 - 2*em, 4*em, 4*em) } else { /* bounds determined by vertices of bbox */ @@ -53,18 +53,18 @@ trait BBoxDisplayData { self: VertexDisplayData => val rect = bbd.rect if(bounds.contains(rect)) { - val ulx = min(rect.getMinX - 5.0*em, bounds.getMinX) - val uly = min(rect.getMinY - 5.0*em, bounds.getMinY) - val lrx = max(rect.getMaxX + 5.0*em, bounds.getMaxX) - val lry = max(rect.getMaxY + 5.0*em, bounds.getMaxY) + val ulx = min(rect.getMinX - em, bounds.getMinX) + val uly = min(rect.getMinY - em, bounds.getMinY) + val lrx = max(rect.getMaxX + em, bounds.getMaxX) + val lry = max(rect.getMaxY + em, bounds.getMaxY) bounds = new Rectangle2D.Double(ulx, uly, lrx - ulx, lry - uly) } else if(rect.contains(bounds)) { - val ulx = min(rect.getMinX, bounds.getMinX - 5.0*em) - val uly = min(rect.getMinY, bounds.getMinY - 5.0*em) - val lrx = max(rect.getMaxX, bounds.getMaxX + 5.0*em) - val lry = max(rect.getMaxY, bounds.getMaxY + 5.0*em) + val ulx = min(rect.getMinX, bounds.getMinX - em) + val uly = min(rect.getMinY, bounds.getMinY - em) + val lrx = max(rect.getMaxX, bounds.getMaxX + em) + val lry = max(rect.getMaxY, bounds.getMaxY + em) val new_bounds = new Rectangle2D.Double(ulx, uly, lrx - ulx, lry - uly) bboxDisplay += bb -> BBDisplay(new_bounds) diff --git a/scala/src/main/scala/quanto/gui/graphview/EdgeDisplayData.scala b/scala/src/main/scala/quanto/gui/graphview/EdgeDisplayData.scala index 8f80c395..9e6f0414 100644 --- a/scala/src/main/scala/quanto/gui/graphview/EdgeDisplayData.scala +++ b/scala/src/main/scala/quanto/gui/graphview/EdgeDisplayData.scala @@ -3,9 +3,18 @@ package quanto.gui.graphview import quanto.data._ import quanto.util.RichCubicCurve._ import java.awt.geom._ + +import quanto.data.Theory.EdgeDesc +import quanto.util.UserOptions + import math._ +import scala.swing.Color -case class EDisplay(path: Path2D.Double, lines: List[Line2D.Double], label: Option[LabelDisplayData]) { +case class EDisplay(path: Path2D.Double, + width: Int, + color: Color, + lines: List[Line2D.Double], + label: Option[LabelDisplayData]) { def pointHit(pt: Point2D) = { lines exists { l => //println("line starts " + l.getP1 + " ends " + l.getP2 + ", distance to " + pt + " is " + l.ptSegDistSq(pt)) @@ -23,6 +32,7 @@ trait EdgeDisplayData { self: GraphView with VertexDisplayData => for ((v1,sd) <- graph.vdata; (v2,td) <- graph.vdata if v1 <= v2) { val edges = graph.source.codf(v1) intersect graph.target.codf(v2) val rEdges = if (v1 == v2) Set[EName]() else graph.target.codf(v1) intersect graph.source.codf(v2) + // Count total edges v1 -> v2, and reverse edges (v2 -> v1) val numEdges = edges.size + rEdges.size if (numEdges != 0) { @@ -115,7 +125,7 @@ trait EdgeDisplayData { self: GraphView with VertexDisplayData => } val labelDisplay = edgeData.typeInfo.value.typ match { - case Theory.ValueType.String => + case Vector(Theory.ValueType.String) => val fm = peer.getGraphics.getFontMetrics(GraphView.EdgeLabelFont) val text = edgeData.label if (text == "") None @@ -126,9 +136,10 @@ trait EdgeDisplayData { self: GraphView with VertexDisplayData => edgeData.typeInfo.style.labelBackgroundColor)) case _ => None } + val width = UserOptions.scaleInt(edgeData.typeInfo.style.strokeWidth) + val color = edgeData.typeInfo.style.strokeColor - - edgeDisplay(e) = EDisplay(p, lines, labelDisplay) + edgeDisplay(e) = EDisplay(p, width, color, lines, labelDisplay) i += 1 } diff --git a/scala/src/main/scala/quanto/gui/graphview/GraphView.scala b/scala/src/main/scala/quanto/gui/graphview/GraphView.scala index 8777b276..f4751913 100644 --- a/scala/src/main/scala/quanto/gui/graphview/GraphView.scala +++ b/scala/src/main/scala/quanto/gui/graphview/GraphView.scala @@ -15,7 +15,8 @@ import java.awt.font.TextLayout import java.io.File import quanto.util.FileHelper.printToFile -import quanto.util.UserOptions +import quanto.util.UserAlerts.alert +import quanto.util.{UserAlerts, UserOptions} import quanto.util.json.JsonString @@ -60,6 +61,7 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel // gets called when the component is first painted lazy val init = { + focusOnGraph() resizeViewToFit() } @@ -92,11 +94,39 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel private var _zoom = 1.0 def zoom = _zoom def zoom_=(d: Double) { + // log where the middle of the screen currently is _zoom = d trans.scale = _zoom * 50.0 - trans.origin = (_zoom * 250.0, _zoom * 250.0) - invalidateGraph(clearSelection = false) resizeViewToFit() + invalidateGraph(false) + repaint() + } + +private def viewportOffset (): (Double, Double) = { + val viewRect = peer.getVisibleRect + (viewRect.width / 2, viewRect.height / 2) +} + + def graphFocus : (Double, Double) = { + val viewShift = viewportOffset() + trans.fromScreen(viewShift) + } + + private def renderOriginAt(x : Double, y: Double): Unit ={ + trans.screenDrawOrigin = (x, y) + } + + def graphFocus_=(pointOnGraph : (Double, Double)) : Unit = { + + val screenFocus = (trans scaleToScreen pointOnGraph._1, trans scaleToScreen pointOnGraph._2) + + val viewShift = viewportOffset() + + renderOriginAt(0-screenFocus._1 + viewShift._1, 0-screenFocus._2 + viewShift._2) + + resizeViewToFit() + invalidateGraph(false) + revalidate() repaint() } @@ -142,6 +172,8 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel } def resizeViewToFit() { + + val bufferFocus = graphFocus // top left and bottom right of bounds, in screen coordinates val graphTopLeft = graph.vdata.foldLeft(0.0,0.0) { (c,v) => (min(c._1, v._2.coord._1), max(c._2, v._2.coord._2)) @@ -152,14 +184,14 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel } match {case (x,y) => trans toScreen (x + 1.0, y - 1.0)} // default bounds, based on the current position of the origin and the size of the visible region - val vRect = peer.getVisibleRect - val defaultTopLeft = (trans.origin._1 - (vRect.getWidth/2.0), trans.origin._2 - (vRect.getHeight/2.0)) - val defaultBottomRight = (trans.origin._1 + (vRect.getWidth/2.0), trans.origin._2 + (vRect.getHeight/2.0)) + val offset = viewportOffset() + val screenTopLeft = (trans.screenDrawOrigin._1 - offset._1, trans.screenDrawOrigin._2 - offset._2) + val screenBottomRight = (trans.screenDrawOrigin._1 + offset._1, trans.screenDrawOrigin._2 + offset._2) - val topLeft = (min(graphTopLeft._1, defaultTopLeft._1), - min(graphTopLeft._2, defaultTopLeft._2)) - val bottomRight = (max(graphBottomRight._1, defaultBottomRight._1), - max(graphBottomRight._2, defaultBottomRight._2)) + val topLeft = (min(graphTopLeft._1, screenTopLeft._1), + min(graphTopLeft._2, screenTopLeft._2)) + val bottomRight = (max(graphBottomRight._1, screenBottomRight._1), + max(graphBottomRight._2, screenBottomRight._2)) val (w,h) = (bottomRight._1 - topLeft._1, bottomRight._2 - topLeft._2) @@ -168,10 +200,12 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel (preferredSize.width != w) || (preferredSize.height != h) if (changed) { - trans.origin = (trans.origin._1 - topLeft._1, trans.origin._2 - topLeft._2) + trans.screenDrawOrigin = (trans.screenDrawOrigin._1 - topLeft._1, trans.screenDrawOrigin._2 - topLeft._2) + //graphFocus = bufferFocus preferredSize = new Dimension(w.toInt, h.toInt) invalidateGraph(clearSelection = false) revalidate() + repaint() } } @@ -190,6 +224,48 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel super.repaint() } + def focusOnGraph(): Unit = { + // top left and bottom right of bounds, in screen coordinates + val graphTopLeft = graph.vdata.foldLeft(0.0, 0.0) { (c, v) => + (min(c._1, v._2.coord._1), max(c._2, v._2.coord._2)) + } + + + val graphBottomRight = graph.vdata.foldLeft((0.0, 0.0)) { (c, v) => + (max(c._1, v._2.coord._1), min(c._2, v._2.coord._2)) + } + + + val defaultTopLeft = (-5,5) + val defaultBottomRight = (5,-5) + + val topLeft = (min(graphTopLeft._1, defaultTopLeft._1), + max(graphTopLeft._2, defaultTopLeft._2)) + val bottomRight = (max(graphBottomRight._1, defaultBottomRight._1), + min(graphBottomRight._2, defaultBottomRight._2)) + + val (w, h) = (bottomRight._1 - topLeft._1, + topLeft._2 - bottomRight._2) + + val graphCentre = ((topLeft._1 + bottomRight._1) / 2, (topLeft._2 + bottomRight._2) / 2) + + // default bounds, based on the current position of the origin and the size of the visible region + val vRect = peer.getVisibleRect + + val widthOnScreen = trans scaleToScreen w + val heightOnScreen = trans scaleToScreen h + + val widthChange = widthOnScreen / vRect.width + val heightChange = heightOnScreen / vRect.height + val zoomChangeNeeded = max(widthChange, heightChange) + + zoom = 0.9*zoom / zoomChangeNeeded + + //trans.scale = trans.scale * zoomChangeNeeded + + graphFocus = graphCentre + } + override def paintComponent(g: Graphics2D) { super.paintComponent(g) g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON) @@ -265,16 +341,10 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel for ((e, ed) <- edgeDisplay) { if (selectedEdges contains e) { g.setColor(Color.BLUE) - g.setStroke(new BasicStroke(2)) + g.setStroke(new BasicStroke(ed.width)) } else { - if (graph.edata(e).isDirected) { - g.setColor(Color.GRAY) - g.setStroke(new BasicStroke(1)) - } else { - g.setColor(Color.BLACK) - g.setStroke(new BasicStroke(2)) - } - + g.setColor(ed.color) + g.setStroke(new BasicStroke(ed.width)) } g.draw(ed.path) @@ -294,7 +364,7 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel g.setStroke(new BasicStroke(1)) var a = g.getColor - for ((v, VDisplay(shape,color,label)) <- vertexDisplay) { + for ((v, VDisplay(shape, strokeWidth, color,label)) <- vertexDisplay) { /* draw red line if vertex coordinates are within !-box rectangle * but the vertex is not a member of the !-box and also write @@ -343,10 +413,10 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel if (selectedVerts contains v) { g.setColor(Color.BLUE) - g.setStroke(new BasicStroke(2)) + g.setStroke(new BasicStroke(strokeWidth + 1)) } else { g.setColor(Color.BLACK) - g.setStroke(new BasicStroke(1)) + g.setStroke(new BasicStroke(strokeWidth)) } g.draw(shape) @@ -501,9 +571,12 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel } } + val names = scala.collection.mutable.Map[String,Int]() + /* Output view to a tikzit-readable file */ printToFile(f, append)(p => { - p.println("\\begin{tikzpicture}[baseline={([yshift=-.5ex]current bounding box.center)}]") + //p.println("\\begin{tikzpicture}[baseline={([yshift=-.5ex]current bounding box.center)}]") + p.println("\\begin{tikzpicture}[quanto]") p.println("\t\\begin{pgfonlayer}{nodelayer}") /* fill in all vertices */ @@ -513,7 +586,9 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel case _ : WireV => "wire" } - val number = vn.toString + names += vn.toString -> names.size + + val number = names.size - 1 val disp_rec = vertexDisplay(vn).shape.getBounds val trans_coord = trans.fromScreen(disp_rec.getCenterX, disp_rec.getCenterY) min_max(trans_coord) @@ -530,25 +605,29 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel /* fill in corners of !-boxes */ for ((bbn,bbd) <- bboxDisplay) { - val number_ul = bbn.toString + "ul" + names += (bbn.toString + "ul") -> names.size + val number_ul = names.size - 1 val trans_coord_ul = trans.fromScreen(bbd.rect.getMinX, bbd.rect.getMinY) min_max(trans_coord_ul) val coord_ul = coordToString(trans_coord_ul) p.println("\t\t\\node [style=bbox] (" + number_ul + ") at " + coord_ul + " {};") - val number_ur = bbn.toString + "ur" + names += (bbn.toString + "ur") -> names.size + val number_ur = names.size - 1 val trans_coord_ur = trans.fromScreen(bbd.rect.getMaxX, bbd.rect.getMinY) min_max(trans_coord_ur) val coord_ur = coordToString(trans_coord_ur) p.println("\t\t\\node [style=none] (" + number_ur + ") at " + coord_ur + " {};") - val number_ll = bbn.toString + "ll" + names += (bbn.toString + "ll") -> names.size + val number_ll = names.size - 1 val trans_coord_ll = trans.fromScreen(bbd.rect.getMinX, bbd.rect.getMaxY) min_max(trans_coord_ll) val coord_ll = coordToString(trans_coord_ll) p.println("\t\t\\node [style=none] (" + number_ll + ") at " + coord_ll + " {};") - val number_lr = bbn.toString + "lr" + names += (bbn.toString + "lr") -> names.size + val number_lr = names.size - 1 val trans_coord_lr = trans.fromScreen(bbd.rect.getMaxX, bbd.rect.getMaxY) min_max(trans_coord_lr) val coord_lr = coordToString(trans_coord_lr) @@ -556,16 +635,16 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel } /* output 4 nodes used for padding*/ - val pad_size = 1.0 - minX -= pad_size - maxX += pad_size - minY -= pad_size - maxY += pad_size - - p.println("\t\t\\node [style=none] (padl) at " + coordToString(minX, minY) + " {};") - p.println("\t\t\\node [style=none] (padr) at " + coordToString(maxX, maxY) + " {};") - p.println("\t\t\\node [style=none] (padu) at " + coordToString(minX, maxY) + " {};") - p.println("\t\t\\node [style=none] (padd) at " + coordToString(maxX, minY) + " {};") +// val pad_size = 1.0 +// minX -= pad_size +// maxX += pad_size +// minY -= pad_size +// maxY += pad_size +// +// p.println("\t\t\\node [style=none] (padl) at " + coordToString(minX, minY) + " {};") +// p.println("\t\t\\node [style=none] (padr) at " + coordToString(maxX, maxY) + " {};") +// p.println("\t\t\\node [style=none] (padu) at " + coordToString(minX, maxY) + " {};") +// p.println("\t\t\\node [style=none] (padd) at " + coordToString(maxX, minY) + " {};") p.println("\t\\end{pgfonlayer}") @@ -574,9 +653,9 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel /* fill in all graph edges */ for (edge_set <- graph.edgePartition) { val edge_arr : Array[EName] = edge_set.toArray - val size = edge_arr.size - val canonical_source = graph.source(edge_arr(0)).toString - val canonical_target = graph.target(edge_arr(0)).toString + val size = edge_arr.length + val canonical_source = names(graph.source(edge_arr(0)).toString) + val canonical_target = names(graph.target(edge_arr(0)).toString) if (canonical_source != canonical_target) { /* edges are between different nodes */ @@ -588,7 +667,9 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel val en = edge_arr(0) val ed = graph.edata(en) val style = if (ed.isDirected) "directed" else "simple" - p.println("\t\t\\draw [style=" + style + "] (" + graph.source(en).toString + ") to (" + graph.target(en).toString + ");" ) + p.println("\t\t\\draw [style=" + style + "] (" + + names(graph.source(en).toString) + ") to (" + + names(graph.target(en).toString) + ");" ) start = 1 } @@ -597,13 +678,13 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel var right_left = "left=" /* draw the rest of the edges as arcs by setting the bend angle */ - for (i <- start to size-1) { + for (i <- start until size) { val en = edge_arr(i) val ed = graph.edata(en) val style = if (ed.isDirected) "directed" else "simple" val angle = angle_it.toString - val source = graph.source(en).toString - val target = graph.target(en).toString + val source = names(graph.source(en).toString) + val target = names(graph.target(en).toString) /* alternate bending left or right */ if ((i-start) % 2 == 0) { @@ -625,7 +706,7 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel /* edges have same source and target, so we need to output loops */ var looseness = 4.5 - for (i <- 0 to size-1) { + for (i <- 0 until size) { val en = edge_arr(i) val ed = graph.edata(en) val style = if (ed.isDirected) "directed" else "simple" @@ -643,10 +724,10 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel for ((bbn, _) <- bboxDisplay) { /* fill in edges connecting !-box corners */ - val number_ul = bbn.toString + "ul.center" - val number_ur = bbn.toString + "ur.center" - val number_ll = bbn.toString + "ll.center" - val number_lr = bbn.toString + "lr.center" + val number_ul = names(bbn.toString + "ul") + ".center" + val number_ur = names(bbn.toString + "ur") + ".center" + val number_ll = names(bbn.toString + "ll") + ".center" + val number_lr = names(bbn.toString + "lr") + ".center" p.println("\t\t\\draw [style=blue] (" + number_ul + ") to (" + number_ur + ");" ) p.println("\t\t\\draw [style=blue] (" + number_ul + ") to (" + number_ll + ");" ) p.println("\t\t\\draw [style=blue] (" + number_ll + ") to (" + number_lr + ");" ) @@ -654,10 +735,9 @@ class GraphView(val theory: Theory, gRef: HasGraph) extends Panel /* draw edges indicating nested !-boxes */ graph.bboxParent.get(bbn) match { - case Some(bb_parent) => { - val parent_number_ul = bb_parent.toString + "ul.center" + case Some(bb_parent) => + val parent_number_ul = names(bb_parent.toString + "ul") + ".center" p.println("\t\t\\draw [style=blue] (" + number_ul + ") to (" + parent_number_ul + ");" ) - } case None => } } diff --git a/scala/src/main/scala/quanto/gui/graphview/VertexDisplayData.scala b/scala/src/main/scala/quanto/gui/graphview/VertexDisplayData.scala index 8afa91a6..e310794b 100644 --- a/scala/src/main/scala/quanto/gui/graphview/VertexDisplayData.scala +++ b/scala/src/main/scala/quanto/gui/graphview/VertexDisplayData.scala @@ -1,13 +1,15 @@ package quanto.gui.graphview -import java.awt.geom.{Rectangle2D, Ellipse2D, Point2D} -import java.awt.{FontMetrics, Color, Shape} +import java.awt.geom.{Ellipse2D, Point2D, Rectangle2D} +import java.awt.{Color, FontMetrics, Shape} + import math._ import quanto.data._ import quanto.gui._ import quanto.core.data.TexConstants +import quanto.util.UserOptions -case class VDisplay(shape: Shape, color: Color, label: Option[LabelDisplayData]) { +case class VDisplay(shape: Shape, borderWidth: Int, color: Color, label: Option[LabelDisplayData]) { def pointHit(pt: Point2D) = { val bnd = shape.getBounds2D pt.getX >= bnd.getMinX - GraphView.VertexSelectionTolerence && @@ -52,13 +54,14 @@ trait VertexDisplayData { self: GraphView => protected def computeVertexDisplay() { val trWireWidth = 0.707 * (trans scaleToScreen GraphView.WireRadius) + // Go through every vertex in the graph for ((v,data) <- graph.vdata if !vertexDisplay.contains(v)) { val (x,y) = trans toScreen data.coord vertexDisplay(v) = data match { case vertexData : NodeV => val style = vertexData.typeInfo.style - val label = if (vertexData.hasAngle) vertexData.angle.toString else vertexData.value + val label = if (vertexData.hasValue) vertexData.phaseData.toString else vertexData.value val text = if(zoom < GraphView.zoomCutOut && label != "") "~" else TexConstants.translate(label) /*vertexData.typeInfo.value.typ match { @@ -66,19 +69,26 @@ trait VertexDisplayData { self: GraphView => case _ => "" }*/ + // Build the label val fm = peer.getGraphics.getFontMetrics(GraphView.VertexLabelFont) val labelDisplay = LabelDisplayData( - text, (x,y), fm, + text, (x, y), fm, vertexData.typeInfo.style.labelForegroundColor, vertexData.typeInfo.style.labelBackgroundColor) - + // Build the shape around the label val shape = style.shape match { case Theory.VertexShape.Rectangle => + val buffer = trans.scaleToScreen(GraphView.NodeTextPadding) + val height = labelDisplay.bounds.getHeight + buffer + val widthFromLabel = labelDisplay.bounds.getWidth + buffer + // Default to square if no data, and stretch horizontally if needed + val width = max(widthFromLabel, height) + + val x = labelDisplay.bounds.getMinX - (width - labelDisplay.bounds.getWidth) / 2.0 + val y = labelDisplay.bounds.getMinY - (height - labelDisplay.bounds.getHeight) / 2.0 - new Rectangle2D.Double( - labelDisplay.bounds.getMinX - 5.0, labelDisplay.bounds.getMinY - 3.0, - labelDisplay.bounds.getWidth + 10.0, labelDisplay.bounds.getHeight + 6.0) + new Rectangle2D.Double(x, y, width, height) case Theory.VertexShape.Circle => // radius should fit to label if required val r = max( @@ -86,30 +96,38 @@ trait VertexDisplayData { self: GraphView => trans.scaleToScreen(GraphView.NodeRadius) ) - new Ellipse2D.Double( - labelDisplay.bounds.getCenterX - r, - labelDisplay.bounds.getCenterY -r, - 2.0 * r, 2.0 * r) + val midX = labelDisplay.bounds.getCenterX - r + val midY = labelDisplay.bounds.getCenterY - r + new Ellipse2D.Double(midX, midY, 2.0 * r, 2.0 * r) case _ => throw new Exception("Shape not supported yet") } - VDisplay(shape, style.fillColor, Some(labelDisplay)) + VDisplay(shape, style.strokeWidth, style.fillColor, Some(labelDisplay)) case _: WireV => VDisplay( new Rectangle2D.Double( x - trWireWidth, y - trWireWidth, 2.0 * trWireWidth, 2.0 * trWireWidth), - Color.GRAY,None) + 1, + Color.GRAY, None) } } } - protected def boundsForVertexSet(vset: Set[VName]) = { + protected def boundsForVertexSet(vset: Set[VName]): Rectangle2D.Double = { var init = false var ulx,uly,lrx,lry = 0.0 + val em = trans.scaleToScreen(0.25) + vset.foreach { v => - val rect = vertexDisplay(v).shape.getBounds + // grow the bounding box until it snaps to the grid + val bds = vertexDisplay(v).shape.getBounds + val rx = Math.ceil(bds.width / (2.0 * em)) * em + val ry = Math.ceil(bds.height / (2.0 * em)) * em + val p = trans toScreen graph.vdata(v).coord + val rect = new Rectangle2D.Double(p._1 - rx, p._2 - ry, 2.0 * rx, 2.0 * ry) + if (init) { ulx = min(ulx, rect.getX) uly = min(uly, rect.getY) @@ -125,9 +143,8 @@ trait VertexDisplayData { self: GraphView => } val bounds = new Rectangle2D.Double(ulx, uly, lrx - ulx, lry - uly) - val em = trans.scaleToScreen(0.1) - val p = (bounds.getX - 3*em, bounds.getY - 3*em) - val q = (bounds.getWidth + 6*em, bounds.getHeight + 6*em) + val p = (bounds.getX - em, bounds.getY - em) + val q = (bounds.getWidth + 2*em, bounds.getHeight + 2*em) new Rectangle2D.Double(p._1, p._2, q._1, q._2) } diff --git a/scala/src/main/scala/quanto/gui/histview/HistView.scala b/scala/src/main/scala/quanto/gui/histview/HistView.scala index d92b74db..aa7686c7 100644 --- a/scala/src/main/scala/quanto/gui/histview/HistView.scala +++ b/scala/src/main/scala/quanto/gui/histview/HistView.scala @@ -31,11 +31,10 @@ class HistView[A <: HistNode](data: TreeSeq[A]) extends ListView[(Seq[TreeSeq.De } renderer = new ListView.Renderer[(Seq[TreeSeq.Decoration[A]], A)] { - def componentFor(list: ListView[_], isSelected: Boolean, - focused: Boolean, a: (Seq[TreeSeq.Decoration[A]], A), index: Int): Component = + override def componentFor(list: ListView[_ <: (Seq[TreeSeq.Decoration[A]], A)], isSelected: Boolean, focused: Boolean, a: (Seq[TreeSeq.Decoration[A]], A), index: Int): Component = { - if (itemWidth == -1) computeItemWidth() - new HistViewItem[A](a._1, a._2, isSelected, new Dimension(itemWidth,scaleInt(30))) + if (itemWidth == -1) computeItemWidth() + new HistViewItem[A](a._1, a._2, isSelected, new Dimension(itemWidth,scaleInt(30))) } } @@ -48,6 +47,10 @@ class HistView[A <: HistNode](data: TreeSeq[A]) extends ListView[(Seq[TreeSeq.De selectIndices() } + def selectedIndex() : Int = { + peer.getSelectedIndex + } + def selectedNode: Option[A] = if (selection.indices.isEmpty) None else Some(treeData.toSeq(selection.indices.head)) diff --git a/scala/src/main/scala/quanto/layout/Constraints.scala b/scala/src/main/scala/quanto/layout/Constraints.scala index b7a1ad37..98a71552 100644 --- a/scala/src/main/scala/quanto/layout/Constraints.scala +++ b/scala/src/main/scala/quanto/layout/Constraints.scala @@ -1,58 +1,25 @@ package quanto.layout -import quanto.data._ -import quanto.data.Graph -import quanto.data.VName +import quanto.data.{Graph, VName} class ConstraintException(msg: String) extends Exception(msg) /** - * A mixin for GraphLayouts which provides distance-based constraint functionality as in - * [1] "Scalable, Versatile and Simple Constrained Graph Layout", Dwyer 2009 - */ + * A mixin for GraphLayouts which provides distance-based constraint functionality as in + * [1] "Scalable, Versatile and Simple Constrained Graph Layout", Dwyer 2009 + */ trait Constraints extends GraphLayout { - def alpha: Double // cooling factor for soft constraints + val constraints = new ConstraintSeq var constraintIterations = 10 var bug = false - val constraints = new ConstraintSeq + + def alpha: Double // cooling factor for soft constraints override def initialize(g: Graph, randomCoords: Boolean = true) { super.initialize(g, randomCoords) constraints.clear() } - def isConstraintSatisfied (c: Constraint): Boolean = { - val (v1x,v1y) = coord(c.v1) - val (v2x,v2y) = coord(c.v2) - val (v1v2x,v1v2y) = (v2x-v1x, v2y-v1y) - - val dir = c.direction match { - case Some(d) => d - case None => (0.0,0.0) - } - - val length = projectLength((v1v2x,v1v2y),dir) - if (c.order == 0) projectLength((v1v2x,v1v2y),dir) == c.length - else if(c.order > 0) projectLength((v1v2x,v1v2y),dir) >= c.length - else projectLength((v1v2x,v1v2y),dir) <= c.length - - } - // project v1 on v2 // direction is unit vector - def projectVector (v1 :(Double,Double), v2 : (Double,Double)) : (Double,Double)= { - val length = vectorProduct(v1,v2) - (length * v2._1, length * v2._2) - } - - def projectLength (v1 :(Double,Double), v2 : (Double,Double)) : Double = { - val length = vectorProduct(v1,v2) - length - } - - def vectorProduct (v1 : (Double,Double), v2 : (Double, Double)) : Double = { - v1._1*v2._1+v1._2*v2._2 - } - - def projectConstraints() { var feasible = false // flag for all constraints satisfied var maxLayer = constraints.currentLayer @@ -65,78 +32,80 @@ trait Constraints extends GraphLayout { iteration = 0 } - for ((constraint,layer) <- constraints; if layer <= maxLayer) { - - val (p1,p2) = (coord(constraint.v1), coord(constraint.v2)) - + for ((constraint, layer) <- constraints; if layer <= maxLayer) { + + val (p1, p2) = (coord(constraint.v1), coord(constraint.v2)) + val shift = constraint.direction match { case Some(dir) => - // val (dx,dy) = ((p2._1 - p1._1) * dir._1, (p2._2 - p1._2) * dir._2) + // val (dx,dy) = ((p2._1 - p1._1) * dir._1, (p2._2 - p1._2) * dir._2) // the coordinates of vector projected on direction - val p1p2 = (p2._1 - p1._1,p2._2 - p1._2) - val (dx,dy) = projectVector((p1p2._1, p1p2._2), dir) - + val p1p2 = (p2._1 - p1._1, p2._2 - p1._2) + val (dx, dy) = projectVector((p1p2._1, p1p2._2), dir) + // Add direction to the distance // if the angle is acute angle then we need to do nothing // if the angle is obtuse angle then we need to move on of the vertex. - if (vectorProduct(p1p2,dir)<0) { + if (vectorProduct(p1p2, dir) < 0) { // swap the two nodes -// val temp = coord(constraint.v1); -// setCoord(constraint.v1, coord(constraint.v2)) -// setCoord(constraint.v2,temp) - // or we can just move p2 to another side - val (nx,ny) = (2*p1._1 - p2._1, 2*p1._2 - p2._2); - setCoord(constraint.v2, (nx,ny)) -// + // val temp = coord(constraint.v1); + // setCoord(constraint.v1, coord(constraint.v2)) + // setCoord(constraint.v2,temp) + // or we can just move p2 to another side + val (nx, ny) = (2 * p1._1 - p2._1, 2 * p1._2 - p2._2) + setCoord(constraint.v2, (nx, ny)) + // } -// // if it is 0 then they are perpendicular + // // if it is 0 then they are perpendicular // this constraint cannot be projected. - if ((vectorProduct(p1p2,dir)==0.0)&& !isConstraintSatisfied(constraint)){ - if (bug) println("constraint " + constraint) - if (bug) println("vector p1 p2 is " + p1p2) - if (bug) println("vector direction " + dir) - - val (v2x,v2y) = coord(constraint.v2) - setCoord(constraint.v2, (v2x+dir._1/50,v2y+dir._2/50)) + if ((vectorProduct(p1p2, dir) == 0.0) && !isConstraintSatisfied(constraint)) { + if (bug) println("constraint " + constraint) + if (bug) println("vector p1 p2 is " + p1p2) + if (bug) println("vector direction " + dir) + + val (v2x, v2y) = coord(constraint.v2) + setCoord(constraint.v2, (v2x + dir._1 / 50, v2y + dir._2 / 50)) } - + val ideal = (dir._1 * constraint.length, dir._2 * constraint.length) ( - if ((constraint.order == 0 && dx != ideal._1) || - (constraint.order == -1 && dx > ideal._1) || - (constraint.order == 1 && dx < ideal._1)) + if ((constraint.order == 0 && dx != ideal._1) || + (constraint.order == -1 && dx > ideal._1) || + (constraint.order == 1 && dx < ideal._1)) //if (!isConstraintSatisfied(constraint)) - + dx - ideal._1 - //ideal._1 - dx + //ideal._1 - dx else 0, - if ((constraint.order == 0 && dy != ideal._2) || - (constraint.order == -1 && dy > ideal._2) || - (constraint.order == 1 && dy < ideal._2)) + if ((constraint.order == 0 && dy != ideal._2) || + (constraint.order == -1 && dy > ideal._2) || + (constraint.order == 1 && dy < ideal._2)) dy - ideal._2 - //ideal._2 - dy + //ideal._2 - dy else 0 ) - + case None => - val (dx,dy) = (p2._1 - p1._1, p2._2 - p1._2) - val length = math.sqrt(dx*dx + dy*dy) - val dir = if (length != 0) (dx/length, dy/length) else (1.0,0.0) + val (dx, dy) = (p2._1 - p1._1, p2._2 - p1._2) + val length = math.sqrt(dx * dx + dy * dy) + val dir = if (length != 0) (dx / length, dy / length) else (1.0, 0.0) - if ((constraint.order == 0 && length != constraint.length) || - (constraint.order == -1 && length > constraint.length) || - (constraint.order == 1 && length < constraint.length)) + if ((constraint.order == 0 && length != constraint.length) || + (constraint.order == -1 && length > constraint.length) || + (constraint.order == 1 && length < constraint.length)) (dir._1 * (length - constraint.length), dir._2 * (length - constraint.length)) - else (0.0,0.0) + else (0.0, 0.0) } - val (shiftX, shiftY) = shift match { case (x,y) => if (constraint.soft) (x * alpha, y * alpha) else (x,y) } + val (shiftX, shiftY) = shift match { + case (x, y) => if (constraint.soft) (x * alpha, y * alpha) else (x, y) + } if (shiftX != 0.0 || shiftY != 0.0) { feasible = false - + setCoord(constraint.v1, (p1._1 + (constraint.mv1 * shiftX), p1._2 + (constraint.mv1 * shiftY))) setCoord(constraint.v2, (p2._1 - (constraint.mv2 * shiftX), p2._2 - (constraint.mv2 * shiftY))) } @@ -145,25 +114,62 @@ trait Constraints extends GraphLayout { iteration += 1 } -// if (feasible) { -// println("feasible solution found after " + iteration + " iterations") -// } else { -// println("no feasible solution") -// } + // if (feasible) { + // println("feasible solution found after " + iteration + " iterations") + // } else { + // println("no feasible solution") + // } + } + + def isConstraintSatisfied(c: Constraint): Boolean = { + val (v1x, v1y) = coord(c.v1) + val (v2x, v2y) = coord(c.v2) + val (v1v2x, v1v2y) = (v2x - v1x, v2y - v1y) + + val dir = c.direction match { + case Some(d) => d + case None => (0.0, 0.0) + } + + if (c.order == 0) projectLength((v1v2x, v1v2y), dir) == c.length + else if (c.order > 0) projectLength((v1v2x, v1v2y), dir) >= c.length + else projectLength((v1v2x, v1v2y), dir) <= c.length + + } + + def projectLength(v1: (Double, Double), v2: (Double, Double)): Double = { + val length = vectorProduct(v1, v2) + length + } + + // project v1 on v2 // direction is unit vector + def projectVector(v1: (Double, Double), v2: (Double, Double)): (Double, Double) = { + val length = vectorProduct(v1, v2) + (length * v2._1, length * v2._2) + } + + def vectorProduct(v1: (Double, Double), v2: (Double, Double)): Double = { + v1._1 * v2._1 + v1._2 * v2._2 } } -class ConstraintSeq extends Iterable[(Constraint,Int)] { +class ConstraintSeq extends Iterable[(Constraint, Int)] { + private val cs = collection.mutable.ListBuffer[() => Iterator[(Constraint, Int)]]() private var _currentLayer = 0 - def currentLayer = _currentLayer - private val cs = collection.mutable.ListBuffer[() => Iterator[(Constraint,Int)]]() - def nextLayer() { _currentLayer += 1 } - def clear() { cs.clear(); _currentLayer = 0 } + def currentLayer: Int = _currentLayer + + def nextLayer() { + _currentLayer += 1 + } + + def clear() { + cs.clear(); _currentLayer = 0 + } def +=(c: Constraint) { val layer = _currentLayer - cs += (() => Iterator((c,layer))) + cs += (() => Iterator((c, layer))) } def ++=(cf: => Iterable[Constraint]) { @@ -171,45 +177,51 @@ class ConstraintSeq extends Iterable[(Constraint,Int)] { cs += (() => cf.iterator.zip(Iterator.continually(layer))) } - def iterator = cs.iterator.map(x => x()).foldLeft(Iterator[(Constraint,Int)]())(_ ++ _) + def iterator: Iterator[(Constraint, Int)] = cs.iterator.map(x => x()).foldLeft(Iterator[(Constraint, Int)]())(_ ++ _) } - // order: which relation with d > < or = -case class Constraint(v1: VName, v2: VName, direction: Option[(Double,Double)], length: Double, w1: Double, w2: Double, order: Int, soft: Boolean) { - lazy val mv1 = if (w1 + w2 != 0.0) w2 / (w1 + w2) else 0.5 - lazy val mv2 = 1.0 - mv1 - + +// order: which relation with d > < or = +case class Constraint(v1: VName, v2: VName, direction: Option[(Double, Double)], length: Double, w1: Double, w2: Double, order: Int, soft: Boolean) { + lazy val mv1: Double = if (w1 + w2 != 0.0) w2 / (w1 + w2) else 0.5 + lazy val mv2: Double = 1.0 - mv1 + } object Constraint { - object distance { - def from(v1: VName) = new DistanceFromExpr(v1) - } - + class DistanceFromExpr(v1: VName) { - def to(v2: VName) = DistanceExpr(v1,v2) + def to(v2: VName) = DistanceExpr(v1, v2) } - - case class DistanceExpr(v1: VName, v2: VName, direction: Option[(Double,Double)] = None, w1: Double = 1.0, w2: Double = 1.0) { - def along(dir: (Double,Double)) = copy(direction = Some(dir)) - - def weighted(w: (Double,Double)) = copy(w1 = w._1, w2 = w._2) + + case class DistanceExpr(v1: VName, v2: VName, direction: Option[(Double, Double)] = None, w1: Double = 1.0, w2: Double = 1.0) { + def along(dir: (Double, Double)): DistanceExpr = copy(direction = Some(dir)) + + def weighted(w: (Double, Double)): DistanceExpr = copy(w1 = w._1, w2 = w._2) + + def <=(len: Double) = Constraint(v1, v2, normalizedDir, len, w1, w2, -1, soft = false) private def normalizedDir = direction.map { - case (0.0,0.0) => throw new ConstraintException("'along' direction must be a non-zero vector") - case (x,y) => - val length = math.sqrt(x*x + y*y) + case (0.0, 0.0) => throw new ConstraintException("'along' direction must be a non-zero vector") + case (x, y) => + val length = math.sqrt(x * x + y * y) // return the normal (x / length, y / length) } - - - def <= (len: Double) = Constraint(v1,v2,normalizedDir,len,w1,w2,-1,soft=false) - def ===(len: Double) = Constraint(v1,v2,normalizedDir,len,w1,w2,0,soft=false) - def >= (len: Double) = Constraint(v1,v2,normalizedDir,len,w1,w2,1,soft=false) - - def ~<= (len: Double) = Constraint(v1,v2,normalizedDir,len,w1,w2,-1,soft=true) - def ~== (len: Double) = Constraint(v1,v2,normalizedDir,len,w1,w2,0,soft=true) - def ~>= (len: Double) = Constraint(v1,v2,normalizedDir,len,w1,w2,1,soft=true) + + def ===(len: Double) = Constraint(v1, v2, normalizedDir, len, w1, w2, 0, soft = false) + + def >=(len: Double) = Constraint(v1, v2, normalizedDir, len, w1, w2, 1, soft = false) + + def ~<=(len: Double) = Constraint(v1, v2, normalizedDir, len, w1, w2, -1, soft = true) + + def ~==(len: Double) = Constraint(v1, v2, normalizedDir, len, w1, w2, 0, soft = true) + + def ~>=(len: Double) = Constraint(v1, v2, normalizedDir, len, w1, w2, 1, soft = true) } + + object distance { + def from(v1: VName) = new DistanceFromExpr(v1) + } + } \ No newline at end of file diff --git a/scala/src/main/scala/quanto/layout/DeriveLayout.scala b/scala/src/main/scala/quanto/layout/DeriveLayout.scala index cfb393cc..863f1b52 100644 --- a/scala/src/main/scala/quanto/layout/DeriveLayout.scala +++ b/scala/src/main/scala/quanto/layout/DeriveLayout.scala @@ -5,21 +5,21 @@ import quanto.layout.constraint._ class DeriveLayout { - def layout(derivation: Derivation) = { + def layout(derivation: Derivation): Derivation = { var steps = Map[DSName, DStep]() val layoutProc = new ForceLayout with Clusters layoutProc.alpha0 = 0.05 layoutProc.keepCentered = false - while (steps.size < derivation.steps.size) derivation.steps.foreach { case (sname, step) => + while (steps.size < derivation.steps.size) derivation.steps.foreach { case (stepName, step) => // try to pull a parent that has already been processed, or root if step has no parent - val parentOpt = derivation.parentMap.get(sname) match { + val parentOpt = derivation.parentMap.get(stepName) match { case Some(p) => steps.get(p).map(_.graph) case None => Some(derivation.root) } - parentOpt.map { parent: Graph => + parentOpt.foreach { parent: Graph => var g = step.graph layoutProc.clearLockedVertices() @@ -27,7 +27,9 @@ class DeriveLayout { // are fresh for parent) for (v <- parent.verts) { if (g.verts.contains(v)) { - g = g.updateVData(v) { _.withCoord(parent.vdata(v).coord) } + g = g.updateVData(v) { + _.withCoord(parent.vdata(v).coord) + } layoutProc.lockVertex(v) } } @@ -35,11 +37,11 @@ class DeriveLayout { layoutProc.initialize(g, randomCoords = false) // relax a bit to layout new coords - for (i <- 1 to 10) layoutProc.step() + for (_ <- 1 to 10) layoutProc.step() layoutProc.updateGraph() // layout the graph, and add the step - steps += sname -> step.copy(graph = layoutProc.graph) + steps += stepName -> step.copy(graph = layoutProc.graph) } } diff --git a/scala/src/main/scala/quanto/layout/DotLayout.scala b/scala/src/main/scala/quanto/layout/DotLayout.scala index bf55319c..7058b96e 100644 --- a/scala/src/main/scala/quanto/layout/DotLayout.scala +++ b/scala/src/main/scala/quanto/layout/DotLayout.scala @@ -1,31 +1,55 @@ package quanto.layout -import sys.process._ -import quanto.data._ import java.io._ -import quanto.data.VName + +import quanto.data.{VName, _} + +import scala.sys.process._ class DotLayout extends GraphLayout { var dotString = "" - var dotProcess: Process = null + var dotProcess: Process = _ var dotIn: BufferedReader = _ var dotOut: BufferedWriter = _ + protected def compute() { + val (vid, dotStr) = generateDot(graph) + dotString = dotStr + var xMax = 0.0 + var yMax = 0.0 + + val Graph = """graph \d+ (\d+(\.\d+)?) (\d+(\.\d+)?)""".r + val Node = """node (\d+) (\d+(\.\d+)?) (\d+(\.\d+)?) .*""".r + + var coordMap = Map[Int, (Double, Double)]() + + ("dot -Tplain" #< new ByteArrayInputStream(dotString.getBytes("UTF-8"))).lines_!.foreach { + case Graph(xbd, _, ybd, _) => + xMax = xbd.toDouble + yMax = ybd.toDouble + case Node(id, x, _, y, _) => + coordMap += id.toInt -> ((x.toDouble - 0.5 * xMax) / 10, (0.5 * yMax - y.toDouble) / 10) + case _ => () + } + + vid.foreach { case (v, id) => setCoord(v, coordMap(id)) } + } + private def generateDot(graph: Graph) = { val sb = new StringBuilder - var vid = Map[VName,Int]() + var vid = Map[VName, Int]() var i = 0 sb ++= "digraph {\n" - graph.vdata.foreach { case (v,d) => + graph.vdata.foreach { case (v, _) => sb ++= " " + i + " [width=14,height=14]\n" vid += v -> i i += 1 } - graph.edata.foreach { case (e,d) => + graph.edata.foreach { case (e, d) => sb ++= " %d %s %d\n".format( vid(graph.source(e)), if (d.isDirected) "->" else "--", @@ -34,7 +58,7 @@ class DotLayout extends GraphLayout { i = 0 - graph.bbdata.foreach { case(bb,d) => + graph.bbdata.foreach { case (bb, _) => sb ++= " subgraph \"cluster_" + i + "\" { \n" graph.contents(bb).foreach { v => sb ++= " " + vid(v) + "\n" } sb ++= " }\n" @@ -43,31 +67,6 @@ class DotLayout extends GraphLayout { sb ++= "}\n" - (vid,sb.toString) - } - - protected def compute() { - val (vid,dotStr) = generateDot(graph) - dotString = dotStr - var xMax = 0.0 - var yMax = 0.0 - - val Graph = """graph \d+ (\d+(\.\d+)?) (\d+(\.\d+)?)""".r - val Node = """node (\d+) (\d+(\.\d+)?) (\d+(\.\d+)?) .*""".r - - var coordMap = Map[Int,(Double,Double)]() - - ("dot -Tplain" #< new ByteArrayInputStream(dotString.getBytes("UTF-8"))).lines_!.foreach { line => - line match { - case Graph(xbd,_, ybd,_) => - xMax = xbd.toDouble - yMax = ybd.toDouble - case Node(id, x,_, y,_) => - coordMap += id.toInt -> ((x.toDouble - 0.5 * xMax) / 10, (0.5 * yMax - y.toDouble) / 10) - case _ => () - } - } - - vid.foreach { case (v,id) => setCoord(v, coordMap(id)) } + (vid, sb.toString) } } diff --git a/scala/src/main/scala/quanto/layout/ForceLayout.scala b/scala/src/main/scala/quanto/layout/ForceLayout.scala index 113b250b..ea779474 100644 --- a/scala/src/main/scala/quanto/layout/ForceLayout.scala +++ b/scala/src/main/scala/quanto/layout/ForceLayout.scala @@ -1,50 +1,39 @@ package quanto.layout -import quanto.util._ import quanto.data._ -import math.{min,max,abs} import quanto.layout.constraint._ +import quanto.util._ + +import scala.math.abs /** - * Force-directed layout algorithm. Parts are based on: - * [1] force.js from the D3 javascript library (see d3js.org) - * [2] "Scalable, Versatile and Simple Constrained Graph Layout", Dwyer 2009 - * [3] "Efficient and High Quality Force-Directed Graph Drawing", Hu 2006 - */ + * Force-directed layout algorithm. Parts are based on: + * [1] force.js from the D3 javascript library (see d3js.org) + * [2] "Scalable, Versatile and Simple Constrained Graph Layout", Dwyer 2009 + * [3] "Efficient and High Quality Force-Directed Graph Drawing", Hu 2006 + */ class ForceLayout extends GraphLayout with Constraints { // repulsive force between vertices //var charge: VName => Double = (v => if (graph.vdata(v).isWireVertex) 3.0 else 5.0) var nodeCharge = 5.0 - - def charge(v:VName) = if (graph.vdata(v).isWireVertex && nodeCharge != 0.0) 1.0 else nodeCharge - // spring strength on edges var strength = 2.5 - // preferred length of edge var edgeLength = 0.5 - // (small) attractive force toward center of bounds var gravity = 1.0 - // Barnes-Hut approximation constant. Higher = coarser var theta = 0.8 - // used in Verlet integration var friction = 0.9 - // initial step size var alpha0: Double = 1.0 - // increase or decrease step size by this amount var alphaAdjust = 0.7 - // maximum iterations var maxIterations = 3000 - // re-center graph after each iteration var keepCentered = true - // step size alpha is re-computed on the fly using trust region heuristic var alpha: Double = _ var prevEnergy: Double = _ @@ -52,6 +41,8 @@ class ForceLayout extends GraphLayout with Constraints { var progress: Int = _ var iteration = 0 + def charge(v: VName): Double = if (graph.vdata(v).isWireVertex && nodeCharge != 0.0) 1.0 else nodeCharge + override def initialize(g: Graph, randomCoords: Boolean = true) { super.initialize(g, randomCoords) alpha = alpha0 @@ -60,24 +51,24 @@ class ForceLayout extends GraphLayout with Constraints { } // compute the equivalent point charge for every region of space in the quad tree - def computeCharges(tr: QuadTree[(Option[VName],Double)]): QuadTree[(Option[VName],Double)] = tr match { + def computeCharges(tr: QuadTree[(Option[VName], Double)]): QuadTree[(Option[VName], Double)] = tr match { case leaf: QuadLeaf[_] => leaf case _: QuadNode[_] => - val node = tr.asInstanceOf[QuadNode[(Option[VName],Double)]] + val node = tr.asInstanceOf[QuadNode[(Option[VName], Double)]] - val (v,nCharge) = node.value.getOrElse((None,0.0)) + val (v, nCharge) = node.value.getOrElse((None, 0.0)) val nw = computeCharges(node.nw) val ne = computeCharges(node.ne) val sw = computeCharges(node.sw) val se = computeCharges(node.se) - val (p,totalCharge) = Iterator(nw,ne,sw,se).foldLeft((node.p._1 * nCharge, node.p._2 * nCharge), nCharge) { - case ((pSum,cSum), child) => - val (_,c) = child.value.getOrElse((None,0.0)) + val (p, totalCharge) = Iterator(nw, ne, sw, se).foldLeft((node.p._1 * nCharge, node.p._2 * nCharge), nCharge) { + case ((pSum, cSum), child) => + val (_, c) = child.value.getOrElse((None, 0.0)) ((child.p._1 * c + pSum._1, child.p._2 * c + pSum._2), c + cSum) } - val center = if (totalCharge != 0.0) (p._1 / totalCharge, p._2 / totalCharge) else (0.0,0.0) - QuadNode(node.x1,node.y1,node.x2,node.y2,Some((v,totalCharge)),center,nw,ne,sw,se) + val center = if (totalCharge != 0.0) (p._1 / totalCharge, p._2 / totalCharge) else (0.0, 0.0) + QuadNode(node.x1, node.y1, node.x2, node.y2, Some((v, totalCharge)), center, nw, ne, sw, se) } // take an unconstrained step in the direction of steepest descent in energy @@ -98,13 +89,28 @@ class ForceLayout extends GraphLayout with Constraints { val oldCoords = coords + // shake overlapping elements slightly + val vertexList: List[VName] = graph.verts.toList + for (i <- vertexList; j <- vertexList if i != j) { + val p1 = coord(i) + val p2 = coord(j) + val shake = 0.2 + if (p1 == p2) { + setCoord(i, ( + p1._1 + shake * Math.random(), + p1._2 + shake * Math.random() + )) + } + } + + // apply spring forces for (e <- graph.edges) { val sp = coord(graph.source(e)) val tp = coord(graph.target(e)) - val (dx,dy) = if (this.isInstanceOf[Ranking] || this.isInstanceOf[IRanking]) (2.0*(tp._1 - sp._1), tp._2 - sp._2) - else (tp._1 - sp._1, tp._2 - sp._2) - val d = math.sqrt(dx*dx + dy*dy) + val (dx, dy) = if (this.isInstanceOf[Ranking] || this.isInstanceOf[IRanking]) (2.0 * (tp._1 - sp._1), tp._2 - sp._2) + else (tp._1 - sp._1, tp._2 - sp._2) + val d = math.sqrt(dx * dx + dy * dy) if (d != 0.0) { val displacement = d - edgeLength val k = (alpha * strength * displacement) / d @@ -127,18 +133,18 @@ class ForceLayout extends GraphLayout with Constraints { } // compute charges - val quad = computeCharges(QuadTree(graph.verts.toSeq.map { v => (coord(v), (Some(v),charge(v))) })) + val quad = computeCharges(QuadTree(graph.verts.toSeq.map { v => (coord(v), (Some(v), charge(v))) })) // apply charge forces for (v <- graph.verts if !lockedVertices.contains(v)) { var p = coord(v) quad.visit { nd => nd.value match { - case Some((optV,nodeCharge1)) => - val (dx1,dy1) = (nd.p._1 - p._1, nd.p._2 - p._2) + case Some((optV, nodeCharge1)) => + val (dx1, dy1) = (nd.p._1 - p._1, nd.p._2 - p._2) val dx = if (abs(dx1) < 0.01) 0.01 else dx1 val dy = if (abs(dy1) < 0.01) 0.01 else dy1 - val d2 = dx*dx + dy*dy + val d2 = dx * dx + dy * dy if (d2 == 0.0) false else { @@ -147,7 +153,7 @@ class ForceLayout extends GraphLayout with Constraints { energy += (charge(v) + nodeCharge1) / d2 val kx = alpha * nodeCharge1 / d2 val ky = if (this.isInstanceOf[Ranking] || this.isInstanceOf[IRanking]) kx * 1.5 else kx - p = (p._1 - dx*kx, p._2 - dy*ky) + p = (p._1 - dx * kx, p._2 - dy * ky) true } else { // if !B-H, but there is a (different) vertex here, act with the point charge @@ -155,7 +161,7 @@ class ForceLayout extends GraphLayout with Constraints { case Some(v1) if v1 != v => energy += (charge(v) + charge(v1)) / d2 val k = alpha * charge(v1) / d2 - p = (p._1 - dx*k, p._2 - dy*k) + p = (p._1 - dx * k, p._2 - dy * k) case _ => } @@ -171,24 +177,24 @@ class ForceLayout extends GraphLayout with Constraints { // position verlet integration for (v <- graph.verts) { - val (px,py) = oldCoords(v) - val (x,y) = coord(v) - setCoord(v, (x - ((px-x) * friction), y - ((py-y)*friction))) + val (px, py) = oldCoords(v) + val (x, y) = coord(v) + setCoord(v, (x - ((px - x) * friction), y - ((py - y) * friction))) } } def recenter() { - val (sumCoordx,sumCoordy) = graph.verts.foldLeft(0.0,0.0)((pos,name) - => (pos._1+coord(name)._1,pos._2+coord(name)._2)) - val (centerX,centerY) = (sumCoordx/graph.verts.size,sumCoordy/graph.verts.size) - -// if(abs(centerX)> 5|| abs(centerY)> 5){ - graph.verts.foreach(name=>{ - val (px,py) = coord(name) - setCoord(name, (px-centerX, py-centerY )) - }) - // } + val (sumCoordX, sumCoordY) = graph.verts.foldLeft(0.0, 0.0)((pos, name) + => (pos._1 + coord(name)._1, pos._2 + coord(name)._2)) + val (centerX, centerY) = (sumCoordX / graph.verts.size, sumCoordY / graph.verts.size) + + // if(abs(centerX)> 5|| abs(centerY)> 5){ + graph.verts.foreach(name => { + val (px, py) = coord(name) + setCoord(name, (px - centerX, py - centerY)) + }) + // } } def step() { diff --git a/scala/src/main/scala/quanto/layout/GraphLayout.scala b/scala/src/main/scala/quanto/layout/GraphLayout.scala index 887ba130..a8f0a1b6 100644 --- a/scala/src/main/scala/quanto/layout/GraphLayout.scala +++ b/scala/src/main/scala/quanto/layout/GraphLayout.scala @@ -1,46 +1,59 @@ package quanto.layout -import quanto.data.{VName, Graph} +import quanto.data.{Graph, VName} + +import scala.collection.mutable import scala.util.Random class LayoutUninitializedException extends Exception("Layout data read before layout() called") abstract class GraphLayout { - private var _graph: Graph = null - def graph = _graph + protected val lockedVertices: mutable.Set[VName] = collection.mutable.Set() + private val _coords = collection.mutable.Map[VName, (Double, Double)]() + private var _graph: Graph = _ - protected val lockedVertices = collection.mutable.Set[VName]() + def lockVertex(v: VName) { + lockedVertices += v + } - def lockVertex(v: VName) { lockedVertices += v } - def clearLockedVertices() { lockedVertices.clear() } + def clearLockedVertices() { + lockedVertices.clear() + } - private val _coords = collection.mutable.Map[VName,(Double,Double)]() - def setCoord(v: VName, p:(Double,Double)) { + def setCoord(v: VName, p: (Double, Double)) { if (!lockedVertices.contains(v)) _coords(v) = p } + def coord(v: VName) = _coords(v) - def coords = _coords.clone() - // override to compute layout data - protected def compute() + def coords: mutable.Map[VName, (Double, Double)] = _coords.clone() + + def layout(g: Graph, randomCoords: Boolean = true): Graph = { + initialize(g, randomCoords) + compute() + updateGraph() + + graph + } def initialize(g: Graph, randomCoords: Boolean = true) { _graph = g _coords.clear() val r = new Random(0xdeadbeef) - graph.vdata.foreach { case (v,d) => - _coords(v) = if (randomCoords) (0.5 - r.nextDouble(), 0.5 - r.nextDouble()) else d.coord } + graph.vdata.foreach { case (v, d) => + _coords(v) = if (randomCoords) (0.5 - r.nextDouble(), 0.5 - r.nextDouble()) else d.coord + } } def updateGraph() { - _graph = _coords.foldLeft(graph) { case(g,(v,c)) => g.updateVData(v) { _.withCoord(c) } } + _graph = _coords.foldLeft(graph) { case (g, (v, c)) => g.updateVData(v) { + _.withCoord(c) + } + } } - def layout(g: Graph, randomCoords: Boolean = true): Graph = { - initialize(g, randomCoords) - compute() - updateGraph() + def graph: Graph = _graph - graph - } + // override to compute layout data + protected def compute() } diff --git a/scala/src/main/scala/quanto/layout/GraphTransform.scala b/scala/src/main/scala/quanto/layout/GraphTransform.scala index c2a7ee0f..4f5d9b06 100644 --- a/scala/src/main/scala/quanto/layout/GraphTransform.scala +++ b/scala/src/main/scala/quanto/layout/GraphTransform.scala @@ -1,28 +1,27 @@ package quanto.layout -import quanto.data.Graph -import quanto.data.VData +import quanto.data.{Graph, VData} + +import scala.collection.mutable.ListBuffer // transfer each bbox to a node -class GraphTransform (g:Graph){ - - //val numOfBBox = g.bboxes.size - val bs = collection.mutable.ListBuffer[() => Iterator[VData]]() - - def transform { - g.bboxes.foreach(bb => - for(ver <- g.contents(bb)){ - val target = g.succVerts(ver) - - - - }) - - - - } +class GraphTransform(g: Graph) { + + //val numOfBBox = g.bboxes.size + val bs: ListBuffer[() => Iterator[VData]] = collection.mutable.ListBuffer[() => Iterator[VData]]() - def restore{ + def transform() { + g.bboxes.foreach(bb => + for (ver <- g.contents(bb)) { + val target = g.succVerts(ver) + + + }) - } + + } + + def restore() { + + } } \ No newline at end of file diff --git a/scala/src/main/scala/quanto/layout/constraint/Clusters.scala b/scala/src/main/scala/quanto/layout/constraint/Clusters.scala index fa27b001..f922ddf3 100644 --- a/scala/src/main/scala/quanto/layout/constraint/Clusters.scala +++ b/scala/src/main/scala/quanto/layout/constraint/Clusters.scala @@ -1,37 +1,37 @@ package quanto.layout.constraint -import quanto.data._ -import collection.mutable.ListBuffer +import quanto.data.{VName, _} import quanto.layout._ +import quanto.util.Geometry._ +import quanto.util.{QuadTree, _} + +import scala.collection.mutable.ListBuffer +import scala.math.abs -import math.{min,max,abs} -import quanto.util.QuadTree -import quanto.data.VName -import quanto.util._ -import Geometry._ -import Names._ /** - * Constraints to force nodes not in the given clusters outside of the bounding box - */ + * Constraints to force nodes not in the given clusters outside of the bounding box + */ trait Clusters extends Constraints { + import Constraint.distance + var debug = false - var clusters = ListBuffer[Set[VName]]() + var clusters: ListBuffer[Set[VName]] = ListBuffer() var clusterPadding = 0.5 var clusterRadiusPerVertex = 2.0 - + // right code? - def inRect(v:VName, r : ((Double,Double),(Double,Double))) : Boolean = { - val (px,py) = coord(v) - val ((lx,ly),(ux,uy)) = r + def inRect(v: VName, r: ((Double, Double), (Double, Double))): Boolean = { + val (px, py) = coord(v) + val ((lx, ly), (ux, uy)) = r (px > lx) && (px < ux) && (py > ly) && (py < uy) } - + override def initialize(g: Graph, randomCoords: Boolean = true) { super.initialize(g, randomCoords) constraints.nextLayer() //println("Clusters at layer " + constraints.currentLayer) - + // take each bbox as a cluster clusters.clear g.bboxes.foreach(bb => clusters += g.contents(bb)) @@ -40,43 +40,41 @@ trait Clusters extends Constraints { val coordTree = QuadTree(graph.verts.toSeq.map { v => (coord(v), v) }) - clusters.foldLeft(List[Constraint]()) { (constraints,cluster) => - bounds(cluster.map(coord(_))) match { + clusters.foldLeft(List[Constraint]()) { (constraints, cluster) => + bounds(cluster.map(coord)) match { case Some(rect) => - - val bbox = new RichRect(rect) - val (lb,ub) = bbox.pad(clusterPadding) + + val bbox = new RichRect(rect) + val (lb, ub) = bbox.pad(clusterPadding) //if (debug) println("the bound bbox boundary is " +(lb,ub)) - - // centre of the bbox - val (cx,cy) = bbox.center - //if (debug) (println("the center is " + bbox.center)) - val (wth,hgt) = bbox.size + val clusterSize = cluster.size var cons = constraints for (v1 <- cluster; v2 <- cluster; if v1 != v2) { - cons ::= { (distance from v1 to v2) <= (clusterRadiusPerVertex * (clusterSize)-1) } // * 0.95 -// cons ::= { (distance from v1 to v2 along (1.0,0.0)) <= (clusterRadiusPerVertex * (clusterSize)) } // * 0.95 -// cons ::= { (distance from v1 to v2 along (0.0,1.0)) <= (clusterRadiusPerVertex * (clusterSize)) } // * 0.95 + cons ::= { + (distance from v1 to v2) <= (clusterRadiusPerVertex * clusterSize - 1) + } // * 0.95 + // cons ::= { (distance from v1 to v2 along (1.0,0.0)) <= (clusterRadiusPerVertex * (clusterSize)) } // * 0.95 + // cons ::= { (distance from v1 to v2 along (0.0,1.0)) <= (clusterRadiusPerVertex * (clusterSize)) } // * 0.95 } - + //val verts = g.verts.filter(vname => inRect(vname,(lb,ub))) - val verts = coordTree.query(lb._1, lb._2, ub._1, ub._2) + val vertices = coordTree.query(lb._1, lb._2, ub._1, ub._2) // quadtree is not right //println("quadTree is " + verts) -// if (debug) println ("v3 's coordinate is " + coord("v3")) -// if (debug) println ("b0 's coordinate is " + coord("b0")) -// if (debug) println ("v1 's coordinate is " + coord("v1")) - // println("v0 " + "is " + coord("v0")) - // println("box " + (lb,ub)) + // if (debug) println ("v3 's coordinate is " + coord("v3")) + // if (debug) println ("b0 's coordinate is " + coord("b0")) + // if (debug) println ("v1 's coordinate is " + coord("v1")) + // println("v0 " + "is " + coord("v0")) + // println("box " + (lb,ub)) //if (debug) println ("all vertices in the bbox\n\t" + verts) - for (v <- verts; if !cluster.contains(v)) { + for (v <- vertices; if !cluster.contains(v)) { val vc = coord(v) - - // println(v + " is in the bbox but not belongs it"); + + // println(v + " is in the bbox but not belongs it"); // work out the most efficient shift ///// @@ -86,69 +84,71 @@ trait Clusters extends Constraints { */ // need a better way to shift. - + // move to the side that has more force to pull it - val (shiftDirX, shiftDirY) = (g.succVerts(v)++g.predVerts(v)). - foldLeft(0.0,0.0)((pos,name) => (pos._1+coord(name)._1-vc._1,pos._2+coord(name)._2-vc._2)) + //val (shiftDirX, shiftDirY) = (g.succVerts(v) ++ g.predVerts(v)). + // foldLeft(0.0, 0.0)((pos, name) => (pos._1 + coord(name)._1 - vc._1, pos._2 + coord(name)._2 - vc._2)) //println("node " + v + " should move direction" + (shiftDirX, shiftDirY)) -// val xShift = if (vc._1 >= lb._1 && vc._1 <= ub._1) { -// if(shiftDirX > 0) ub._1 - vc._1 else lb._1 - vc._1 -// -// } else 0 -// -// val yShift = if (vc._2 >= lb._2 && vc._2 <= ub._2) { -// if(shiftDirY > 0){ -// ub._2 - vc._2 -// } -// else{ -// lb._2 - vc._2 -// } -// } else 0 -// -// if (xShift != 0 && yShift != 0){ -// for (v1 <- cluster) { -// val v1c = coord(v1) -// val (len,dir) = if (abs(shiftDirX) > abs(shiftDirY)) { -// if (xShift < 0.0) (abs((vc._1 - v1c._1) + (xShift)), (1.0,0.0)) -// else (abs((vc._1 - v1c._1) + xShift), (-1.0,0.0)) -// } else { -// if (yShift < 0.0) (abs((vc._2 - v1c._2) + (yShift)), (0.0,1.0)) -// else (abs((vc._2 - v1c._2) + (yShift)), (0.0,-1.0)) -// } -// -// cons ::= { (distance from v to v1 along dir weighted (1.0, 1.0)) >= len } //cluster.size.toDouble+1 -// } -// } - + // val xShift = if (vc._1 >= lb._1 && vc._1 <= ub._1) { + // if(shiftDirX > 0) ub._1 - vc._1 else lb._1 - vc._1 + // + // } else 0 + // + // val yShift = if (vc._2 >= lb._2 && vc._2 <= ub._2) { + // if(shiftDirY > 0){ + // ub._2 - vc._2 + // } + // else{ + // lb._2 - vc._2 + // } + // } else 0 + // + // if (xShift != 0 && yShift != 0){ + // for (v1 <- cluster) { + // val v1c = coord(v1) + // val (len,dir) = if (abs(shiftDirX) > abs(shiftDirY)) { + // if (xShift < 0.0) (abs((vc._1 - v1c._1) + (xShift)), (1.0,0.0)) + // else (abs((vc._1 - v1c._1) + xShift), (-1.0,0.0)) + // } else { + // if (yShift < 0.0) (abs((vc._2 - v1c._2) + (yShift)), (0.0,1.0)) + // else (abs((vc._2 - v1c._2) + (yShift)), (0.0,-1.0)) + // } + // + // cons ::= { (distance from v to v1 along dir weighted (1.0, 1.0)) >= len } //cluster.size.toDouble+1 + // } + // } + // not sure if we ned the if condition // since quadtree will make sure the condition is true val xShift = if (vc._1 >= lb._1 && vc._1 <= ub._1) { - val left = lb._1 - vc._1 + val left = lb._1 - vc._1 val right = ub._1 - vc._1 if (abs(left) < right) left else right } else 0 val yShift = if (vc._2 >= lb._2 && vc._2 <= ub._2) { - val down = lb._2 - vc._2 + val down = lb._2 - vc._2 val up = ub._2 - vc._2 if (-down < up) down else up } else 0 - - //questions about following code + + //questions about following code if (xShift != 0 && yShift != 0) { // val soft = g.isBBoxed(v) for (v1 <- cluster) { val v1c = coord(v1) - val (len,dir) = if (abs(xShift) < abs(yShift)) { - if (xShift < 0.0) (abs((vc._1 - v1c._1) + (xShift)), (1.0,0.0)) - else (abs((vc._1 - v1c._1) + xShift), (-1.0,0.0)) + val (len, dir) = if (abs(xShift) < abs(yShift)) { + if (xShift < 0.0) (abs((vc._1 - v1c._1) + xShift), (1.0, 0.0)) + else (abs((vc._1 - v1c._1) + xShift), (-1.0, 0.0)) } else { - if (yShift < 0.0) (abs((vc._2 - v1c._2) + (yShift)), (0.0,1.0)) - else (abs((vc._2 - v1c._2) + (yShift)), (0.0,-1.0)) + if (yShift < 0.0) (abs((vc._2 - v1c._2) + yShift), (0.0, 1.0)) + else (abs((vc._2 - v1c._2) + yShift), (0.0, -1.0)) } - - cons ::= { (distance from v to v1 along dir weighted (1.0, 1.0)) >= len } //cluster.size.toDouble+1 -// cons ::= { (distance from v to v1 along dir weighted (1.0, cluster.size.toDouble+1)) >= len } // + + cons ::= { + (distance from v to v1 along dir weighted(1.0, 1.0)) >= len + } //cluster.size.toDouble+1 + // cons ::= { (distance from v to v1 along dir weighted (1.0, cluster.size.toDouble+1)) >= len } // } } } @@ -156,7 +156,7 @@ trait Clusters extends Constraints { cons case None => constraints } - + } } } diff --git a/scala/src/main/scala/quanto/layout/constraint/NoTweaks.scala b/scala/src/main/scala/quanto/layout/constraint/NoTweaks.scala index 49467101..c8cddec5 100644 --- a/scala/src/main/scala/quanto/layout/constraint/NoTweaks.scala +++ b/scala/src/main/scala/quanto/layout/constraint/NoTweaks.scala @@ -1,53 +1,49 @@ package quanto.layout.constraint import quanto.data._ -import collection.mutable.ListBuffer import quanto.layout._ -import math.{min,max,abs} -import quanto.util.QuadTree -import quanto.data.VName -import quanto.util._ -import Geometry._ -import Names._ //import quanto.layout.distance trait NoTweaks extends Constraints { - val radius = 0.2 - override def initialize(g: Graph, randomCoords: Boolean = true) { - super.initialize(g, randomCoords) - constraints.nextLayer() - println("Clusters at layer " + constraints.currentLayer) - - // for edges we get rid of the - - for(v <- g.verts){ - val edges = g.inEdges(v) ++ g.outEdges(v) - val compEdges = g.edges -- edges - - // for each vertex, the distance to each compEdges - // is greater than its radius. - - for (e <- compEdges ){ - val s = graph.source(e) - val t = graph.target(e) - val (sx,sy) = coord(s) - val (tx,ty) = coord(t) - val (pqx,pqy) = (sx-tx, sy-ty) - var par = (0.0, 0.0) - if (pqy == 0){ - par = (0.0,1.0) - } - else if (pqx == 0){ - par = (1.0, 0.0) - } - else { - par = (-pqy/pqx,1.0) - } - - constraints += { (Constraint.distance from v to s along par) >= radius } - } - - } - } + val radius = 0.2 + + override def initialize(g: Graph, randomCoords: Boolean = true) { + super.initialize(g, randomCoords) + constraints.nextLayer() + println("Clusters at layer " + constraints.currentLayer) + + // for edges we get rid of the + + for (v <- g.verts) { + val edges = g.inEdges(v) ++ g.outEdges(v) + val compEdges = g.edges -- edges + + // for each vertex, the distance to each compEdges + // is greater than its radius. + + for (e <- compEdges) { + val s = graph.source(e) + val t = graph.target(e) + val (sx, sy) = coord(s) + val (tx, ty) = coord(t) + val (pqx, pqy) = (sx - tx, sy - ty) + var par = (0.0, 0.0) + if (pqy == 0) { + par = (0.0, 1.0) + } + else if (pqx == 0) { + par = (1.0, 0.0) + } + else { + par = (-pqy / pqx, 1.0) + } + + constraints += { + (Constraint.distance from v to s along par) >= radius + } + } + + } + } } \ No newline at end of file diff --git a/scala/src/main/scala/quanto/layout/constraint/Ranking.scala b/scala/src/main/scala/quanto/layout/constraint/Ranking.scala index aaa669f8..0e7d0b54 100644 --- a/scala/src/main/scala/quanto/layout/constraint/Ranking.scala +++ b/scala/src/main/scala/quanto/layout/constraint/Ranking.scala @@ -4,10 +4,12 @@ import quanto.data._ import quanto.layout._ /** - * Mix in to add ranking constraints to initialization - */ + * Mix in to add ranking constraints to initialization + */ trait Ranking extends Constraints { + import Constraint.distance + var rankSep: Double = 1.0 override def initialize(g: Graph, randomCoords: Boolean = true) { @@ -25,7 +27,9 @@ trait Ranking extends Constraints { } trait IRanking extends Constraints { + import Constraint.distance + var rankSep: Double = 1.0 override def initialize(g: Graph, randomCoords: Boolean = true) { @@ -34,6 +38,8 @@ trait IRanking extends Constraints { val dag = g.dagCopy for (e <- dag.edges) - constraints += {(distance from dag.source(e) to dag.target(e) along (0,-1)) ~>= rankSep} + constraints += { + (distance from dag.source(e) to dag.target(e) along(0, -1)) ~>= rankSep + } } } diff --git a/scala/src/main/scala/quanto/layout/constraint/VerticalBoundary.scala b/scala/src/main/scala/quanto/layout/constraint/VerticalBoundary.scala index e0476a69..4f1a29c3 100644 --- a/scala/src/main/scala/quanto/layout/constraint/VerticalBoundary.scala +++ b/scala/src/main/scala/quanto/layout/constraint/VerticalBoundary.scala @@ -1,50 +1,55 @@ package quanto.layout.constraint -import quanto.layout._ import quanto.data._ +import quanto.layout._ trait VerticalBoundary extends Constraints { + import Constraint.distance - + override def initialize(g: Graph, randomCoords: Boolean = true) { super.initialize(g, randomCoords) constraints.nextLayer() //println("VerticalBoundary at layer " + constraints.currentLayer) - -// for (bnd <- Iterator(g.inputs,g.outputs); if !bnd.isEmpty) { -// val it = bnd.iterator -// val v1 = it.next() -// for (v2 <- it) { -//// if (g.isBBoxed(v1) && !g.isBBoxed(v2)) { -//// -//// } -// coord(v1) -// if (g.isBBoxed(v1) || g.isBBoxed(v2)) -// constraints += { (distance from v1 to v2 along (0.0,1.0)) ~== 0.0 } -// else -// constraints += { (distance from v1 to v2 along (0.0,1.0)) === 0.0 } -// //constraints += { (distance from v1 to v2 along (0.0,1.0)) ~== 0.0 } -// // v1 = v2 -// } -// } + + // for (bnd <- Iterator(g.inputs,g.outputs); if !bnd.isEmpty) { + // val it = bnd.iterator + // val v1 = it.next() + // for (v2 <- it) { + //// if (g.isBBoxed(v1) && !g.isBBoxed(v2)) { + //// + //// } + // coord(v1) + // if (g.isBBoxed(v1) || g.isBBoxed(v2)) + // constraints += { (distance from v1 to v2 along (0.0,1.0)) ~== 0.0 } + // else + // constraints += { (distance from v1 to v2 along (0.0,1.0)) === 0.0 } + // //constraints += { (distance from v1 to v2 along (0.0,1.0)) ~== 0.0 } + // // v1 = v2 + // } + // } g.edges.foreach { e => - val (s,t) = (g.source(e), g.target(e)) - if (g.isInput(s) || g.isOutput(t)) { - if (g.isBBoxed(s) || g.isBBoxed(t)){ - constraints += { (distance from s to t along (1.0,0.0)) ~== 0.0 } + val (s, t) = (g.source(e), g.target(e)) + if (g.isInputWire(s) || g.isOutputWire(t)) { + if (g.isBBoxed(s) || g.isBBoxed(t)) { + constraints += { + (distance from s to t along(1.0, 0.0)) ~== 0.0 + } } else - constraints += { (distance from s to t along (1.0,0.0)) === 0.0 } + constraints += { + (distance from s to t along(1.0, 0.0)) === 0.0 + } } } } - -// override def projectConstraints(){ -// -// -// } - + + // override def projectConstraints(){ + // + // + // } + } diff --git a/scala/src/main/scala/quanto/rewrite/AngleExpressionMatcher.scala b/scala/src/main/scala/quanto/rewrite/AngleExpressionMatcher.scala deleted file mode 100644 index bad80bb1..00000000 --- a/scala/src/main/scala/quanto/rewrite/AngleExpressionMatcher.scala +++ /dev/null @@ -1,42 +0,0 @@ -package quanto.rewrite - -import quanto.data._ -import quanto.util._ - - -class AngleExpressionMatcher(pVars : Vector[String], tVars : Vector[String], mat : RationalMatrix) { - val pvSet = pVars.toSet - val tvSet = tVars.toSet - - def addMatch(pExpr : AngleExpression, tExpr : AngleExpression) : Option[AngleExpressionMatcher] = { - val pVars1 = pVars ++ (pExpr.vars -- pvSet).toVector - val tVars1 = tVars ++ (tExpr.vars -- tvSet).toVector - - val r1 = pVars1.map { v => pExpr.coeffs.getOrElse(v, Rational(0)) } - val r2 = tVars1.map { v => tExpr.coeffs.getOrElse(v, Rational(0)) } - val row = r1 ++ r2 :+ (tExpr.const - pExpr.const) - - mat.padTo(pVars1.length, tVars1.length).gaussUpdate(row).map { mat1 => - new AngleExpressionMatcher(pVars1, tVars1, mat1) - } - } - - - def toMap : Map[String, AngleExpression] = - if (mat.numCols == 0) Map() - else mat.rows.foldLeft(Map[String,AngleExpression]()) { (mp, row) => - val p = RationalMatrix.findPivot(row) - var coeffs = Map[String,Rational]() - for (i <- p+1 until mat.line) - if (row(i) != Rational(0)) coeffs = coeffs + (pVars(i) -> row(i) * -1) - for (i <- mat.line to row.length - 2) - if (row(i) != Rational(0)) coeffs = coeffs + (tVars(i-mat.line) -> row(i)) - - mp + (pVars(p) -> AngleExpression(row.last, coeffs)) - } -} - -object AngleExpressionMatcher { - def apply(pVars : Vector[String], tVars : Vector[String]) = - new AngleExpressionMatcher(pVars, tVars, new RationalMatrix(Vector(), pVars.length)) -} diff --git a/scala/src/main/scala/quanto/rewrite/CompositeExpressionMatcher.scala b/scala/src/main/scala/quanto/rewrite/CompositeExpressionMatcher.scala new file mode 100644 index 00000000..9a14b80c --- /dev/null +++ b/scala/src/main/scala/quanto/rewrite/CompositeExpressionMatcher.scala @@ -0,0 +1,86 @@ +package quanto.rewrite + + +import quanto.data.Theory.ValueType +import quanto.data.{CompositeExpression, PhaseExpression} +import quanto.util.{Rational, RationalMatrix} + +class PhaseExpressionMatcher(pVars: Vector[String], tVars: Vector[String], mat: RationalMatrix) { + val pvSet: Set[String] = pVars.toSet + val tvSet: Set[String] = tVars.toSet + + def addMatch(patternExpression: PhaseExpression, targetExpression: PhaseExpression): Option[PhaseExpressionMatcher] = { + val patternVars1 = pVars ++ (patternExpression.vars.toSet -- pvSet).toVector + val targetVars1 = tVars ++ (targetExpression.vars.toSet -- tvSet).toVector + val zero = Rational(0) + val r1 = patternVars1.map { v => patternExpression.coefficients.getOrElse(v, zero) } + val r2 = targetVars1.map { v => targetExpression.coefficients.getOrElse(v, zero) } + val row = r1 ++ r2 :+ (targetExpression.constant - patternExpression.constant) + + mat.padTo(patternVars1.length, targetVars1.length).gaussUpdate(row).map { mat1 => + new PhaseExpressionMatcher(patternVars1, targetVars1, mat1) + } + } + + def toMap(valueType: ValueType): Map[String, PhaseExpression] = toMap.mapValues(_.as(valueType)) + + def toMap: Map[String, PhaseExpression] = + if (mat.numCols == 0) Map() + else mat.rows.foldLeft(Map[String, PhaseExpression]()) { (mp, row) => + val p = RationalMatrix.findPivot(row) + var coefficients = Map[String, Rational]() + for (i <- p + 1 until mat.line) + if (row(i) != Rational(0)) coefficients = coefficients + (pVars(i) -> row(i) * -1) + for (i <- mat.line to row.length - 2) + if (row(i) != Rational(0)) coefficients = coefficients + (tVars(i - mat.line) -> row(i)) + + mp + (pVars(p) -> PhaseExpression(row.last, coefficients, ValueType.Rational)) + } +} + +object PhaseExpressionMatcher { + def empty: PhaseExpressionMatcher = PhaseExpressionMatcher(Vector(), Vector(), None) + + def apply(pVars: Vector[String], tVars: Vector[String], modulus: Option[Int]) = + new PhaseExpressionMatcher(pVars, tVars, new RationalMatrix(Vector(), pVars.length, modulus)) +} + + +class CompositeExpressionMatcher(matchers: Map[ValueType, Option[PhaseExpressionMatcher]]) { + + // Add matchings by component + def addMatch(pExpr: CompositeExpression, tExpr: CompositeExpression): Option[CompositeExpressionMatcher] = { + val valuePairs = pExpr.values.zip(tExpr.values) + val typeValueTriples = pExpr.valueTypes.zip(valuePairs) + // Loop through each valueType and apply relevant matches + // If any of the matches fail then the whole thing needs to fail + typeValueTriples.foldLeft(Some(this): Option[CompositeExpressionMatcher])((om, t) => { + if (om.nonEmpty) { + om.get.addPhaseMatch(t._1, t._2._1.asInstanceOf[PhaseExpression], t._2._2.asInstanceOf[PhaseExpression]) + } else { + None + } + }) + } + + // Add a single matching to a specific valueType + def addPhaseMatch(valueType: ValueType, pExpr: PhaseExpression, tExpr: PhaseExpression): Option[CompositeExpressionMatcher] = { + val updatedSingletonMatcher = matchers. + getOrElse(valueType, Some(PhaseExpressionMatcher(Vector(), Vector(), pExpr.modulus))). + get.addMatch(pExpr, tExpr) + if (updatedSingletonMatcher.nonEmpty) { + Some(new CompositeExpressionMatcher(matchers + (valueType -> updatedSingletonMatcher))) + } else { + None + } + } + + def toMap: Map[ValueType, Map[String, PhaseExpression]] = + matchers.keySet.map( + valueType => valueType -> matchers(valueType).getOrElse(PhaseExpressionMatcher.empty).toMap.mapValues(_.as(valueType)) + ).toMap +} + +object CompositeExpressionMatcher { + def apply(): CompositeExpressionMatcher = new CompositeExpressionMatcher(Map()) +} diff --git a/scala/src/main/scala/quanto/rewrite/Match.scala b/scala/src/main/scala/quanto/rewrite/Match.scala index a660de38..c9fa44e8 100644 --- a/scala/src/main/scala/quanto/rewrite/Match.scala +++ b/scala/src/main/scala/quanto/rewrite/Match.scala @@ -1,7 +1,11 @@ package quanto.rewrite +import quanto.data.Theory.ValueType import quanto.data._ +import quanto.util.UserAlerts -class MatchException(msg: String) extends Exception(msg) +class MatchException(msg: String) extends Exception(msg) { + UserAlerts.alert(s"Match Exception: $msg", UserAlerts.Elevation.WARNING) +} case class Match(pattern0: Graph, // the pattern without bbox operations pattern: Graph, @@ -9,7 +13,8 @@ case class Match(pattern0: Graph, // the pattern without bbox operations map: GraphMap = GraphMap(), bareWireMap: Map[VName, Vector[VName]] = Map(), bbops: List[BBOp] = List(), - subst: Map[String,AngleExpression] = Map()) { + subst: Map[ValueType, Map[String,PhaseExpression]] = Map() // matching individual phases inside composite phases + ) { def addVertex(vPair: (VName, VName)): Match = { copy(map = map addVertex vPair) @@ -41,16 +46,16 @@ case class Match(pattern0: Graph, // the pattern without bbox operations bareWireMap.headOption match { case Some((tw, pw +: pws)) => val (target1, (newW1, newW2, newE)) = target.expandWire(tw) - val emap1 = map.e + (pattern.outEdges(pw).head -> newE) + val edgeMap1 = map.e + (pattern.outEdges(pw).head -> newE) - var vmap1 = map.v + var vertexMap1 = map.v for (pw1 <- map.v.codf(tw)) { - if (pattern.isInput(pw1)) vmap1 = vmap1 + (pw1 -> newW2) + if (pattern.isInputWire(pw1)) vertexMap1 = vertexMap1 + (pw1 -> newW2) } - vmap1 = vmap1 + (pw -> newW1) + (pattern.succVerts(pw).head -> newW2) + vertexMap1 = vertexMap1 + (pw -> newW1) + (pattern.succVerts(pw).head -> newW2) copy( - map = map.copy(v = vmap1, e = emap1), + map = map.copy(v = vertexMap1, e = edgeMap1), target = target1, bareWireMap = bareWireMap + (tw -> pws) ).normalize diff --git a/scala/src/main/scala/quanto/rewrite/MatchState.scala b/scala/src/main/scala/quanto/rewrite/MatchState.scala index 4eb56ad3..ecae3120 100644 --- a/scala/src/main/scala/quanto/rewrite/MatchState.scala +++ b/scala/src/main/scala/quanto/rewrite/MatchState.scala @@ -1,30 +1,30 @@ package quanto.rewrite + import quanto.data._ import scala.annotation.tailrec case class MatchState( - m: Match, // the match being built - tVerts: Set[VName], // restriction of the range of the match - angleMatcher: AngleExpressionMatcher, // state of matched angle data - pNodes: Set[VName] = Set(), // nodes with partially-mapped neighbourhood - psNodes: Set[VName] = Set(), // same, but scheduled for completion - sBBox: Option[BBName] = None, // a bbox scheduled for matching - candidateNodes: Option[Set[VName]] = None, // nodes to try matching in the target - candidateEdges: Option[Set[EName]] = None, // edges to try matching in the target - candidateWires: Option[Set[(VName,Int)]] = None, // wire-vertices to try matching bare wires on - candidateBBoxes: Option[Set[BBName]] = None, // bboxes to try matching in the target - bboxOrbits: PFun[VName, VName] = PFun(), // for smashing redundant matches - nextState: Option[MatchState] = None // next state to try after search terminates + m: Match, // the match being built + targetVertices: Set[VName], // restriction of the range of the match + expressionMatcher: CompositeExpressionMatcher, // state of matched angle data + pNodes: Set[VName] = Set(), // nodes with partially-mapped neighbourhood + psNodes: Set[VName] = Set(), // same, but scheduled for completion + sBBox: Option[BBName] = None, // a bbox scheduled for matching + candidateNodes: Option[Set[VName]] = None, // nodes to try matching in the target + candidateEdges: Option[Set[EName]] = None, // edges to try matching in the target + candidateWires: Option[Set[(VName, Int)]] = None, // wire-vertices to try matching bare wires on + candidateBBoxes: Option[Set[BBName]] = None, // bboxes to try matching in the target + bboxOrbits: PFun[VName, VName] = PFun(), // for smashing redundant matches + nextState: Option[MatchState] = None // next state to try after search terminates ) { - val uVerts: Set[VName] = m.pattern.verts.filter(v => bboxesMatched(v) && !m.map.v.domSet.contains(v)) - val uCircles: Set[VName] = uVerts.filter(m.pattern.isCircle) - lazy val uBareWires: Set[VName] = uVerts.filter(m.pattern.representsBareWire) - lazy val uNodes: Set[VName] = uVerts.filter { v => !m.pattern.vdata(v).isWireVertex } - lazy val uWires: Set[VName] = uVerts.filter { v => m.pattern.vdata(v).isWireVertex } - val uBBoxes: Set[BBName] = m.pattern.bboxes.filter(bb => parentBBoxesMatched(bb) && !m.map.bb.domSet.contains(bb)) - + lazy val uBareWires: Set[VName] = unmatchedVertices.filter(m.pattern.representsBareWire) + lazy val uNodes: Set[VName] = unmatchedVertices.filter { v => !m.pattern.vdata(v).isWireVertex } + lazy val uWires: Set[VName] = unmatchedVertices.filter { v => m.pattern.vdata(v).isWireVertex } + val unmatchedVertices: Set[VName] = m.pattern.verts.filter(v => bboxesMatched(v) && !m.map.v.domSet.contains(v)) + val uCircles: Set[VName] = unmatchedVertices.filter(m.pattern.isCircle) + val uBBoxes: Set[BBName] = m.pattern.bboxes.filter(bb => parentBBoxesMatched(bb) && !m.map.bb.domSet.contains(bb)) /** @@ -34,28 +34,36 @@ case class MatchState( */ @tailrec final def nextMatch(): Option[(Match, Option[MatchState])] = { +// if (m.map.e.domSet.size == 4) { +// print("got 4 edges") +// } + // if unmatched circles are found in the pattern, match them first if (uCircles.nonEmpty) { val pc = uCircles.min - tVerts.find(v => m.target.isCircle(v) && reflectsBBoxes(pc, v)) match { - case None => nextState match { case Some(next) => next.nextMatch(); case None => None } + targetVertices.find(v => m.target.isCircle(v) && reflectsBBoxes(pc, v)) match { + case None => nextState match { + case Some(next) => next.nextMatch(); + case None => None + } case Some(tc) => val pce = m.pattern.inEdges(pc).min val tce = m.target.inEdges(tc).min - copy(m = m.addEdge(pce -> tce, pc -> tc), tVerts = tVerts - tc).nextMatch() + copy(m = m.addEdge(pce -> tce, pc -> tc), targetVertices = targetVertices - tc).nextMatch() } - // if there is a scheduled node, try to match its neighbourhood in every possible way + // if there is a scheduled node, try to match its neighbourhood in every possible way } else if (psNodes.nonEmpty) { val np = psNodes.min - if (pVertexMayBeCompleted(np)) { + val cmp = pVertexMayBeCompleted(np) + if (cmp) { val nt = m.map.v(np) // get the next matchable edge in the neighbourhood of np val uEdges = m.pattern.adjacentEdges(np).filter(e => !m.map.e.domSet.contains(e) && - bboxesMatched(m.pattern.edgeGetOtherVertex(e, np)) + bboxesMatched(m.pattern.edgeGetOtherVertex(e, np)) ) val epOpt = if (uEdges.isEmpty) None else Some(uEdges.min) epOpt match { @@ -66,7 +74,7 @@ case class MatchState( case None => copy(candidateEdges = Some(m.target.adjacentEdges(nt).filter { e => !m.map.e.codSet.contains(e) && - tVerts.contains(m.target.edgeGetOtherVertex(e, nt)) + targetVertices.contains(m.target.edgeGetOtherVertex(e, nt)) })).nextMatch() case Some(candidateEdges1) => if (candidateEdges1.isEmpty) { @@ -78,8 +86,10 @@ case class MatchState( val et = candidateEdges1.min val next = copy(candidateEdges = Some(candidateEdges1 - et)) matchNewWire(np, ep, nt, et) match { - case Some(ms1) => ms1.copy(candidateEdges = None, nextState = Some(next)).nextMatch() - case None => next.nextMatch() + case Some(ms1) => + ms1.copy(candidateEdges = None, nextState = Some(next)).nextMatch() + case None => + next.nextMatch() } } } @@ -100,13 +110,13 @@ case class MatchState( } - // if there are no scheduled nodes, pick a new unmatched node in the pattern, match it in every possible way - // and schedule its neighbourhood for matching + // if there are no scheduled nodes, pick a new unmatched node in the pattern, match it in every possible way + // and schedule its neighbourhood for matching } else if (uNodes.nonEmpty) { val np = uNodes.min candidateNodes match { case None => - copy(candidateNodes = Some(tVerts.filter { v => + copy(candidateNodes = Some(targetVertices.filter { v => !m.target.vdata(v).isWireVertex })).nextMatch() case Some(candidateNodes1) => @@ -125,18 +135,18 @@ case class MatchState( } } - // if there are bare wires remaining, add them in all possible ways + // if there are bare wires remaining, add them in all possible ways } else if (uBareWires.nonEmpty) { val pbw = uBareWires.min candidateWires match { case None => // pull all the candidate locations for matching this bare wire. If a wire already has n bare wires matched // on it, this is n+1 possible locations. - val cwires = - for (v <- tVerts if m.target.representsWire(v) && reflectsBBoxes(pbw, v); + val cWires = + for (v <- targetVertices if m.target.representsWire(v) && reflectsBBoxes(pbw, v); i <- 0 to m.bareWireMap.get(v).map(_.length).getOrElse(0)) - yield (v,i) - copy(candidateWires = Some(cwires.toSet)).nextMatch() + yield (v, i) + copy(candidateWires = Some(cWires)).nextMatch() case Some(candidateWires1) => if (candidateWires1.isEmpty) { nextState match { @@ -156,8 +166,8 @@ case class MatchState( } } - // if all matchable verts are matched, pull the first top-level bbox and try to kill, expand, or copy+match - } else if (uBBoxes.nonEmpty && uVerts.isEmpty) { + // if all matchable verts are matched, pull the first top-level bbox and try to kill, expand, or copy+match + } else if (uBBoxes.nonEmpty && unmatchedVertices.isEmpty) { val pbb = uBBoxes.min candidateBBoxes match { case Some(candidateBBoxes1) => @@ -169,11 +179,11 @@ case class MatchState( } else { val tbb = candidateBBoxes1.min val next = copy(candidateBBoxes = Some(candidateBBoxes1 - tbb)) -// val schedule = -// m.pattern.adjacentVerts(m.pattern.contents(pbb)).filter { v => -// !m.pattern.vdata(v).isWireVertex && -// bboxesMatched(v) -// } + // val schedule = + // m.pattern.adjacentVerts(m.pattern.contents(pbb)).filter { v => + // !m.pattern.vdata(v).isWireVertex && + // bboxesMatched(v) + // } copy( m = m.addBBox(pbb -> tbb), psNodes = pNodes, // re-schedule everything @@ -192,7 +202,7 @@ case class MatchState( // if a !-box is wild, drop it, rather than copy/expand val (dropGraph, dropOp) = m.pattern.dropBBox(pbb) copy(m = m.copy(pattern = dropGraph, bbops = dropOp :: m.bbops), - nextState = Some(killState)) + nextState = Some(killState)) } else { // only expand/copy non-wild !-boxes, or we'll get infinite matchings val minV = m.pattern.contents(pbb).min val (expandGraph, expandOp) = m.pattern.expandBBox(pbb) @@ -206,7 +216,7 @@ case class MatchState( else copy( m = m.copy(pattern = copyGraph, bbops = copyOp :: m.bbops), - candidateBBoxes = Some(m.target.bboxes.filter { tbb => reflectsParentBBoxes(pbb, tbb)}), + candidateBBoxes = Some(m.target.bboxes.filter { tbb => reflectsParentBBoxes(pbb, tbb) }), bboxOrbits = bboxOrbits + (copyOp.mp.v(minV) -> minV), nextState = Some(killState) ) @@ -231,12 +241,12 @@ case class MatchState( expState.nextMatch() } - // if there is nothing left to do, check if the match is complete and return it if so. If not, continue - // the search from nextState + // if there is nothing left to do, check if the match is complete and return it if so. If not, continue + // the search from nextState } else { if (pNodes.isEmpty && m.isTotal) { if (MatchState.countMatches) MatchState.matchCounter += 1 - val ms = copy(m = m.copy(subst = angleMatcher.toMap)) + val ms = copy(m = m.copy(subst = expressionMatcher.toMap)) Some((ms.m, nextState)) } else { nextState match { @@ -256,33 +266,37 @@ case class MatchState( bboxOrbits.codf(rep).forall { pv1 => m.map.v.get(pv1) match { case None => true - case Some(tv1) => (pv <= pv1) == (tv <= tv1) + case Some(tv1) => if (pv <= pv1) { tv <= tv1 } else { true } } } case None => true } - def pVertexMayBeCompleted(vp: VName) = { - val allVerts = m.pattern.adjacentVerts(vp) - val concreteVerts = allVerts.filter(v => m.pattern.bboxesContaining(v).isEmpty) - val hasBBox = allVerts.size > concreteVerts.size + def pVertexMayBeCompleted(vp: VName): Boolean = { + val allEdges = m.pattern.adjacentEdges(vp) + val concreteEdges = allEdges.filter(e => m.pattern.bboxesContaining(m.pattern.edgeGetOtherVertex(e, vp)).isEmpty) + val hasBBox = allEdges.size > concreteEdges.size val tArity = m.target.arity(m.map.v(vp)) - - concreteVerts.size == tArity || (hasBBox && concreteVerts.size <= tArity) + concreteEdges.size == tArity || (hasBBox && concreteEdges.size <= tArity) + //val allVertices = m.pattern.adjacentVerts(vp) + //val concreteVertices = allVertices.filter(v => m.pattern.bboxesContaining(v).isEmpty) + //val hasBBox = allVertices.size > concreteVertices.size + //val tArity = m.target.arity(m.map.v(vp)) + //concreteVertices.size == tArity || (hasBBox && concreteVertices.size <= tArity) } // TODO: we may only need to check the closest parent, not all parents - def reflectsBBoxes(vp: VName, vt: VName) = + def reflectsBBoxes(vp: VName, vt: VName): Boolean = m.map.bb.directImage(m.pattern.bboxesContaining(vp)) == m.target.bboxesContaining(vt) - def bboxesMatched(vp: VName) = + def bboxesMatched(vp: VName): Boolean = m.pattern.bboxesContaining(vp).forall(m.map.bb.domSet.contains) - def reflectsParentBBoxes(bbp: BBName, bbt: BBName) = + def reflectsParentBBoxes(bbp: BBName, bbt: BBName): Boolean = m.map.bb.directImage(m.pattern.bboxParents(bbp)) == m.target.bboxParents(bbt) - def parentBBoxesMatched(bbp: BBName) = { + def parentBBoxesMatched(bbp: BBName): Boolean = { m.pattern.bboxParents(bbp).forall(m.map.bb.domSet.contains) } @@ -301,22 +315,23 @@ case class MatchState( (m.pattern.vdata(np), m.target.vdata(nt)) match { case (pd: NodeV, td: NodeV) => if (pd.typ == td.typ) { - if (pd.hasAngle) - angleMatcher.addMatch(pd.angle, td.angle).map { angleMatcher1 => - copy( - m = m.addVertex(np -> nt), - pNodes = pNodes + np, - psNodes = psNodes + np, - tVerts = tVerts - nt, - angleMatcher = angleMatcher1 - ) + if (pd.hasValue) + expressionMatcher.addMatch(pd.phaseData, td.phaseData).map { + angleMatcher1 => + copy( + m = m.addVertex(np -> nt), + pNodes = pNodes + np, + psNodes = psNodes + np, + targetVertices = targetVertices - nt, + expressionMatcher = angleMatcher1 + ) } else if (pd.value == td.value) Some(copy( m = m.addVertex(np -> nt), pNodes = pNodes + np, psNodes = psNodes + np, - tVerts = tVerts - nt + targetVertices = targetVertices - nt )) else None } else None @@ -324,24 +339,28 @@ case class MatchState( } /** - * Try to recursively add wire to matching, starting with the given head - * vertex and edge. Return NONE on failure. - * - * (ported from the ML function tryadd_wire) - * - * @param vp already-matched vertex - * @param ep unmatched edge incident to vp (other end must be in P, Uw or Un) - * @param vt target of vp - * @param et unmatched edge incident to vt - */ - def matchNewWire(vp:VName, ep:EName, vt:VName, et:EName): Option[MatchState] = { - val pdir = m.pattern.edata(ep).isDirected - val tdir = m.target.edata(et).isDirected + * Try to recursively add wire to matching, starting with the given head + * vertex and edge. Return NONE on failure. + * + * (ported from the ML function tryadd_wire) + * + * @param vp already-matched vertex + * @param ep unmatched edge incident to vp (other end must be in P, Uw or Un) + * @param vt target of vp + * @param et unmatched edge incident to vt + */ + def matchNewWire(vp: VName, ep: EName, vt: VName, et: EName): Option[MatchState] = { + val pDir = m.pattern.edata(ep).isDirected + val tDir = m.target.edata(et).isDirected val pOutEdge = m.pattern.source(ep) == vp val tOutEdge = m.target.source(et) == vt + val pType = m.pattern.edata(ep).typ + val tType = m.target.edata(et).typ // match directedness and, if the edge is directed, direction - if ((pdir && tdir && pOutEdge == tOutEdge) || (!pdir && !tdir)) { + // also match on type + // TODO: Match data on edges + if (((pDir && tDir && pOutEdge == tOutEdge) || (!pDir && !tDir)) && (pType == tType)) { val newVp = m.pattern.edgeGetOtherVertex(ep, vp) val newVt = m.target.edgeGetOtherVertex(et, vt) @@ -349,14 +368,14 @@ case class MatchState( if (m.map.v contains (newVp -> newVt)) Some(copy(psNodes = psNodes + newVp, m = m.addEdge(ep -> et))) else None - } else if (tVerts contains newVt) { + } else if (targetVertices contains newVt) { (m.pattern.vdata(newVp), m.target.vdata(newVt)) match { case (_: WireV, _: WireV) if reflectsBBoxes(newVp, newVt) && matchIsMonotone(newVp, newVt) => (m.pattern.wireVertexGetOtherEdge(newVp, ep), m.target.wireVertexGetOtherEdge(newVt, et)) match { case (Some(newEp), Some(newEt)) => copy( m = m.addEdge(ep -> et, newVp -> newVt), - tVerts = tVerts - newVt + targetVertices = targetVertices - newVt ).matchNewWire(newVp, newEp, newVt, newEt) case (Some(_), None) => None case (None, _) => @@ -376,14 +395,13 @@ case class MatchState( } object MatchState { + // use !-box orbits to ignore redundant matches + var smashSymmetries = true // for testing e.g. laziness in a single thread private var matchCounter = 0 private var countMatches = false - // use !-box orbits to ignore redundant matches - var smashSymmetries = true - - def startCountingMatches() = { + def startCountingMatches(): Unit = { matchCounter = 0 countMatches = true } diff --git a/scala/src/main/scala/quanto/rewrite/Matcher.scala b/scala/src/main/scala/quanto/rewrite/Matcher.scala index f1e462c0..80b0dbeb 100644 --- a/scala/src/main/scala/quanto/rewrite/Matcher.scala +++ b/scala/src/main/scala/quanto/rewrite/Matcher.scala @@ -1,20 +1,9 @@ package quanto.rewrite -import quanto.data._ -import scala.annotation.tailrec +import quanto.data._ object Matcher { - private def matchMain(ms: MatchState): Stream[Match] = - ms.nextMatch() match { - case Some((m1,Some(next))) => m1 #:: matchMain(next) - case Some((m1,None)) => Stream(m1) - case None => Stream() - } - def initialise(pat: Graph, tgt: Graph, restrictTo: Set[VName]): MatchState = { - // TODO: new free vars should be fresh w.r.t. vars in target - val patVars = pat.freeVars.toVector - val tgtVars = tgt.freeVars.toVector val patN = pat.normalise val tgtN = tgt.normalise val restrict0 = restrictTo intersect tgtN.verts @@ -26,8 +15,8 @@ object Matcher { MatchState( m = Match(pattern0 = patN, pattern = patN, target = tgtN), - tVerts = restrict1, - angleMatcher = AngleExpressionMatcher(patVars,tgtVars)) + targetVertices = restrict1, + expressionMatcher = CompositeExpressionMatcher()) // Create the matcher empty, it will fill itself in in time } def findMatches(pat: Graph, tgt: Graph, restrictTo: Set[VName]): Stream[Match] = { @@ -37,4 +26,11 @@ object Matcher { def findMatches(pat: Graph, tgt: Graph): Stream[Match] = findMatches(pat, tgt, tgt.verts) + private def matchMain(ms: MatchState): Stream[Match] = + ms.nextMatch() match { + case Some((m1, Some(next))) => m1 #:: matchMain(next) + case Some((m1, None)) => Stream(m1) + case None => Stream() + } + } diff --git a/scala/src/main/scala/quanto/rewrite/Rewriter.scala b/scala/src/main/scala/quanto/rewrite/Rewriter.scala index 3331ea60..ac862830 100644 --- a/scala/src/main/scala/quanto/rewrite/Rewriter.scala +++ b/scala/src/main/scala/quanto/rewrite/Rewriter.scala @@ -1,32 +1,10 @@ package quanto.rewrite -import quanto.data._ import quanto.data.Names._ +import quanto.data._ object Rewriter { - def expandRhs(m: Match, rhs: Graph): Graph = { - // ensure that *all* boundary names used in expanding bbops are avoided - val fullBoundary = m.bbops.foldRight(m.pattern0.boundary) { (bbop, vs) => - bbop match { - case BBExpand(_, mp) => vs union mp.v.directImage(vs) - case BBCopy(_, mp) => vs union mp.v.directImage(vs) - case _ => vs - } - } - - val rhs1 = m.bbops.foldRight(rhs) { (bbop, g) => g.applyBBOp(bbop, fullBoundary) } - - val vdata = rhs1.vdata.mapValues { - case d: NodeV => - val data = d.data.setPath("$.value", d.angle.subst(m.subst).toString).asObject - d.copy(data = data) - case d: WireV => d - } - - rhs1.copy(vdata = vdata) - } - - def rewrite(m: Match, rhs: Graph, desc: RuleDesc = RuleDesc()): (Graph, Rule) = { + def rewrite(m: Match, rhs: Graph, desc: RuleDesc = RuleDesc()): (Graph, Rule) = { // expand bare wires in the match val m1 = m.normalize @@ -44,23 +22,48 @@ object Rewriter { .deleteEdges(m1.map.e.codSet) .deleteVertices(m1.map.v.directImage(interiorLhs)) - val vmap = interiorRhs.foldRight(m1.map.v.restrictDom(boundary)) { (v, mp) => + val vertexMap = interiorRhs.foldRight(m1.map.v.restrictDom(boundary)) { (v, mp) => mp + (v -> (context.verts union mp.codSet).freshWithSuggestion(v)) } - val emap = rhsE.edges.foldRight(PFun[EName,EName]()) { (e, mp) => + val edgeMap = rhsE.edges.foldRight(PFun[EName, EName]()) { (e, mp) => mp + (e -> (context.edges union mp.codSet).freshWithSuggestion(e)) } // quotient the lhs and rhs such that pairs of boundaries mapped to the same vertex are identified val quotientLhs = m1.pattern.rename(m1.map.v.toMap, m1.map.e.toMap, m1.map.bb.toMap) - val quotientRhs = rhsE.rename(vmap.toMap, emap.toMap, m1.map.bb.toMap) + val quotientRhs = rhsE.rename(vertexMap.toMap, edgeMap.toMap, m1.map.bb.toMap) val ruleInst = if (desc.inverse) Rule(quotientRhs, quotientLhs, description = desc) - else Rule(quotientLhs, quotientRhs, description = desc) + else Rule(quotientLhs, quotientRhs, description = desc) // compute the pushout as a union of the context with the quotiented domain of the matching (quotientRhs.appendGraph(context), ruleInst) } + + def expandRhs(m: Match, rhs: Graph): Graph = { + // ensure that *all* boundary names used in expanding bbops are avoided + val fullBoundary = m.bbops.foldRight(m.pattern0.boundary) { (bbop, vs) => + bbop match { + case BBExpand(_, mp) => vs union mp.v.directImage(vs) + case BBCopy(_, mp) => vs union mp.v.directImage(vs) + case _ => vs + } + } + + val rhs1 = m.bbops.foldRight(rhs) { (bbop, g) => g.applyBBOp(bbop, fullBoundary) } + + // Apply all substitutions from our matches + val vdata = rhs1.vdata.mapValues { + case d: NodeV => + val data = d.data.setPath( + "$.value", m.subst.foldLeft(d.phaseData)((data, t) => data.substSubValues(t._2)).toString + ).asObject + d.copy(data = data) + case d: WireV => d + } + + rhs1.copy(vdata = vdata) + } } \ No newline at end of file diff --git a/scala/src/main/scala/quanto/rewrite/Simproc.scala b/scala/src/main/scala/quanto/rewrite/Simproc.scala index 0722e45e..849256cd 100644 --- a/scala/src/main/scala/quanto/rewrite/Simproc.scala +++ b/scala/src/main/scala/quanto/rewrite/Simproc.scala @@ -1,20 +1,31 @@ package quanto.rewrite +import quanto.cosy.{AutoReduce, RuleSynthesis} +import quanto.data.Derivation.DerivationWithHead import quanto.data._ import quanto.layout.ForceLayout +import quanto.util.UserAlerts + +import scala.util.Random abstract class Simproc { + var sourceFile: String = "" + var sourceCode: String = "" + def simp(g: Graph): Iterator[(Graph, Rule)] + // jython binding for >> + def __rshift__(t: Simproc): Simproc = this >> t + // chain two simprocs together def >>(t: Simproc) = { val s = this new Simproc { - override def simp(g: Graph): Iterator[(Graph, Rule)] = new Iterator[(Graph,Rule)] { - var iterS: Iterator[(Graph,Rule)] = s.simp(g) - var iterT: Iterator[(Graph,Rule)] = null - var lastGraphS = g + override def simp(g: Graph): Iterator[(Graph, Rule)] = new Iterator[(Graph, Rule)] { + var iterS: Iterator[(Graph, Rule)] = s.simp(g) + var iterT: Iterator[(Graph, Rule)] = _ + var lastGraphS: Graph = g override def hasNext: Boolean = if (iterT != null) iterT.hasNext @@ -28,9 +39,9 @@ abstract class Simproc { if (iterT != null) { iterT.next() } else if (iterS.hasNext) { - val (g1,r1) = iterS.next() + val (g1, r1) = iterS.next() lastGraphS = g1 - (g1,r1) + (g1, r1) } else { iterT = t.simp(lastGraphS) iterT.next() @@ -38,12 +49,66 @@ abstract class Simproc { } } } - - // jython binding for >> - def __rshift__(t: Simproc): Simproc = this >> t } object Simproc { + + // Converts a (Derivation, Head) pair into an iterated series of steps + // Allows gluing together of simprocs and derivations + implicit def fromDerivationWithHead(d: DerivationWithHead): Iterator[(Graph, Rule)] = { + if (d._2.nonEmpty) { + d._1.stepsTo(d._2.get).map(d._1.steps).map(step => (step.graph, step.rule)).toIterator + } else { + Iterator.empty + } + } + + /** + * Anneals the graph using only the rules (forwards only), using vertex size as the metric + * No initial heat specified, just accepts worse states with a (decreasing-over-time) random chance + * + * @param rules List of rules, taken forwards only + * @param steps Number of steps to be taken + * @param dilation How slowly we stop accepting worse states + * @return The resulting derivation will appear (all at once) in the side bar + */ + def ANNEAL(rules: List[Rule], + steps: Int, + dilation: Double, + seed: Random = new Random(), + vertexLimit: Option[Int] = None) = new Simproc { + override def simp(g: Graph): Iterator[(Graph, Rule)] = { + val reduced = AutoReduce.annealingReduce( + RuleSynthesis.basicGraphComparison, + RuleSynthesis.graphToDerivation(g), + rules, + steps, + dilation, + seed, + vertexLimit) + fromDerivationWithHead(reduced) + } + } + + def LOG(graphEval: Graph => String) = new Simproc { + override def simp(g: Graph): Iterator[(Graph, Rule)] = { + UserAlerts.alert(graphEval(g), UserAlerts.Elevation.NOTICE) + Iterator.empty + } + } + + // takes a list of rules and rewrites w.r.t. the first that gets a match + def REWRITE(rules: List[Rule]) = new Simproc { + override def simp(g: Graph): Iterator[(Graph, Rule)] = { + for (rule <- rules) + Matcher.findMatches(rule.lhs, g).headOption.foreach { m => + return Iterator.single(layout(Rewriter.rewrite(m, rule.rhs, rule.description))) + } + //println("got no match REWRITE: " + rules.map{_.name}.toString()) + Iterator.empty + } + } + private def layout(gr: (Graph, Rule)) = { val (graph, rule) = gr val layoutProc = new ForceLayout @@ -55,54 +120,71 @@ object Simproc { layoutProc.maxIterations = 300 //layoutProc.keepCentered = false - val rhsi = rule.rhs.verts.filter(!rule.rhs.isBoundary(_)) + val rhsInterior = rule.rhs.verts.filter(!rule.rhs.isTerminalWire(_)) //println(rhsi) - graph.verts.foreach { v => if (!rhsi.contains(v)) layoutProc.lockVertex(v) } + graph.verts.foreach { v => if (!rhsInterior.contains(v)) layoutProc.lockVertex(v) } //graph.verts.foreach { v => if (graph.isBoundary(v)) layoutProc.lockVertex(v) } (layoutProc.layout(graph, randomCoords = false).snapToGrid(), rule) //(graph, rule) } - object EMPTY extends Simproc { override def simp(g: Graph): Iterator[(Graph,Rule)] = Iterator.empty } - - // takes a list of rules and rewrites w.r.t. the first that gets a match - def REWRITE(rules: List[Rule]) = new Simproc { - override def simp(g: Graph): Iterator[(Graph, Rule)] = { - for (rule <- rules) - Matcher.findMatches(rule.lhs, g).headOption.foreach { m => - return Iterator.single(layout(Rewriter.rewrite(m, rule.rhs, rule.description))) - } - Iterator.empty - } - } - - def REWRITE_TARGETED(rule: Rule, vp: VName, targ: Graph => Option[VName]) = new Simproc { + // Applies rewrite rules, but only if the rule affects the targeted vertex + def REWRITE_TARGETED(rule: Rule, vertexInPattern: VName, target: Graph => Option[VName]) = new Simproc { override def simp(g: Graph): Iterator[(Graph, Rule)] = { - targ(g).flatMap { vt => + target(g).flatMap { vt => //println("REWRITE_TARGETED(" + rule.name + ", " + vt + ")") if (g.verts contains vt) { val ms = Matcher.initialise(rule.lhs, g, g.verts) - ms.matchNewNode(vp, vt).flatMap(_.nextMatch()) + ms.matchNewNode(vertexInPattern, vt).flatMap(_.nextMatch()) } else None } match { - case Some((m,_)) => + case Some((m, _)) => //println("SUCCESS") Iterator.single(layout(Rewriter.rewrite(m, rule.rhs, rule.description))) case None => + //println("got no match REWRITE_TARGETED: " + rule.name) //println("FAILED") Iterator.empty } } } - def REWRITE_METRIC(rules: List[Rule], metric: Graph => Int) = + def REWRITE_TARGET_LIST(rule: Rule, vp: VName, target: List[VName]): Simproc = new Simproc { + override def simp(g: Graph): Iterator[(Graph, Rule)] = { + for (vt <- target) { + if (g.verts contains vt) { + val ms = Matcher.initialise(rule.lhs, g, g.verts) + ms.matchNewNode(vp, vt).flatMap(_.nextMatch()).map { case (m, _) => + return Iterator.single(layout(Rewriter.rewrite(m, rule.rhs, rule.description))) + } + } + } + Iterator.empty + // targ(g).flatMap { vt => + // //println("REWRITE_TARGETED(" + rule.name + ", " + vt + ")") + // if (g.verts contains vt) { + // val ms = Matcher.initialise(rule.lhs, g, g.verts) + // ms.matchNewNode(vp, vt).flatMap(_.nextMatch()) + // } else None + // } match { + // case Some((m,_)) => + // //println("SUCCESS") + // Iterator.single(layout(Rewriter.rewrite(m, rule.rhs, rule.description))) + // case None => + // //println("FAILED") + // Iterator.empty + // } + } + } + + def REWRITE_METRIC(rules: List[Rule], metric: Graph => Int, target: Int = 0) = new Simproc { override def simp(g: Graph): Iterator[(Graph, Rule)] = { - if (metric(g) <= 0) return Iterator.empty + if (metric(g) <= target) return Iterator.empty for (rule <- rules) { Matcher.findMatches(rule.lhs, g).foreach { m => - val (g1,r1) = Rewriter.rewrite(m, rule.rhs, rule.description) - if (metric(g1) < metric(g)) return Iterator.single(layout((g1,r1))) + val (g1, r1) = Rewriter.rewrite(m, rule.rhs, rule.description) + if (metric(g1) < metric(g)) return Iterator.single(layout((g1, r1))) } } Iterator.empty @@ -115,43 +197,45 @@ object Simproc { if (metric(g) <= 0) return Iterator.empty for (rule <- rules) { Matcher.findMatches(rule.lhs, g).foreach { m => - val (g1,r1) = Rewriter.rewrite(m, rule.rhs, rule.description) - if (metric(g1) <= metric(g)) return Iterator.single(layout((g1,r1))) + val (g1, r1) = Rewriter.rewrite(m, rule.rhs, rule.description) + if (metric(g1) <= metric(g)) return Iterator.single(layout((g1, r1))) } } Iterator.empty } } - def REPEAT(s: Simproc): Simproc = new Simproc { - override def simp(g: Graph): Iterator[(Graph, Rule)] = new Iterator[(Graph, Rule)] { - var iterS: Iterator[(Graph,Rule)] = s.simp(g) - var lastGraphS: Graph = g - - override def hasNext: Boolean = - if (iterS.hasNext) true - else { - val iterT = s.simp(lastGraphS) - if (iterT.hasNext) { - iterS = iterT - true - } else false - } + def REPEAT(s: Simproc): Simproc = (g: Graph) => new Iterator[(Graph, Rule)] { + var iterS: Iterator[(Graph, Rule)] = s.simp(g) + var lastGraphS: Graph = g + + override def hasNext: Boolean = + if (iterS.hasNext) true + else { + val iterT = s.simp(lastGraphS) + if (iterT.hasNext) { + iterS = iterT + true + } else false + } - override def next(): (Graph,Rule) = - if (iterS.hasNext) { - val (g1,r1) = iterS.next() + override def next(): (Graph, Rule) = + if (iterS.hasNext) { + val (g1, r1) = iterS.next() + lastGraphS = g1 + (g1, r1) + } else { + val iterT = s.simp(lastGraphS) + if (iterT.hasNext) { + iterS = iterT + val (g1, r1) = iterS.next() lastGraphS = g1 - (g1,r1) - } else { - val iterT = s.simp(lastGraphS) - if (iterT.hasNext) { - iterS = iterT - val (g1,r1) = iterS.next() - lastGraphS = g1 - (g1,r1) - } else null - } - } + (g1, r1) + } else null + } + } + + object EMPTY extends Simproc { + override def simp(g: Graph): Iterator[(Graph, Rule)] = Iterator.empty } } \ No newline at end of file diff --git a/scala/src/main/scala/quanto/rewrite/rewriting.sc b/scala/src/main/scala/quanto/rewrite/rewriting.sc deleted file mode 100644 index 81308318..00000000 --- a/scala/src/main/scala/quanto/rewrite/rewriting.sc +++ /dev/null @@ -1,12 +0,0 @@ -import quanto.data._ -import quanto.rewrite._ -val e1 = AngleExpression.parse("x") -val f1 = AngleExpression.parse("2 a + b") -val p = Vector("x", "y", "z") -val t = Vector("a", "b", "c") -var m = AngleExpressionMatcher(p, t) -println(m.mat) -//m = m.mtch(e1, f1).get -//m.toMap - - diff --git a/scala/src/main/scala/quanto/util/DirectoryWatcher.scala b/scala/src/main/scala/quanto/util/DirectoryWatcher.scala index 36fd3ac8..7c07abd7 100644 --- a/scala/src/main/scala/quanto/util/DirectoryWatcher.scala +++ b/scala/src/main/scala/quanto/util/DirectoryWatcher.scala @@ -3,19 +3,17 @@ package quanto.util import java.io._ sealed abstract class FileTree + case class DirNode(name: String, files: Set[FileTree]) extends FileTree + case class FileNode(name: String) extends FileTree class DirectoryWatcher(val dir: String, onChange: FileTree => Any) extends Thread { var fileTree: FileTree = FileNode("") var poll = true - private def buildFileTree(f: File): FileTree = { - if (f.isDirectory) DirNode(f.getName, f.listFiles.toSet.map(buildFileTree)) - else FileNode(f.getName) - } override def run() { - while(poll) { + while (poll) { val ft = buildFileTree(new File(dir)) if (ft != fileTree) { fileTree = ft @@ -27,7 +25,14 @@ class DirectoryWatcher(val dir: String, onChange: FileTree => Any) extends Threa } } - def stopPolling() { poll = false } + def stopPolling() { + poll = false + } + + private def buildFileTree(f: File): FileTree = { + if (f.isDirectory) DirNode(f.getName, f.listFiles.toSet.map(buildFileTree)) + else FileNode(f.getName) + } } object DirectoryWatcher { diff --git a/scala/src/main/scala/quanto/util/FileHelper.scala b/scala/src/main/scala/quanto/util/FileHelper.scala index dec592ea..8a1010eb 100644 --- a/scala/src/main/scala/quanto/util/FileHelper.scala +++ b/scala/src/main/scala/quanto/util/FileHelper.scala @@ -1,12 +1,34 @@ package quanto.util import java.io.File +import java.net.URI import quanto.util.json.Json -import scala.util.matching.Regex +import scala.io.Source + object FileHelper { + + implicit def uriToFile(uri: URI): File = new File(uri) + + implicit def fileToURI(file: File): URI = file.toURI + + implicit def pathToFile(path: String): File = new File(path) + + val Home: URI = { + val uri = System.getProperty("user.home").toURI + if (!uri.exists) throw new IllegalStateException("Couldn't access dir: " + uri) + uri + } + + def printToFile(file_name: File, string: String, append: Boolean) { + printToFile(file_name, append) { p => { + p.println(string) + } + } + } + /** * Helper method to print to a file. * @@ -16,7 +38,7 @@ object FileHelper { */ def printToFile(file_name: File, append: Boolean = true) (op: java.io.PrintWriter => Unit) { - val p = new java.io.PrintWriter(new java.io.FileWriter(file_name, append)) + val p = new java.io.PrintWriter(new java.io.FileWriter(ensureParentFolderExists(file_name), append)) try { op(p) } finally { @@ -24,8 +46,43 @@ object FileHelper { } } + def ensureParentFolderExists(file: File): File = { + ensureFolderExists(file.getParentFile) + file + } + + def ensureFolderExists(file: File): File = { + if (!file.exists && !file.mkdirs) throw new IllegalStateException("Couldn't create dir: " + file) + file + } + + def printJson(fileName: String, json: Json): Unit = { + val targetFile = new File(fileName) + val parent = targetFile.getParentFile + if (!parent.exists && !parent.mkdirs) throw new IllegalStateException("Couldn't create dir: " + parent) + json.writeTo(new File(fileName)) + } + def readFile[T](file: File, conversion: Json => T): T = conversion(Json.parse(file)) + def readJson(file: File): Json = Json.parse(file) + + def readFile(file: File): List[String] = { + val bufferedSource = Source.fromFile(file) + val lines = bufferedSource.getLines().toList + bufferedSource.close + lines + } + + def extension(file: File): String = { + val pattern = """.*\.(\w+)""".r + file.getAbsolutePath match { + case pattern(extension) => extension + case _ => "" + } + } + + def readAllOfType[T](directory: String, regexFilter: String, conversion: Json => T): List[T] = { readJSONFromDirectory(directory, regexFilter).map(conversion) } diff --git a/scala/src/main/scala/quanto/util/Geometry.scala b/scala/src/main/scala/quanto/util/Geometry.scala index 849f6e9b..51acbd94 100644 --- a/scala/src/main/scala/quanto/util/Geometry.scala +++ b/scala/src/main/scala/quanto/util/Geometry.scala @@ -1,55 +1,75 @@ package quanto.util -import math.{min,max} +import scala.math.{max, min} class RichPt(val p: Geometry.Pt) { - def +(p1: Geometry.Pt) = (p._1 + p1._1, p._2 + p1._2) - def -(p1: Geometry.Pt) = (p._1 - p1._1, p._2 - p1._2) - def unary_-(p1: Geometry.Pt) = (-p._1, -p._2) + def +(p1: Geometry.Pt): (Double, Double) = (p._1 + p1._1, p._2 + p1._2) + + def -(p1: Geometry.Pt): (Double, Double) = (p._1 - p1._1, p._2 - p1._2) + + def unary_-(p1: Geometry.Pt): (Double, Double) = (-p._1, -p._2) } object RichPt { - implicit def richPtToPt(rp: RichPt) = rp.p - implicit def ptToRichPt(p: Geometry.Pt) = new RichPt(p) + implicit def richPtToPt(rp: RichPt): (Double, Double) = rp.p + + implicit def ptToRichPt(p: Geometry.Pt): RichPt = new RichPt(p) } class RichRect(val r: Geometry.Rect) { - def pad(padding: Double) = - r match { case (((lx,ly),(ux,uy))) => - ((lx - padding, ly - padding), (ux + padding, uy + padding)) + def pad(padding: Double): ((Double, Double), (Double, Double)) = + r match { + case (((lx, ly), (ux, uy))) => + ((lx - padding, ly - padding), (ux + padding, uy + padding)) } - def coords = (r._1,r._2) - def center: Geometry.Pt = r match { case (((lx,ly),(ux,uy))) => ((lx+ux)/2.0,(ly+uy)/2.0) } - def width: Double = r match { case ((lx,ux),_) => ux - lx } - def height: Double = r match { case (_,(ly,uy)) => uy - ly } - def size = (width, height) + + def coords: ((Double, Double), (Double, Double)) = (r._1, r._2) + + def center: Geometry.Pt = r match { + case (((lx, ly), (ux, uy))) => ((lx + ux) / 2.0, (ly + uy) / 2.0) + } + + def size: (Double, Double) = (width, height) + + def width: Double = r match { + case ((lx, ux), _) => ux - lx + } + + def height: Double = r match { + case (_, (ly, uy)) => uy - ly + } } object RichRect { - implicit def richRectToRect(rr: RichRect) = rr.r - implicit def rectToRichRect(r: Geometry.Rect) = new RichRect(r) + implicit def richRectToRect(rr: RichRect): ((Double, Double), (Double, Double)) = rr.r + + implicit def rectToRichRect(r: Geometry.Rect): RichRect = new RichRect(r) } object Geometry { - type Pt = (Double,Double) - type Rect = (Pt,Pt) - def bounds(ps: Iterable[Pt]): Option[(Pt,Pt)] = { + type Pt = (Double, Double) + type Rect = (Pt, Pt) + + def bounds(ps: Iterable[Pt]): Option[(Pt, Pt)] = { val it = ps.iterator - + if (it.hasNext) { - var upper = it.next() - var lower = upper + var upper = it.next() + var lower = upper for (p <- it) { - lower = (min(lower._1,p._1),min(lower._2,p._2)) - upper = (max(upper._1,p._1),max(upper._2,p._2)) + lower = (min(lower._1, p._1), min(lower._2, p._2)) + upper = (max(upper._1, p._1), max(upper._2, p._2)) } - Some(lower,upper) + Some(lower, upper) } else None } // implicit conversions for various geometric classes implicit def ptToPoint(p: Pt): java.awt.Point = new java.awt.Point(p._1.toInt, p._2.toInt) + implicit def ptToPoint2D(p: Pt): java.awt.geom.Point2D = new java.awt.geom.Point2D.Double(p._1, p._2) - implicit def pointToPt(p: java.awt.Point): Pt = (p.getX,p.getY) - implicit def point2DToPt(p: java.awt.geom.Point2D): Pt = (p.getX,p.getY) + + implicit def pointToPt(p: java.awt.Point): Pt = (p.getX, p.getY) + + implicit def point2DToPt(p: java.awt.geom.Point2D): Pt = (p.getX, p.getY) } diff --git a/scala/src/main/scala/quanto/util/Globals.scala b/scala/src/main/scala/quanto/util/Globals.scala index 54872e78..3207b3e1 100644 --- a/scala/src/main/scala/quanto/util/Globals.scala +++ b/scala/src/main/scala/quanto/util/Globals.scala @@ -1,14 +1,19 @@ package quanto.util -import java.io.File import java.awt.event.InputEvent +import java.io.File object Globals { - def CommandMask = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask - def CommandDownMask = if (CommandMask == InputEvent.META_MASK) InputEvent.META_DOWN_MASK - else InputEvent.CTRL_DOWN_MASK + def CommandDownMask: Int = if (CommandMask == InputEvent.META_MASK) InputEvent.META_DOWN_MASK + else InputEvent.CTRL_DOWN_MASK + + def CommandMask: Int = java.awt.Toolkit.getDefaultToolkit.getMenuShortcutKeyMask + + def isBundle: Boolean = isMacBundle || isWindowsBundle || isLinuxBundle + def isMacBundle: Boolean = new File("osx-bundle").exists + def isLinuxBundle: Boolean = new File("linux-bundle").exists + def isWindowsBundle: Boolean = new File("windows-bundle").exists - def isBundle: Boolean = isMacBundle || isWindowsBundle || isLinuxBundle } diff --git a/scala/src/main/scala/quanto/util/QuadTree.scala b/scala/src/main/scala/quanto/util/QuadTree.scala index d1889a07..8a64dd6a 100644 --- a/scala/src/main/scala/quanto/util/QuadTree.scala +++ b/scala/src/main/scala/quanto/util/QuadTree.scala @@ -1,32 +1,34 @@ package quanto.util -import math.{abs,min,max} +import scala.math.{abs, max, min} /** - * A quadtree for 2D spacial queries - */ + * A quadtree for 2D spacial queries + */ sealed abstract class QuadTree[A] { def x1: Double + def y1: Double + def x2: Double + def y2: Double def value: Option[A] - def p: (Double,Double) - def insert(p: (Double,Double), v: A): QuadTree[A] - // visit each node/leaf with 'f', recursing until f returns true or a leaf is encountered - def visit(f : QuadTree[A] => Boolean) + def p: (Double, Double) - private def intervalOverlap(s1: Double, t1: Double, s2: Double, t2: Double) = - if (s1 <= s2) t1 >= s2 else t2 >= s1 + def insert(p: (Double, Double), v: A): QuadTree[A] + + // visit each node/leaf with 'f', recursing until f returns true or a leaf is encountered + def visit(f: QuadTree[A] => Boolean) // query the tree, returning all values in regions with the given bound def query(xLower: Double, yLower: Double, xUpper: Double, yUpper: Double): Iterable[A] = { val result = collection.mutable.ListBuffer[A]() visit { tr => - if (intervalOverlap(tr.x1,tr.x2,xLower,xUpper) && intervalOverlap(tr.y1,tr.y2,yLower,yUpper)) { + if (intervalOverlap(tr.x1, tr.x2, xLower, xUpper) && intervalOverlap(tr.y1, tr.y2, yLower, yUpper)) { tr.value.map(v => result += v) false // recurse } else { @@ -35,57 +37,59 @@ sealed abstract class QuadTree[A] { } result } + + private def intervalOverlap(s1: Double, t1: Double, s2: Double, t2: Double) = + if (s1 <= s2) t1 >= s2 else t2 >= s1 } object QuadTree { - def apply[A](items: Iterable[((Double,Double),A)]): QuadTree[A] = { + def apply[A](items: Iterable[((Double, Double), A)]): QuadTree[A] = { items.headOption match { - case Some(((hx,hy),_)) => + case Some(((hx, hy), _)) => // compute bounding box - val (x1,y1,x2,y2) = items.tail.foldLeft(hx,hy,hx,hy) { - case ((minX,minY,maxX,maxY),((x,y),_)) => (min(minX,x),min(minY,y),max(maxX,x),max(maxY,y)) + val (x1, y1, x2, y2) = items.tail.foldLeft(hx, hy, hx, hy) { + case ((minX, minY, maxX, maxY), ((x, y), _)) => (min(minX, x), min(minY, y), max(maxX, x), max(maxY, y)) } // make bounds square val sz = max(x2 - x1, y2 - y1) - val emptyTree: QuadTree[A] = QuadNode[A](x1,y1,x1+sz,y1+sz) + val emptyTree: QuadTree[A] = QuadNode[A](x1, y1, x1 + sz, y1 + sz) - items.foldLeft(emptyTree) { case (tr,(p,v)) => tr.insert(p,v) } - case None => QuadNode[A](0,0,0,0) + items.foldLeft(emptyTree) { case (tr, (p, v)) => tr.insert(p, v) } + case None => QuadNode[A](0, 0, 0, 0) } } } case class QuadNode[A]( - x1: Double, y1: Double, - x2: Double, y2: Double, - - value: Option[A], - p: (Double,Double) = (0,0), - nw: QuadTree[A], - ne: QuadTree[A], - sw: QuadTree[A], - se: QuadTree[A]) extends QuadTree[A] -{ - lazy val children = Seq(nw,ne,sw,se) - val midX = (x1 + x2) / 2 - val midY = (y1 + y2) / 2 - - def insert(p: (Double,Double), v: A) = p match { - case (x,y) if x < midX && y < midY => copy(nw = nw.insert(p,v)) - case (x,y) if x >= midX && y < midY => copy(ne = ne.insert(p,v)) - case (x,y) if x < midX && y >= midY => copy(sw = sw.insert(p,v)) - case (x,y) if x >= midX && y >= midY => copy(se = se.insert(p,v)) + x1: Double, y1: Double, + x2: Double, y2: Double, + + value: Option[A], + p: (Double, Double) = (0, 0), + nw: QuadTree[A], + ne: QuadTree[A], + sw: QuadTree[A], + se: QuadTree[A]) extends QuadTree[A] { + lazy val children = Seq(nw, ne, sw, se) + val midX: Double = (x1 + x2) / 2 + val midY: Double = (y1 + y2) / 2 + + def insert(p: (Double, Double), v: A): QuadNode[A] = p match { + case (x, y) if x < midX && y < midY => copy(nw = nw.insert(p, v)) + case (x, y) if x >= midX && y < midY => copy(ne = ne.insert(p, v)) + case (x, y) if x < midX && y >= midY => copy(sw = sw.insert(p, v)) + case (x, y) if x >= midX && y >= midY => copy(se = se.insert(p, v)) } - def visit(f : QuadTree[A] => Boolean) { - if (!f(this)) { - nw.visit(f) - ne.visit(f) - sw.visit(f) - se.visit(f) - } + def visit(f: QuadTree[A] => Boolean) { + if (!f(this)) { + nw.visit(f) + ne.visit(f) + sw.visit(f) + se.visit(f) + } } //def mapValue[B](f: Option[A] => Option[B]) = QuadNode[B](x1,y1,x2,y2,f(value),p,nw,ne,sw,se) @@ -95,41 +99,43 @@ object QuadNode { // a new empty quad node with given bounds def apply[A](x1: Double, y1: Double, x2: Double, y2: Double): QuadNode[A] = { - QuadNode[A](x1,y1,x2,y2,None,(0.0,0.0)) + QuadNode[A](x1, y1, x2, y2, None, (0.0, 0.0)) } // a new node with a value, point, and the given bounds - def apply[A](x1: Double, y1: Double, x2: Double, y2: Double, value: Option[A], p: (Double,Double)): QuadNode[A] = { + def apply[A](x1: Double, y1: Double, x2: Double, y2: Double, value: Option[A], p: (Double, Double)): QuadNode[A] = { val midX = (x1 + x2) / 2 val midY = (y1 + y2) / 2 - val nw = QuadLeaf[A](x1,y1,midX,midY) - val ne = QuadLeaf[A](midX,y1,x2,midY) - val sw = QuadLeaf[A](x1,midY,midX,y2) - val se = QuadLeaf[A](midX,midY,x2,y2) - QuadNode[A](x1,y1,x2,y2,value,p,nw,ne,sw,se) + val nw = QuadLeaf[A](x1, y1, midX, midY) + val ne = QuadLeaf[A](midX, y1, x2, midY) + val sw = QuadLeaf[A](x1, midY, midX, y2) + val se = QuadLeaf[A](midX, midY, x2, y2) + QuadNode[A](x1, y1, x2, y2, value, p, nw, ne, sw, se) } } case class QuadLeaf[A]( - x1: Double, y1: Double, - x2: Double, y2: Double, - value: Option[A] = None, - p: (Double,Double) = (0,0)) extends QuadTree[A] -{ - def insert(p1: (Double,Double), v1: A) = + x1: Double, y1: Double, + x2: Double, y2: Double, + value: Option[A] = None, + p: (Double, Double) = (0, 0)) extends QuadTree[A] { + def insert(p1: (Double, Double), v1: A): QuadTree[A] = value match { case None => copy(p = p1, value = Some(v1)) case Some(v) => if (abs(p._1 - p1._1) < 0.01 && abs(p._2 - p1._2) < 0.01) { // if they are very close, place the current value on the node, and add the new value as a leaf - QuadNode[A](x1,y1,x2,y2,Some(v),p).insert(p1,v1) + QuadNode[A](x1, y1, x2, y2, Some(v), p).insert(p1, v1) } else { // otherwise place them on their own leaves - QuadNode[A](x1,y1,x2,y2).insert(p,v).insert(p1,v1) + QuadNode[A](x1, y1, x2, y2).insert(p, v).insert(p1, v1) } } - def visit(f : QuadTree[A] => Boolean) { f(this) } - def mapValue[B](f: A => B): QuadLeaf[B] = QuadLeaf[B](x1,y1,x2,y2,value.map(f(_)),p) + def visit(f: QuadTree[A] => Boolean) { + f(this) + } + + def mapValue[B](f: A => B): QuadLeaf[B] = QuadLeaf[B](x1, y1, x2, y2, value.map(f(_)), p) } diff --git a/scala/src/main/scala/quanto/util/Rational.scala b/scala/src/main/scala/quanto/util/Rational.scala index 88075b96..558ceb78 100644 --- a/scala/src/main/scala/quanto/util/Rational.scala +++ b/scala/src/main/scala/quanto/util/Rational.scala @@ -3,45 +3,60 @@ package quanto.util class RationalDivideByZeroException(r: Rational) extends Exception("Attempted to divide by 0 in (" + r.n + "/" + r.d + ")") -class Rational(numerator : Int, denominator : Int) extends Ordered[Rational] { - if (denominator == 0) throw new RationalDivideByZeroException(this) - private val r = Rational.gcd(numerator,denominator) +class Rational(numerator: Int, denominator: Int) extends Ordered[Rational] { + + private val r = Rational.gcd(numerator, denominator) private val dsn = if (denominator < 0) -1 else 1 - val n: Int = dsn * numerator/r - val d: Int = dsn * denominator/r + if (denominator == 0) throw new RationalDivideByZeroException(this) + val n: Int = dsn * numerator / r + val d: Int = dsn * denominator / r + + def +(r: Rational) = Rational(n * r.d + r.n * d, d * r.d) + + def -(r: Rational) = Rational(n * r.d - r.n * d, d * r.d) - def +(r : Rational) = Rational(n * r.d + r.n * d, d * r.d) - def -(r : Rational) = Rational(n * r.d - r.n * d, d * r.d) - def *(r : Rational) = Rational(n * r.n, d * r.d) - def *(i : Int) = Rational(n * i, d) - def /(r : Rational) = Rational(n * r.d, d * r.n) - def mod(i : Int): Rational = + def *(r: Rational) = Rational(n * r.n, d * r.d) + + def *(i: Int) = Rational(n * i, d) + + def /(r: Rational) = Rational(n * r.d, d * r.n) + + def mod(i: Int): Rational = if (n < 0) Rational((n % (d * i)) + (d * i), d) else Rational(n % (d * i), d) - def inv = Rational(d,n) - override def equals(r : Any):Boolean = r match { - case r1 : Rational => n == r1.n && d == r1.d + def inv = Rational(d, n) + + override def equals(r: Any): Boolean = r match { + case r1: Rational => n == r1.n && d == r1.d case _ => false } - override def compare(r : Rational): Int = { n * r.d - r.n * d } + + override def compare(r: Rational): Int = { + n * r.d - r.n * d + } + def isZero: Boolean = n == 0 + def isOne: Boolean = n == 1 && d == 1 - override def toString:String = if (d == 1) n.toString else "(" + n + "/" + d + ")" + override def toString: String = if (d == 1) n.toString else "(" + n + "/" + d + ")" } object Rational { def apply(numerator: Int, denominator: Int) = new Rational(numerator, denominator) - def apply(numerator: Int) = new Rational(numerator, 1) - private def gcd(a: Int,b: Int): Int = { - if(b == 0) Math.abs(a) else gcd(b, a%b) + private def gcd(a: Int, b: Int): Int = { + if (b == 0) Math.abs(a) else gcd(b, a % b) } - implicit def intToRational(i : Int) : Rational = Rational(i) + def apply(numerator: Int) = new Rational(numerator, 1) + + + + implicit def intToRational(i: Int): Rational = Rational(i) - implicit def rationalToDouble(r: Rational) : Double = r.n.toFloat / r.d.toFloat + implicit def rationalToDouble(r: Rational): Double = r.n.toFloat / r.d.toFloat } diff --git a/scala/src/main/scala/quanto/util/RationalMatrix.scala b/scala/src/main/scala/quanto/util/RationalMatrix.scala index 120bde65..42a135d3 100644 --- a/scala/src/main/scala/quanto/util/RationalMatrix.scala +++ b/scala/src/main/scala/quanto/util/RationalMatrix.scala @@ -3,48 +3,52 @@ package quanto.util /** * * A matrix of the form (M1|M2) representing a system of equations: - * M1 v = M2 c + * M1 v = M2 c * where "v" is a vector of pattern variables and "c" a vector of target variables (treated as constants) where the * final constant "pi" is treated modulo 2. * * The variables in "t" are treated as constants when judging whether a solution exists. * - * @param mat A 2D Vector of rational numbers - * @param line The position of the line between LHS and RHS matrices, i.e. the number of free variables + * @param mat A 2D Vector of rational numbers + * @param line The position of the line between LHS and RHS matrices, i.e. the number of free variables * @param constModulo Modulus to apply to the constant or -1 for no modulus. */ -class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val constModulo : Int = 2) { +class RationalMatrix(val mat: Vector[Vector[Rational]], val line: Int, val constModulo: Option[Int] = Some(2)) { + import RationalMatrix._ - def numRows : Int = mat.length - def numCols : Int = if (mat.isEmpty) 0 else mat(0).length - override def equals(that: Any) = that match { - case m1 : RationalMatrix => + def numRows: Int = mat.length + + override def equals(that: Any): Boolean = that match { + case m1: RationalMatrix => mat == m1.mat && line == line && constModulo == constModulo case _ => false } - def apply(i : Int): Vector[Rational] = mat(i) - def rows = mat + def apply(i: Int): Vector[Rational] = mat(i) - //def insertVar = new RationalMatrix(mat.map { row => ins(row, line) }, line+1, constModulo) - //def insertConst = new RationalMatrix(mat.map { row => ins(row, row.length-1) }, line, constModulo) + def rows: Vector[Vector[Rational]] = mat - def padTo(vCols: Int, cCols: Int) = { + def padTo(vCols: Int, cCols: Int): RationalMatrix = { val m = vCols - line val n = cCols - (numCols - line - 1) if (m >= 0 && n >= 0) new RationalMatrix(mat.map { row => row.slice(0, line) ++ Vector.fill(m)(Rational(0)) ++ - row.slice(line, row.length - 1) ++ Vector.fill(n)(Rational(0)) :+ - row(row.length - 1) + row.slice(line, row.length - 1) ++ Vector.fill(n)(Rational(0)) :+ + row(row.length - 1) }, math.max(line, vCols), constModulo) else this } -// private def ins(row : Vector[Rational], i : Int) : Vector[Rational] = -// row.take(i) ++ (Rational(0) +: row.takeRight(row.length - i)) + //def insertVar = new RationalMatrix(mat.map { row => ins(row, line) }, line+1, constModulo) + //def insertConst = new RationalMatrix(mat.map { row => ins(row, row.length-1) }, line, constModulo) + + def numCols: Int = if (mat.isEmpty) 0 else mat(0).length + + // private def ins(row : Vector[Rational], i : Int) : Vector[Rational] = + // row.take(i) ++ (Rational(0) +: row.takeRight(row.length - i)) def isReduced: Boolean = { var p = -1 @@ -59,8 +63,8 @@ class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val cons // translate matrix to echelon form. returns None if there is an inconsistent row, i.e. a row of the form // (0..0|v!=0..0) - def gauss : Option[RationalMatrix] = { - val empty : Option[RationalMatrix] = Some(new RationalMatrix(Vector(), line, constModulo)) + def gauss: Option[RationalMatrix] = { + val empty: Option[RationalMatrix] = Some(new RationalMatrix(Vector(), line, constModulo)) mat.foldLeft(empty) { case (Some(m), r) => m.gaussUpdate(r) case _ => None @@ -68,7 +72,7 @@ class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val cons } // add a new row, keeping the matrix in echelon form. Returns None if new row introduces an inconsistency. - def gaussUpdate(row : Vector[Rational]) : Option[RationalMatrix] = { + def gaussUpdate(row: Vector[Rational]): Option[RationalMatrix] = { var r = row // gaussian reduce the new row @@ -89,8 +93,8 @@ class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val cons } else { // otherwise use the new row to further reduce existing rows, and insert into the correct position r = normaliseRow(r, p) - val (mat1, inserted) = mat.foldLeft((Vector[Vector[Rational]](),false)) { - case ((rows, ins),row1) => + val (mat1, inserted) = mat.foldLeft((Vector[Vector[Rational]](), false)) { + case ((rows, ins), row1) => if (findPivot(row1) > p) { if (!ins) (rows :+ r :+ row1, true) else (rows :+ row1, ins) @@ -104,22 +108,22 @@ class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val cons } // multiply row by a scalar to make pivot = 1 - private def normaliseRow(row : Vector[Rational], p : Int) = { + private def normaliseRow(row: Vector[Rational], p: Int) = { val r = row(p).inv Vector.tabulate(row.length) { i => - if (constModulo != -1 && i == row.length) (row(i) * r) mod constModulo + if (constModulo.nonEmpty && i == row.length) (row(i) * r) mod constModulo.get else row(i) * r } } // subtract a multiple of row2 (with pivot p2) from row1 to make row1(p2) == 0. Assumes row2 is // normalised. - private def reduceWith(row1 : Vector[Rational], row2 : Vector[Rational], p2 : Int) = { + private def reduceWith(row1: Vector[Rational], row2: Vector[Rational], p2: Int) = { val n = row1(p2) if (!n.isZero) { Vector.tabulate(row1.length) { i => val n1 = row1(i) - (n * row2(i)) - if (constModulo != -1 && i == row1.length) n1 mod constModulo + if (constModulo.nonEmpty && i == row1.length) n1 mod constModulo.get else n1 } } else { @@ -127,11 +131,11 @@ class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val cons } } - override def toString = mat.foldLeft("\n") { (s,row) => + override def toString: String = mat.foldLeft("\n") { (s, row) => var t = "[" for (i <- row.indices) { t += " " + row(i).toString - if (i == line-1) t += " |" + if (i == line - 1) t += " |" else t += " " } s + t + "]\n" @@ -139,7 +143,7 @@ class RationalMatrix(val mat: Vector[Vector[Rational]], val line : Int, val cons } object RationalMatrix { - def findPivot(row : Vector[Rational]) : Int = + def findPivot(row: Vector[Rational]): Int = row.indexWhere(!_.isZero) } diff --git a/scala/src/main/scala/quanto/util/RichCubicCurve.scala b/scala/src/main/scala/quanto/util/RichCubicCurve.scala index 80b23501..07446fd6 100644 --- a/scala/src/main/scala/quanto/util/RichCubicCurve.scala +++ b/scala/src/main/scala/quanto/util/RichCubicCurve.scala @@ -3,21 +3,23 @@ package quanto.util import java.awt.geom.CubicCurve2D class RichCubicCurve(curve: CubicCurve2D) { + import RichCubicCurve._ - def pointAt(dist: Double) = ( + + def pointAt(dist: Double): (Double, Double) = ( bezierInterpolate(dist, curve.getX1, curve.getCtrlX1, curve.getCtrlX2, curve.getX2), bezierInterpolate(dist, curve.getY1, curve.getCtrlY1, curve.getCtrlY2, curve.getY2) ) } object RichCubicCurve { - def bezierInterpolate(dist: Double, c0: Double, c1: Double, c2: Double, c3: Double) = { - val distp = 1 - dist + def bezierInterpolate(dist: Double, c0: Double, c1: Double, c2: Double, c3: Double): Double = { + val distP = 1 - dist - (distp*distp*distp) * c0 + - 3.0 * (distp*distp) * dist * c1 + - 3.0 * (dist*dist) * distp * c2 + - (dist*dist*dist) * c3 + (distP * distP * distP) * c0 + + 3.0 * (distP * distP) * dist * c1 + + 3.0 * (dist * dist) * distP * c2 + + (dist * dist * dist) * c3 } implicit def cubicCurveToRichCubicCurve(curve: CubicCurve2D): RichCubicCurve = new RichCubicCurve(curve) diff --git a/scala/src/main/scala/quanto/util/Scripting.scala b/scala/src/main/scala/quanto/util/Scripting.scala index 10fd3ef9..2d999811 100644 --- a/scala/src/main/scala/quanto/util/Scripting.scala +++ b/scala/src/main/scala/quanto/util/Scripting.scala @@ -11,8 +11,6 @@ import quanto.util.json._ import scala.collection.JavaConverters._ import scala.concurrent.duration._ - - // object providing functions specifically for python scripting object Scripting { @@ -56,6 +54,12 @@ object Scripting { listToPyList(pyListToList[String](ss).map(load_rule)) } + def include_inverses(rs: PyList) : PyList = { + listToPyList(pyListToList[Rule](rs).flatMap(r => { + List(r, r.inverse) + }).distinct) + } + def plug(g1: Graph, g2: Graph, b1: String, b2: String) = g1.plugGraph(g2, VName(b1), VName(b2)) @@ -94,8 +98,13 @@ object Scripting { // } } + def new_graph_from_json(jsonString: String): Unit = { + val doc = QuantoDerive.newGraph() + doc.replaceJson(Json.parse(jsonString)) + } + class derivation(start : Graph) { - var d = Derivation(theory, start) + var d = Derivation(start) def rewrite(r: (String, Rule)) = { false @@ -135,7 +144,7 @@ object Scripting { } def vertex_angle_is(g: Graph, v: VName, a: String) = g.vdata(v) match { - case nv: NodeV => nv.angle == AngleExpression.parse(a) + case nv: NodeV => nv.phaseData.values == CompositeExpression.parseKnowingTypes(a, nv.phaseData.valueTypes) case _ => false } @@ -148,6 +157,39 @@ object Scripting { // python wrappers for simproc combinators val EMPTY = Simproc.EMPTY + + //takes in a python function that accepts a json version of the graph + // and outputs new json content for a new graph + def JSON_REWRITE(func: PyFunction) = new Simproc{ + override def simp(g: Graph): Iterator[(Graph, Rule)] = { + val input = g.toJson().toString() + val output = Py.tojava(func.__call__(Py.java2py(input)),classOf[String]) + Iterator.single((Graph.fromJson(Json.parse(output),project.theory),Rule(Graph(),Graph()))) + } + } + + def JSON_REWRITE_STEPS(starter: PyMethod, step_getter: PyMethod, name_getter: PyMethod) = new Simproc{ + override def simp(g: Graph): Iterator[(Graph, Rule)] = { + val input = g.toJson().toString() + val total_steps = Py.tojava(starter.__call__(Py.java2py(input)),classOf[Integer]) + var index = 0 + new Iterator[(Graph, Rule)] { + override def length = total_steps + override def hasNext: Boolean = index < total_steps + override def next(): (Graph, Rule) = { + if (index < total_steps) { + val name = Py.tojava(name_getter.__call__(Py.java2py(index)), classOf[String]) + println("Adding rewrite step: " + name) + val js = Py.tojava(step_getter.__call__(Py.java2py(index)), classOf[String]) + val rule = Rule(Graph(), Graph(), None, RuleDesc(name)) + index += 1 + (Graph.fromJson(Json.parse(js), project.theory), rule) + } + else null + } + } + } + } def REWRITE(o: Object) = o match { case list: PyList => Simproc.REWRITE(pyListToList(list)) case _ => Simproc.REWRITE(List(o.asInstanceOf[Rule])) @@ -162,6 +204,15 @@ object Scripting { Simproc.REWRITE_METRIC(rules, {g => Py.py2int(metric.__call__(Py.java2py(g)))}) } + def REWRITE_METRIC_TO(o: Object, metric: PyFunction, target: Int) = { + val rules = o match { + case list: PyList => pyListToList(list) + case _ => List(o.asInstanceOf[Rule]) + } + + Simproc.REWRITE_METRIC(rules, {g => Py.py2int(metric.__call__(Py.java2py(g)))}, target) + } + def REWRITE_WEAK_METRIC(o: Object, metric: PyFunction) = { val rules = o match { case list: PyList => pyListToList(list) @@ -180,14 +231,33 @@ object Scripting { }) } + def ANNEAL(r: PyList, steps: Int, dilation: Double) = Simproc.ANNEAL(pyListToList[Rule](r), steps, dilation) + + def LOG(graphEval: PyFunction) = Simproc.LOG( + g => { + val pyReturn = graphEval.__call__(Py.java2py(g)) + if (pyReturn.isInstanceOf[PyString]) pyReturn.toString else "" + } + ) + + def REWRITE_TARGET_LIST(rule: Rule, v: String, tlist: PyList) = Simproc.REWRITE_TARGET_LIST(rule, VName(v), pyListToList(tlist)) + def REPEAT(s: Simproc) = Simproc.REPEAT(s) // REDUCE_XXX(-) := REPEAT(REWRITE_XXX(-)) def REDUCE(o: Object) = REPEAT(REWRITE(o)) def REDUCE_TARGETED(rule: Rule, v: String, targ: PyFunction) = REPEAT(REWRITE_TARGETED(rule, v, targ)) def REDUCE_METRIC(o: Object, metric: PyFunction) = REPEAT(REWRITE_METRIC(o, metric)) + def REDUCE_METRIC_TO(o: Object, metric: PyFunction, target: Int) = REPEAT(REWRITE_METRIC_TO(o, metric, target)) def REDUCE_WEAK_METRIC(o: Object, metric: PyFunction) = REPEAT(REWRITE_WEAK_METRIC(o, metric)) + private def register_simproc(simprocName: String, simproc: Simproc, sourceFile: String): Unit = { + simproc.sourceFile = sourceFile + project.simprocs += simprocName -> simproc + } + + def register_simproc(s: String, sp: Simproc): Unit = { + register_simproc(s, sp, project.lastRunPythonFilePath.getOrElse("")) + } - def register_simproc(s: String, sp: Simproc) { project.simprocs += s -> sp } } diff --git a/scala/src/main/scala/quanto/util/SignallingStreamRedirector.scala b/scala/src/main/scala/quanto/util/SignallingStreamRedirector.scala index 61fbc8de..3c54c011 100644 --- a/scala/src/main/scala/quanto/util/SignallingStreamRedirector.scala +++ b/scala/src/main/scala/quanto/util/SignallingStreamRedirector.scala @@ -8,12 +8,19 @@ import java.io._ //case class InterruptedSignal(id: Int) extends Signal(id) sealed abstract class MessagePart + case class CodePart(c: Char) extends MessagePart + case class IntPart(i: Int) extends MessagePart + case class StringPart(s: String) extends MessagePart case class StreamMessage(parts: MessagePart*) { + def writeTo(out: OutputStream) { + writeTo(new OutputStreamWriter(out, "ISO-8859-1")) + } + def writeTo(out: Writer) { parts.foreach { case CodePart(c: Char) => @@ -28,15 +35,11 @@ case class StreamMessage(parts: MessagePart*) { out.flush() } - def writeTo(out: OutputStream) { - writeTo(new OutputStreamWriter(out, "ISO-8859-1")) - } - - def stripCodes = parts.filter{ case _: CodePart => false ; case _ => true } + def stripCodes: Seq[MessagePart] = parts.filter { case _: CodePart => false; case _ => true } } object StreamMessage { - def compileMessage(id: Int, fileName: String, code: String) = { + def compileMessage(id: Int, fileName: String, code: String): StreamMessage = { new StreamMessage( CodePart('R'), IntPart(id), CodePart(','), @@ -52,85 +55,24 @@ object StreamMessage { } /** - * A simple stream redirector that allows bits of code to listen for signals coming over the stream, all of - * this form: <<[S](id)>>, <<[F](id)>>, <<[I](id)>> indicating success, failure, and interrupt (along with - * some identifying code). - * - */ + * A simple stream redirector that allows bits of code to listen for signals coming over the stream, all of + * this form: <<[S](id)>>, <<[F](id)>>, <<[I](id)>> indicating success, failure, and interrupt (along with + * some identifying code). + * + */ class SignallingStreamRedirector(from: InputStream, to: Option[OutputStream] = None) -extends Thread("Signalling Stream Redirector") { + extends Thread("Signalling Stream Redirector") { + private val listeners = collection.mutable.Map[Int, List[(StreamMessage => Any)]]() + private val outputStreams = collection.mutable.Buffer[OutputStream]() private var state = 0 private var currentCode: Option[Char] = None private var currentId = 0 private var buf = new StringBuffer private var msgParts = Seq[MessagePart]() - private val listeners = collection.mutable.Map[Int, List[(StreamMessage => Any)]]() - private val outputStreams = collection.mutable.Buffer[OutputStream]() to.map { s => outputStreams += s } - private def fire() { - msgParts match { - case (_ :: IntPart(id) :: _) => - val msg = StreamMessage(msgParts: _*) - listeners.synchronized { - listeners.remove(id).map { _.reverse.foreach( f => f(msg)) } - } - case _ => - println("Got bad message: " + msgParts) - } - } - - // process string via tiny state machine -// private def processChar(c: Char) = state match { -// case 0 => if (c == '<') state = 1 -// case 1 => if (c == '<') state = 2 else state = 0 -// case 2 => if (c == '[') state = 3 else state = 0 -// case 3 => if (c == 'S' || c == 'F' || c == 'I') { currentCode = c ; state = 4 } -// else { currentCode = 'X' ; state = 0 } -// case 4 => if (c == ']') state = 5 -// else { currentCode = 'X' ; state = 0 } -// case 5 => if (c.isDigit) currentId = (10 * currentId) + c.toString.toInt -// else if (c == '>') state = 6 -// else { currentCode = 'X' ; currentId = 0; state = 0 } -// case 6 => if (c == '>') fire(currentCode, currentId) -// currentCode = 'X' ; currentId = 0; state = 0 -// } - - private def resetState() { - msgParts = List() - buf = new StringBuffer - currentCode = None - currentId = 0 - state = 0 - } - - private def processChar(c: Char) = - state match { - case 0 => - if (c == '\u001B') { state = 1 ; false } - else { state = 0; true } - case 1 => - msgParts :+= CodePart(c) - currentCode = Some(c.toLower) - state = 2 - false - case 2 => - if (c.isDigit) { currentId = (10 * currentId) + c.toString.toInt } - else if (c == '\u001B') { msgParts :+= IntPart(currentId); state = 3 } - else { resetState() } - false - case 3 => msgParts :+= CodePart(c) - if (Some(c) == currentCode) { fire(); resetState() } - else { state = 4 } - false - case 4 => - if (c == '\u001B') { msgParts :+= StringPart(buf.toString); buf = new StringBuffer; state = 3 } - else { buf.append(c) } - false - } - - def addListener(id: Int)(f : StreamMessage => Any) { + def addListener(id: Int)(f: StreamMessage => Any) { listeners.synchronized { listeners.put(id, listeners.get(id) match { case Some(list) => f :: list @@ -139,13 +81,29 @@ extends Thread("Signalling Stream Redirector") { } } + // process string via tiny state machine + // private def processChar(c: Char) = state match { + // case 0 => if (c == '<') state = 1 + // case 1 => if (c == '<') state = 2 else state = 0 + // case 2 => if (c == '[') state = 3 else state = 0 + // case 3 => if (c == 'S' || c == 'F' || c == 'I') { currentCode = c ; state = 4 } + // else { currentCode = 'X' ; state = 0 } + // case 4 => if (c == ']') state = 5 + // else { currentCode = 'X' ; state = 0 } + // case 5 => if (c.isDigit) currentId = (10 * currentId) + c.toString.toInt + // else if (c == '>') state = 6 + // else { currentCode = 'X' ; currentId = 0; state = 0 } + // case 6 => if (c == '>') fire(currentCode, currentId) + // currentCode = 'X' ; currentId = 0; state = 0 + // } + def addOutputStream(out: OutputStream) { outputStreams.synchronized { outputStreams += out } } - def removeOutputStream(out : OutputStream) { + def removeOutputStream(out: OutputStream) { outputStreams.synchronized { outputStreams -= out } @@ -154,12 +112,12 @@ extends Thread("Signalling Stream Redirector") { override def run() { try { val buffer: Array[Byte] = new Array[Byte](200) - val buffer1 : Array[Byte] = new Array[Byte](200) + val buffer1: Array[Byte] = new Array[Byte](200) var count: Int = from.read(buffer) while (count != -1) { var j = 0 - for (i <- 0 to count - 1) { + for (i <- 0 until count) { if (processChar(buffer(i).toChar)) { buffer1(j) = buffer(i) j += 1 @@ -181,4 +139,69 @@ extends Thread("Signalling Stream Redirector") { ex.printStackTrace() } } + + private def processChar(c: Char) = + state match { + case 0 => + if (c == '\u001B') { + state = 1; false + } + else { + state = 0; true + } + case 1 => + msgParts :+= CodePart(c) + currentCode = Some(c.toLower) + state = 2 + false + case 2 => + if (c.isDigit) { + currentId = (10 * currentId) + c.toString.toInt + } + else if (c == '\u001B') { + msgParts :+= IntPart(currentId); state = 3 + } + else { + resetState() + } + false + case 3 => msgParts :+= CodePart(c) + if (currentCode.contains(c)) { + fire(); resetState() + } + else { + state = 4 + } + false + case 4 => + if (c == '\u001B') { + msgParts :+= StringPart(buf.toString); buf = new StringBuffer; state = 3 + } + else { + buf.append(c) + } + false + } + + private def fire() { + msgParts match { + case (_ :: IntPart(id) :: _) => + val msg = StreamMessage(msgParts: _*) + listeners.synchronized { + listeners.remove(id).foreach { + _.reverse.foreach(f => f(msg)) + } + } + case _ => + println("Got bad message: " + msgParts) + } + } + + private def resetState() { + msgParts = List() + buf = new StringBuffer + currentCode = None + currentId = 0 + state = 0 + } } diff --git a/scala/src/main/scala/quanto/util/SwingTimer.scala b/scala/src/main/scala/quanto/util/SwingTimer.scala index e824d848..1267e128 100644 --- a/scala/src/main/scala/quanto/util/SwingTimer.scala +++ b/scala/src/main/scala/quanto/util/SwingTimer.scala @@ -3,7 +3,7 @@ package quanto.util object SwingTimer { def apply(interval: Int, repeats: Boolean = true)(op: => Unit) { val timeOut = new javax.swing.AbstractAction() { - def actionPerformed(e : java.awt.event.ActionEvent) = op + def actionPerformed(e: java.awt.event.ActionEvent): Unit = op } val t = new javax.swing.Timer(interval, timeOut) t.setRepeats(repeats) diff --git a/scala/src/main/scala/quanto/util/TreeSeq.scala b/scala/src/main/scala/quanto/util/TreeSeq.scala index d705325c..25e192e6 100644 --- a/scala/src/main/scala/quanto/util/TreeSeq.scala +++ b/scala/src/main/scala/quanto/util/TreeSeq.scala @@ -3,29 +3,31 @@ package quanto.util class TreeSeqFormatException(msg: String) extends Exception(msg) /** - * An abstract class representing a tree whose elements also have a sequential ordering, as in a branching history. - */ + * An abstract class representing a tree whose elements also have a sequential ordering, as in a branching history. + */ abstract class TreeSeq[A] { + import TreeSeq._ + def toSeq: Seq[A] + def indexOf(a: A): Int + def parent(a: A): Option[A] - def children(a: A): Seq[A] - private def padding(size: Int) = - Seq.fill[WhiteSpace[A]](size) { WhiteSpace[A](collapseBottom = false, collapseTop = true) } + def children(a: A): Seq[A] def flatten: Seq[(Seq[Decoration[A]], A)] = toSeq.foldLeft(Seq[(Seq[Decoration[A]], A)]()) { (rows, a) => val node = NodeLink(parent(a), children(a)) val prev = if (rows.isEmpty) Seq[Decoration[A]]() - else rows.last._1 + else rows.last._1 var inserted = false var pad = 0 // traverse the previous decoration list from left to right and construct the current decoration list - val current = prev.foldLeft(Seq[Decoration[A]]()) { (cols,col) => + val current = prev.foldLeft(Seq[Decoration[A]]()) { (cols, col) => col match { case (WireLink(dest)) => if (inserted) cols :+ WireLink(dest) @@ -43,7 +45,7 @@ abstract class TreeSeq[A] { pad += 1 cols :+ WhiteSpace[A](collapseBottom = true, collapseTop = false) } - else outs.foldLeft(cols) { (outCols,out) => + else outs.foldLeft(cols) { (outCols, out) => if (inserted) outCols :+ WireLink(out) else { if (out == a) { @@ -55,17 +57,17 @@ abstract class TreeSeq[A] { } } } - case WhiteSpace(false,_) => + case WhiteSpace(false, _) => pad += 1 cols :+ WhiteSpace[A](collapseBottom = true, collapseTop = false) - case WhiteSpace(true,_) => cols + case WhiteSpace(true, _) => cols } } rows :+ (if (!inserted) { node.input match { - case Some (p) => + case Some(p) => throw new TreeSeqFormatException("Node '" + a.toString + "' occurs before its parent '" + p + "'") case None => current :+ node } @@ -73,19 +75,17 @@ abstract class TreeSeq[A] { current }, a) } + + private def padding(size: Int) = + Seq.fill[WhiteSpace[A]](size) { + WhiteSpace[A](collapseBottom = false, collapseTop = true) + } } object TreeSeq { - sealed abstract class Decoration[A] - case class NodeLink[A](input: Option[A], outputs: Seq[A]) extends Decoration[A] - case class WireLink[A](dest: A) extends Decoration[A] - - // placeholder for whitespace. if collapse is false, the space propagates to the next rank. if it is true, it - // disappears at the next rank. - case class WhiteSpace[A](collapseBottom: Boolean, collapseTop: Boolean) extends Decoration[A] // figure out how wide a given decoration sequence is - def decorationWidth[A](dec: Seq[Decoration[A]]) = { + def decorationWidth[A](dec: Seq[Decoration[A]]): Int = { var topIndex = 0 var bottomIndex = 0 var sz = 0 @@ -93,11 +93,11 @@ object TreeSeq { case WireLink(_) => topIndex += 1 bottomIndex += 1 - sz = math.max(topIndex,bottomIndex) + sz = math.max(topIndex, bottomIndex) case NodeLink(inputOpt, outputs) => topIndex += inputOpt.size bottomIndex += math.max(1, outputs.size) - sz = math.max(topIndex,bottomIndex) + sz = math.max(topIndex, bottomIndex) case WhiteSpace(collapseBottom, collapseTop) => if (!collapseBottom) bottomIndex += 1 if (!collapseTop) topIndex += 1 @@ -105,4 +105,14 @@ object TreeSeq { sz } + + sealed abstract class Decoration[A] + + case class NodeLink[A](input: Option[A], outputs: Seq[A]) extends Decoration[A] + + case class WireLink[A](dest: A) extends Decoration[A] + + // placeholder for whitespace. if collapse is false, the space propagates to the next rank. if it is true, it + // disappears at the next rank. + case class WhiteSpace[A](collapseBottom: Boolean, collapseTop: Boolean) extends Decoration[A] } diff --git a/scala/src/main/scala/quanto/util/UserAlerts.scala b/scala/src/main/scala/quanto/util/UserAlerts.scala index e21c0ec0..d5a1f8ef 100644 --- a/scala/src/main/scala/quanto/util/UserAlerts.scala +++ b/scala/src/main/scala/quanto/util/UserAlerts.scala @@ -1,9 +1,11 @@ package quanto.util -import java.util.{Calendar, Date, UUID} +import java.io.File +import java.util.concurrent.TimeUnit +import java.util.{Calendar, Date} -import scala.swing.{Color, Dialog, Publisher} import scala.swing.event.Event +import scala.swing.{Color, Dialog, Publisher} // Universal system for alerting the user @@ -12,13 +14,79 @@ import scala.swing.event.Event // Listen to events via AlertPublisher object UserAlerts { + var ongoingProcesses: List[UserStartedProcess] = List() + var alerts: List[Alert] = List() + + def leastCompleteProcess: Option[UserStartedProcess] = { + if (ongoingProcesses.isEmpty) None else { + val indeterminate = ongoingProcesses.find(op => !op.determinate) + if (indeterminate.nonEmpty) indeterminate else { + Some(ongoingProcesses.minBy(op => op.value)) + } + } + } + + def latestMessage: Alert = { + if (alerts.headOption.nonEmpty) alerts.head else { + alert("Quantomatic starting up") + latestMessage + } + } + + def alert(message: Any): Unit = alert(message.toString, Elevation.NOTICE) + + def debug(message: String): Unit = alert(message, Elevation.DEBUG) + + def errorBox(message: String): Unit = { + alert(message, Elevation.ERROR) + Dialog.showMessage( + title = "Error", + message = message, + messageType = Dialog.Message.Error) + } + + def alert(message: String, elevation: Elevation.Elevation): Unit = { + val newAlert = Alert(Calendar.getInstance().getTime, elevation, message) + println(newAlert.toString) + alerts = newAlert :: alerts + AlertPublisher.publish(UserAlertEvent(newAlert)) + writeToLogFile(newAlert) + } + + def writeToLogFile(alert: Alert, force: Boolean = false): Unit = { + val elevation = alert.elevationText match { + case "" => "" + case e => s"[$e]" + } + if (logFile.nonEmpty && (UserOptions.logging || force)) { + FileHelper.printToFile(logFile.get)( + p => p.println(UserOptions.preferredTimeFormat.format(alert.time) + ": " + elevation + alert.message) + ) + + } + } + + def registerLogFile(optionFile: Option[File]) : Unit = { + _logFile = optionFile + } + + private var _logFile : Option[File] = None + + def logFile: Option[File] = _logFile + case class UserAlertEvent(alert: Alert) extends Event - case class UserProcessUpdate(ongoingProcess: UserStartedProcess) extends Event - object AlertPublisher extends Publisher + case class UserProcessUpdate(ongoingProcess: UserStartedProcess) extends Event class SelfAlertingProcess(name: String) extends UserStartedProcess(name) { alert(name + ": Started") + val startTime: Long = Calendar.getInstance().getTimeInMillis + + + override def halt(): Unit = { + super.halt() + alert(name + ": Halted", Elevation.NOTICE) + } override def fail(): Unit = { super.fail() @@ -27,26 +95,21 @@ object UserAlerts { override def finish(): Unit = { super.finish() - alert(name + ": Finished") + val timeTaken = TimeUnit.SECONDS.toSeconds(Calendar.getInstance().getTimeInMillis - startTime) + alert(name + s": Finished ${timeTaken / 1000.0}s") } } class UserStartedProcess(val name: String) { //private val uuid : UUID = UUID.randomUUID() //Will need for log files private var _determinate: Boolean = false - private var _value : Int = 0 - private var _failed : Boolean = false + private var _value: Int = 0 + private var _failed: Boolean = false - def failed : Boolean = _failed + def failed: Boolean = _failed - def determinate : Boolean = _determinate + def determinate: Boolean = _determinate - def value: Int = _value - def value_=(newValue: Int) : Unit = { - _value = newValue - _determinate = true - AlertPublisher.publish(UserProcessUpdate(this)) - } def setIndeterminate(): Unit = { _value = 0 _determinate = false @@ -58,29 +121,31 @@ object UserAlerts { value = 100 } + def value: Int = _value + + def value_=(newValue: Int): Unit = { + _value = newValue + _determinate = true + AlertPublisher.publish(UserProcessUpdate(this)) + } + def finish(): Unit = { value = 100 } - ongoingProcesses = this :: ongoingProcesses - AlertPublisher.publish(UserProcessUpdate(this)) - } - - var ongoingProcesses : List[UserStartedProcess] = List() - - def leastCompleteProcess : Option[UserStartedProcess] = { - if (ongoingProcesses.isEmpty) None else { - val indeterminate = ongoingProcesses.find(op => !op.determinate) - if (indeterminate.nonEmpty) indeterminate else { - Some(ongoingProcesses.minBy(op => op.value)) - } + def halt(): Unit = { + _failed = false + value = 0 } - } + ongoingProcesses = this :: ongoingProcesses + AlertPublisher.publish(UserProcessUpdate(this)) + } - case class Alert(time : Date, elevation: Elevation.Elevation, message: String) { + case class Alert(time: Date, elevation: Elevation.Elevation, message: String) { override def toString: String = UserOptions.preferredTimeFormat.format(time) + ": " + message - def color : Color= { + + def color: Color = { elevation match { case Elevation.ERROR => new Color(150, 0, 0) // Something broke case Elevation.ALERT => new Color(150, 150, 0) // Something soon to break @@ -89,36 +154,22 @@ object UserAlerts { case Elevation.NOTICE => new Color(0, 0, 150) // Nothing is broken } } - } - object Elevation extends Enumeration { - type Elevation = Value - val ALERT, ERROR, WARNING, NOTICE, DEBUG = Value - } - - var Alerts : List[Alert] = List() - - def latestMessage : Alert = { - if (Alerts.headOption.nonEmpty) Alerts.head else { - alert("Quantomatic starting up") - latestMessage + def elevationText: String = { + elevation match { + case Elevation.ERROR => "ERROR" // Something broke + case Elevation.ALERT => "ALERT" // Something soon to break + case Elevation.WARNING => "WARNING" // That would have caused something to break + case Elevation.DEBUG => "DEBUG" // I want to know how it would have broken + case Elevation.NOTICE => "" // Nothing is broken + } } } - def alert(message: String) : Unit = alert(message, Elevation.NOTICE) - - def errorbox(message: String) : Unit = { - alert(message, Elevation.ERROR) - Dialog.showMessage( - title = "Error", - message = message, - messageType = Dialog.Message.Error) - } + object AlertPublisher extends Publisher - def alert(message: String, elevation: Elevation.Elevation): Unit ={ - val newAlert = Alert(Calendar.getInstance().getTime, elevation, message) - println(newAlert.toString) - Alerts = newAlert :: Alerts - AlertPublisher.publish(UserAlertEvent(newAlert)) + object Elevation extends Enumeration { + type Elevation = Value + val ALERT, ERROR, WARNING, NOTICE, DEBUG = Value } } diff --git a/scala/src/main/scala/quanto/util/UserOptions.scala b/scala/src/main/scala/quanto/util/UserOptions.scala index fd0bfeb5..6719791d 100644 --- a/scala/src/main/scala/quanto/util/UserOptions.scala +++ b/scala/src/main/scala/quanto/util/UserOptions.scala @@ -15,16 +15,42 @@ class UserOptions { object UserOptions { + val prefs: Preferences = Preferences.userRoot().node(this.getClass.getName) + // A scaling of 1 corresponds to a font size of 14 + // changing the font size will also change the rest of the scaling + private var _uiScale: Double = prefs.getDouble("uiScale", 1) + private var _fontDecoration = Font.PLAIN + private var _fontFamily = "defaultFont" + private var _logging: Boolean = prefs.getBoolean("logging", false) + private var _graphScale: Double = 1 + private var _preferredTimeFormat: SimpleDateFormat = new SimpleDateFormat("HH:mm:ss") + private var _preferredDateTimeFormat: SimpleDateFormat = new SimpleDateFormat("yy-MM-dd.HH:mm:ss") - case class UIRedrawRequest() extends Event + def scaleInt(d: Double): Int = math.floor(scale(d)).toInt - object OptionsChanged extends Publisher + def scale(d: Double): Double = { + d * uiScale + } + + def resetScale(): Unit = { + uiScale = 1 + } + + def uiScale: Double = _uiScale + + def uiScale_=(d: Double) { + _uiScale = d + _uiScale = math.max(_uiScale, 0.5) // Limit scaling to equivalent of 7pt font + _uiScale = math.min(_uiScale, 4) // Limit scaling to equivalent of 56pt font + setUIFont(new FontUIResource(_fontFamily, _fontDecoration, fontSize)) + prefs.putDouble("uiScale", _uiScale) + requestUIRefresh() + } private def requestUIRefresh(): Unit = { OptionsChanged.publish(UIRedrawRequest()) } - // Changes the default font but doesn't request redraw private def setUIFont(f: FontUIResource): Unit = { val keys = UIManager.getDefaults.keys while ( { @@ -36,67 +62,70 @@ object UserOptions { } } - // A scaling of 1 corresponds to a font size of 14 - // changing the font size will also change the rest of the scaling - private var _uiScale : Double = 1 - def uiScale : Double = _uiScale - def uiScale_=(d: Double){ - _uiScale = math.max(d, 0.5) // Limit scaling to equivalent of 7pt font - setUIFont(new FontUIResource(_fontFamily, _fontDecoration, fontSize)) - val prefs = Preferences.userRoot().node(this.getClass.getName) - prefs.putDouble("uiScale", _uiScale) - requestUIRefresh() + def font: Font = { + new Font(_fontFamily, _fontDecoration, fontSize) } - def scale(d: Double) : Double = { - d * uiScale + def fontSize: Int = { + Math.floor(14 * _uiScale).toInt } - def scaleInt(d: Double) : Int = math.floor(scale(d)).toInt - - def resetScale(): Unit ={ - uiScale = 1 + def fontSize_=(n: Int) { + uiScale = n / 14.0 //changing uiScale triggers redraw and font size changes } - def font : Font = { - new Font(_fontFamily, _fontDecoration, fontSize) - } + def fontDecoration: Int = _fontDecoration - private var _fontDecoration = Font.PLAIN - def fontDecoration : Int = _fontDecoration - def fontDecoration_=(n: Int){ + def fontDecoration_=(n: Int) { _fontDecoration = n } - private var _fontFamily = "defaultFont" - def fontFamily : String = _fontFamily - def fontFamily_=(s: String): Unit ={ + def fontFamily: String = _fontFamily + + def fontFamily_=(s: String): Unit = { _fontFamily = s } - def fontSize : Int = { - Math.floor(14*_uiScale).toInt - } - def fontSize_=(n: Int){ - uiScale = n / 14.0 //changing uiScale triggers redraw and font size changes + def logging: Boolean = _logging + + def logging_=(b: Boolean): Unit = { + prefs.putBoolean("logging", b) + _logging = b } - private var _graphScale : Double = 1 - def graphScale : Double = _graphScale - def graphScale_=(n: Double): Unit ={ + def graphScale: Double = _graphScale + + def graphScale_=(n: Double): Unit = { _graphScale = n } - private var _preferredTimeFormat : SimpleDateFormat = new SimpleDateFormat("HH:mm:ss") - def preferredTimeFormat : SimpleDateFormat = _preferredTimeFormat + def preferredTimeFormat: SimpleDateFormat = _preferredTimeFormat + def preferredTimeFormat_=(format: String): Unit = { try { _preferredTimeFormat = new SimpleDateFormat(format) } - catch { - case e: Exception => - e.printStackTrace() - } + catch { + case e: Exception => + e.printStackTrace() + } + } + + def preferredDateTimeFormat: SimpleDateFormat = _preferredDateTimeFormat + + def preferredDateTimeFormat_=(format: String): Unit = { + try { + _preferredDateTimeFormat = new SimpleDateFormat(format) + } + catch { + case e: Exception => + e.printStackTrace() + } } + case class UIRedrawRequest() extends Event + + object OptionsChanged extends Publisher + + } diff --git a/scala/src/main/scala/quanto/util/WebHelper.scala b/scala/src/main/scala/quanto/util/WebHelper.scala index 8f8ad34c..f3c3ee63 100644 --- a/scala/src/main/scala/quanto/util/WebHelper.scala +++ b/scala/src/main/scala/quanto/util/WebHelper.scala @@ -1,6 +1,7 @@ package quanto.util -import java.net.URL + import java.awt.Desktop +import java.net.URL object WebHelper { diff --git a/scala/src/test/scala/quanto/cosy/test/BlockEnumerationSpec.scala b/scala/src/test/scala/quanto/cosy/test/BlockEnumerationSpec.scala index c1047ac1..04c165b6 100644 --- a/scala/src/test/scala/quanto/cosy/test/BlockEnumerationSpec.scala +++ b/scala/src/test/scala/quanto/cosy/test/BlockEnumerationSpec.scala @@ -2,13 +2,16 @@ package quanto.cosy.test import quanto.cosy._ import org.scalatest.FlatSpec -import quanto.data.Rule +import quanto.data.{Graph, GraphTikz, Rule} import quanto.rewrite.Matcher +import quanto.cosy.BlockRowMaker._ +import quanto.util.FileHelper /** * Created by hector on 28/06/17. */ + class BlockEnumerationSpec extends FlatSpec { implicit def quickList(n: Int): List[Int] = { @@ -22,31 +25,31 @@ class BlockEnumerationSpec extends FlatSpec { behavior of "Block Enumeration" it should "build a small ZW row" in { - var rowsAllowed = BlockRowMaker(1, allowedBlocks = BlockRowMaker.ZW) + var rowsAllowed = BlockRowMaker(1, allowedBlocks = BlockGenerators.ZW) println(rowsAllowed) } it should "build bigger ZW rows" in { - var rowsAllowed = BlockRowMaker(2, allowedBlocks = BlockRowMaker.ZW) + var rowsAllowed = BlockRowMaker(2, allowedBlocks = BlockGenerators.ZW) println(rowsAllowed) assert(rowsAllowed.length == 11 * 11 + 11) } it should "stack rows" in { - var rowsAllowed = BlockRowMaker(1, allowedBlocks = BlockRowMaker.ZW) + var rowsAllowed = BlockRowMaker(1, allowedBlocks = BlockGenerators.ZW) var stacks = BlockStackMaker(2, rowsAllowed) println(stacks) } it should "limit wires" in { - var rowsAllowed = BlockRowMaker(2, allowedBlocks = BlockRowMaker.ZW, maxInOut = Option(2)) + var rowsAllowed = BlockRowMaker(2, allowedBlocks = BlockGenerators.ZW, maxInOut = Option(2)) var stacks = BlockStackMaker(2, rowsAllowed) println(stacks) assert(stacks.forall(s => (s.inputs.length <= 2) && (s.outputs.length <= 2))) } it should "compute tensors" in { - var rowsAllowed = BlockRowMaker(2, allowedBlocks = BlockRowMaker.ZW) + var rowsAllowed = BlockRowMaker(2, allowedBlocks = BlockGenerators.ZW) var stacks = BlockStackMaker(2, rowsAllowed) for (elem <- stacks) { println("---\n" + elem.toString + " = \n" + elem.tensor) @@ -66,7 +69,7 @@ class BlockEnumerationSpec extends FlatSpec { it should "find wire identities" in { var rowsAllowed = BlockRowMaker(1, allowedBlocks = List( - Block(1, 1, " 1 ", Tensor.idWires(1)), + Block(1, 1, " 1 ", Tensor.idWires(1), new Graph()), Block(1, 1, " w ", new Tensor(Array(Array[Complex](1, 0), Array[Complex](0, -1)))), Block(1, 1, " b ", new Tensor(Array(Array[Complex](0, 1), Array[Complex](1, 0)))) )) @@ -109,42 +112,42 @@ class BlockEnumerationSpec extends FlatSpec { behavior of "Stack to Graph" it should "convert a block to a graph" in { - var b = BlockRowMaker.Bian2Qubit(2) // T-gate - var g = BlockRowMaker.Bian2QubitToGraph(b) + var b = BlockGenerators.Bian2Qubit(2) // T-gate + var g = BlockGenerators.Bian2QubitToGraph(b) println(g.toString) } it should "convert a row to a graph" in { - var B2 = BlockRowMaker.Bian2Qubit + var B2 = BlockGenerators.Bian2Qubit var r = new BlockRow(List(B2(2), B2(3))) // T x H - var g = BlockRowMaker.rowToGraph(r, BlockRowMaker.Bian2QubitToGraph) + var g = BlockRowMaker.predicateRowToGraph(r, BlockGenerators.Bian2QubitToGraph) println(g.toString) } it should "convert a stack to a graph" in { - var B2 = BlockRowMaker.Bian2Qubit + var B2 = BlockGenerators.Bian2Qubit var r = new BlockRow(List(B2(2), B2(3))) // T x H - var g = BlockRowMaker.stackToGraph( + var g = BlockRowMaker.predicateStackToGraph( new BlockStack(List(r, r)), - BlockRowMaker.Bian2QubitToGraph) + BlockGenerators.Bian2QubitToGraph) println(g) } behavior of "qutrits and qudits" it should "generate enough qutrit generators" in { - assert(BlockRowMaker.ZXQutrit(9).length == (10 + 2 * 81)) + assert(BlockGenerators.ZXQutrit(9).length == (10 + 2 * 81)) } it should "generate enough qudit generators" in { - assert(BlockRowMaker.ZXQudit(3, 9).length == (10 + 2 * 81)) + assert(BlockGenerators.ZXQudit(3, 9).length == (10 + 2 * 81)) // And check it is the correct swap tensor: - assert(BlockRowMaker.ZXQudit(3, 9)(1).tensor == BlockRowMaker.ZXQutrit(9)(1).tensor) - assert(BlockRowMaker.ZXQudit(4, 8).length == (10 + 2 * math.pow(8, 4 - 1)).toInt) + assert(BlockGenerators.ZXQudit(3, 9)(1).tensor == BlockGenerators.ZXQutrit(9)(1).tensor) + assert(BlockGenerators.ZXQudit(4, 8).length == (10 + 2 * math.pow(8, 4 - 1)).toInt) } it should "have spider rules for qudits" in { - var Q4 = BlockRowMaker.ZXQudit(4, 8) + var Q4 = BlockGenerators.ZXQudit(4, 8) var r760 = Q4.find(p => p.name == "r|7|6|0") var r230 = Q4.find(p => p.name == "r|2|3|0") var r110 = Q4.find(p => p.name == "r|1|1|0") @@ -157,7 +160,7 @@ class BlockEnumerationSpec extends FlatSpec { behavior of "Bell Simple" it should "display quantum teleportation" in { - var BSRow = BlockRowMaker(2, BlockRowMaker.BellTeleportation, Option(3)) + var BSRow = BlockRowMaker(2, BlockGenerators.BellTeleportation, Option(3)) var BSStacks = BlockStackMaker(4, BSRow) var tp = BSStacks. //filterNot(x => x.toString.matches(raw".*\(w\d \).*")). @@ -169,4 +172,59 @@ class BlockEnumerationSpec extends FlatSpec { tp.foreach(x => println("---- \n " + x.toString + "\n" + x.tensor)) assert(tp.length == 4) } + + behavior of "ZX Clifford" + + it should "Find CZ gate" in { + var BSRow = BlockRowMaker(2, BlockGenerators.ZXClifford, Option(2)) + var BSStacks = BlockStackMaker(3, BSRow) + var tp = BSStacks. + filter(x=> x.tensor.isSameShapeAs(Tensor.idWires(2))). + filter(x => x.tensor.isRoughlyUpToScalar( + Tensor(Array(Array(1, 0, 0, 0), Array(0, 1, 0, 0), Array(0, 0, 1, 0), Array(0, 0, 0, -1)))) + ) + tp.foreach(x => println("---- \n " + x.toString + "\n" + x.tensor)) + assert(tp.nonEmpty) + } + + behavior of "graph generation" + + val ZXClifford = BlockGenerators.ZXClifford + + it should "make a pair horizontally" in { + var row = new BlockRow(List(ZXClifford(1), ZXClifford(3))) + var g = row.graph + assert(g.verts.size == 6) + } + + it should "make a pair vertically" in { + var row = new BlockRow(List(ZXClifford(0))) + var g = BlockStack(List(row, row)).graph + assert(g.verts.size == 4) + assert(g.edges.size == 3) + assert(g.minimise.edges.size == 1) + } + + it should "make a 2x2" in { + var row = new BlockRow(List(ZXClifford(5), ZXClifford(1))) + var row2 = new BlockRow(List(ZXClifford(1), ZXClifford(5))) + var g = BlockStack(List(row, row2)).graph + assert(g.verts.size == 12) + } + + it should "cache rows" in { + var row = new BlockRow(List(ZXClifford(0), ZXClifford(1))) + var row2 = new BlockRow(List(ZXClifford(5), ZXClifford(0))) + var row3 = new BlockRow(List(ZXClifford(0), ZXClifford(5))) + var b1 = BlockStack(List(row, row2, row3)) + var b2 = BlockStack(List(row2, row3)).append(row) + var b3 = BlockStack(List(row3)).append(row2).append(row) + var g1 = b1.graph + var g2 = b2.graph + assert(b1.tensor == b2.tensor) + assert(b1.tensor == b3.tensor) + assert(g1.edges.size == g2.edges.size) + assert(g1.verts.toList.sorted == g2.verts.toList.sorted) + } + } \ No newline at end of file diff --git a/scala/src/test/scala/quanto/cosy/test/BlockGeneratorsSpec.scala b/scala/src/test/scala/quanto/cosy/test/BlockGeneratorsSpec.scala new file mode 100644 index 00000000..c23bcf53 --- /dev/null +++ b/scala/src/test/scala/quanto/cosy/test/BlockGeneratorsSpec.scala @@ -0,0 +1,70 @@ +package quanto.cosy.test + +import org.scalatest.FlatSpec +import quanto.cosy.BlockGenerators._ +import quanto.cosy._ + +import scala.concurrent.duration.Duration +import java.io.File + +import quanto.data._ +import quanto.rewrite.Matcher +import quanto.util.{FileHelper, Rational} +import quanto.util.json.Json +import quanto.data.Names._ + +import scala.util.Random + +class BlockGeneratorsSpec extends FlatSpec { + + + behavior of "ZX generators" + + it should "make the right number of generators" in { + // CNOTs start at width 2 + assert(zxCNOTs(2).size == 1) + assert(zxCNOTs(4).size == 3) + assert(zxTONCs(4).size == 3) + // Make X and Z twists + val twists = zxQubitTwists(8) + assert(twists.size == 16) + } + + it should "make the right angles" in { + val twists = zxQubitTwists(8) + assert(twists(2).graph.vdata("X").asInstanceOf[NodeV].phaseData.values.head.constant == Rational(1, 4)) + assert(twists(15).graph.vdata("Z").asInstanceOf[NodeV].phaseData.values.head.constant == Rational(7, 4)) + } + + it should "make hadamards" in { + val twists = zxQubitTwists(8) + assert((zxQubitHadamard o twists(2) o zxQubitHadamard).isRoughlyUpToScalar(twists(3))) + } + + it should "make correct CNOTs and TNOCs" in { + assert((zxCNOT(4) o zxCNOT(4)).isRoughlyUpToScalar(Tensor.idWires(4))) + assert((zxTONC(4) o zxTONC(4)).isRoughlyUpToScalar(Tensor.idWires(4))) + assert((zxCNOT(4) o zxTONC(4) o zxCNOT(4)).isRoughlyUpToScalar(Tensor.swap(List(3, 1, 2, 0)))) + val c4 = zxCNOT(4) + val c4graphInterpretation = Interpreter.interpretZXGraph(c4.graph, + List("i-0", "i-1", "i-2", "i-3"), + List("o-0", "o-1", "o-2", "o-3")) + assert(c4.tensor.isRoughlyUpToScalar(c4graphInterpretation)) + + val t4 = zxTONC(4) + val t4graphInterpretation = Interpreter.interpretZXGraph(t4.graph, + List("i-0", "i-1", "i-2", "i-3"), + List("o-0", "o-1", "o-2", "o-3")) + assert(t4.tensor.isRoughlyUpToScalar(t4graphInterpretation)) + + } + + it should "not have string edges, only rails" in { + val blocks: List[Block] = BlockGenerators.ZXGates(4, 1) + val rows: List[BlockRow] = BlockRowMaker.makeRowsUpToSize(1, blocks, Some(1)) + val stacks = BlockStackMaker.makeStacksOfSize(2, rows) + val graphs = stacks.map(_.graph) + val graphsWithStrings = graphs.filter(g => g.edata.values.exists(ed => ed.typ == "string")) + assert(graphsWithStrings.isEmpty) + } +} diff --git a/scala/src/test/scala/quanto/cosy/test/CoSyRunSpec.scala b/scala/src/test/scala/quanto/cosy/test/CoSyRunSpec.scala new file mode 100644 index 00000000..e4b2cebb --- /dev/null +++ b/scala/src/test/scala/quanto/cosy/test/CoSyRunSpec.scala @@ -0,0 +1,263 @@ +package quanto.cosy.test + +import org.scalatest.FlatSpec +import quanto.cosy.CoSyRuns._ +import quanto.cosy._ + +import scala.concurrent.duration.Duration +import java.io.File + +import quanto.data.Theory.VertexDesc +import quanto.data._ +import quanto.rewrite.Matcher +import quanto.util.FileHelper +import quanto.util.json.{Json, JsonObject} + +import scala.util.Random + +/** + * Created by hector on 24/05/17. + */ +class CoSyRunSpec extends FlatSpec { + + behavior of "ZX" + + it should "do a small run" in { + var theory = Theory.fromFile("red_green") + var CR = new CoSyRuns.CoSyZX(duration = Duration.Inf, + numBoundaries = List(0, 1, 2), + outputDir = None, + scalars = false, + numVertices = 2, + rulesDir = new File("./cosy/"), theory = theory, + numAngles = 4) + + def interpret(g: Graph) = Interpreter.interpretZXGraph(g, g.verts.filter(g.isBoundary).toList, List()) + // Don't test this here + // It isn't a standard feature + // And pollutes the filesystem + // Ask hmillerbakewell@gmail.com for more information + + // CR.begin() + /* + val reduced = RuleSynthesis.minimiseRuleset(CR.reductionRules, theory, new Random(1)) + //reduced.foreach(r => FileHelper.printJson(s"./cosy/${r.lhs.hashCode}-${r.rhs.hashCode}.qrule", Rule.toJson(r, theory))) + reduced.foreach(r => + assert(interpret(r.lhs).isRoughly(interpret(r.rhs)))) + */ + + } + + + behavior of "ZX with bool" + + it should "do a small run" in { + val theory = Theory.fromJson( + """ + |{ + | "name": "ZXH", + | "core_name": "zxh", + | "vertex_types": { + | "Z": { + | "value": { + | "type": "angle_expr, bool", + | "latex_constants": true, + | "validate_with_core": false + | }, + | "style": { + | "label": { + | "position": "center", + | "fg_color": [ + | 0.0, + | 0.0, + | 0.0 + | ] + | }, + | "stroke_color": [ + | 0.0, + | 0.0, + | 0.0 + | ], + | "fill_color": [ + | 0.0, + | 1.0, + | 0.0 + | ], + | "shape": "circle" + | }, + | "default_data": { + | "type": "Z", + | "value": "0, 0" + | } + | }, + | "X": { + | "value": { + | "type": "angle_expr, bool", + | "latex_constants": true, + | "validate_with_core": false + | }, + | "style": { + | "label": { + | "position": "center", + | "fg_color": [ + | 0.0, + | 0.0, + | 0.0 + | ] + | }, + | "stroke_color": [ + | 0.0, + | 0.0, + | 0.0 + | ], + | "fill_color": [ + | 1.0, + | 0.0, + | 0.0 + | ], + | "shape": "circle" + | }, + | "default_data": { + | "type": "X", + | "value": "0, 0" + | } + | }, + | "hadamard": { + | "value": { + | "type": "empty", + | "latex_constants": true, + | "validate_with_core": false + | }, + | "style": { + | "label": { + | "position": "center", + | "fg_color": [ + | 0.0, + | 0.0, + | 0.0 + | ] + | }, + | "stroke_color": [ + | 0.0, + | 0.0, + | 0.0 + | ], + | "fill_color": [ + | 1.0, + | 1.0, + | 0.0 + | ], + | "shape": "rectangle" + | }, + | "default_data": { + | "type": "hadamard", + | "value": "" + | } + | }, + | "dummyBoundary": { + | "value": { + | "type": "empty", + | "latex_constants": true, + | "validate_with_core": false + | }, + | "style": { + | "label": { + | "position": "center", + | "fg_color": [ + | 0.0, + | 0.0, + | 0.0 + | ] + | }, + | "stroke_color": [ + | 0.0, + | 0.0, + | 0.0 + | ], + | "fill_color": [ + | 0.0, + | 1.0, + | 1.0 + | ], + | "shape": "rectangle" + | }, + | "default_data": { + | "type": "dummyBoundary", + | "value": "" + | } + | } + | }, + | "default_vertex_type": "Z", + | "default_edge_type": "plain", + | "edge_types": { + | "plain": { + | "value": { + | "type": "empty", + | "latex_constants": false, + | "validate_with_core": false + | }, + | "style": { + | "stroke_color": [ + | 0.0, + | 0.0, + | 0.0 + | ], + | "stroke_width": 1, + | "label": { + | "position": "auto", + | "fg_color": [ + | 0.0, + | 0.0, + | 0.0 + | ] + | } + | }, + | "default_data": { + | "type": "plain" + | } + | } + | } + |} + """.stripMargin) + var CR = new CoSyRuns.CoSyZXBool(duration = Duration.Inf, + numBoundaries = List(0, 1, 2), + outputDir = None, + scalars = false, + numVertices = 3, + rulesDir = new File("./cosy/"), theory = theory, + numAngles = 4) + + // Don't test this here + // It isn't a standard feature + // And pollutes the filesystem + // Ask hmillerbakewell@gmail.com for more information + + //CR.begin() + assert(1 == 1) + /* + val reduced = RuleSynthesis.minimiseRuleset(CR.reductionRules, theory, new Random(1)) + //reduced.foreach(r => FileHelper.printJson(s"./cosy/${r.lhs.hashCode}-${r.rhs.hashCode}.qrule", Rule.toJson(r, theory))) + reduced.foreach(r => + assert(interpret(r.lhs).isRoughly(interpret(r.rhs)))) + */ + + } + + behavior of "ZX Circuit" + + it should "do a small run" in { + val ZXRails = Theory.fromFile("ZXRails") + + // THIS WILL RUN UNTIL THE TIME RUNS OUT. + var CR = new CoSyRuns.CoSyCircuit(duration = Duration(4, "minutes"), numBoundaries = 3, + outputDir = Some(new File("./cosy_synth/")), + rulesDir = new File("./cosy_synth/"), theory = ZXRails) + //var rules = CR.begin() + /* + val reduced = RuleSynthesis.minimiseRuleset(CR.reductionRules, theory, new Random(1)) + reduced.foreach(r => FileHelper.printJson(s"./cosy/${r.lhs.hashCode}-${r.rhs.hashCode}.qrule", Rule.toJson(r, theory))) + */ + assert(1 == 1) + } + +} diff --git a/scala/src/test/scala/quanto/cosy/test/ColbournReadEnumSpec.scala b/scala/src/test/scala/quanto/cosy/test/ColbournReadEnumSpec.scala index 1d2ea1cf..7f78cf95 100644 --- a/scala/src/test/scala/quanto/cosy/test/ColbournReadEnumSpec.scala +++ b/scala/src/test/scala/quanto/cosy/test/ColbournReadEnumSpec.scala @@ -2,6 +2,8 @@ package quanto.cosy.test import org.scalatest._ import quanto.cosy._ +import quanto.data.{Graph, NodeV, Theory} +import quanto.util.json.JsonObject class ColbournReadEnumSpec extends FlatSpec { @@ -198,4 +200,28 @@ class ColbournReadEnumSpec extends FlatSpec { assert(numBi(6) === 1 + 1 + 1 + 1 + 3 + 5 + 17) assert(numBi(7) === 1 + 1 + 1 + 1 + 3 + 5 + 17 + 44) } + + behavior of "conversion to graph" + + it should "convert small amats into ZX graphs" in { + val rg = Theory.fromFile("red_green") + + val pi = math.Pi + val rdata = Vector( + NodeV(data = JsonObject("type" -> "X", "value" -> "0"), theory = rg), + NodeV(data = JsonObject("type" -> "X", "value" -> pi.toString), theory = rg), + NodeV(data = JsonObject("type" -> "X", "value" -> (0.5 * pi).toString), theory = rg), + NodeV(data = JsonObject("type" -> "X", "value" -> (-0.5 * pi).toString), theory = rg) + ) + val gdata = Vector( + NodeV(data = JsonObject("type" -> "Z", "value" -> "0"), theory = rg), + NodeV(data = JsonObject("type" -> "Z", "value" -> pi.toString), theory = rg), + NodeV(data = JsonObject("type" -> "Z", "value" -> (0.5 * pi).toString), theory = rg), + NodeV(data = JsonObject("type" -> "Z", "value" -> (-0.5 * pi).toString), theory = rg) + ) + + var one = Complex.one + def quickGraph(amat: AdjMat) : Graph = Graph.fromAdjMat(amat, rdata, gdata) + ColbournReadEnum.enumerate(2,2,2,2).map(quickGraph) + } } diff --git a/scala/src/test/scala/quanto/cosy/test/EQCAnalysisSpec.scala b/scala/src/test/scala/quanto/cosy/test/EQCAnalysisSpec.scala index e7e70076..69e6a623 100644 --- a/scala/src/test/scala/quanto/cosy/test/EQCAnalysisSpec.scala +++ b/scala/src/test/scala/quanto/cosy/test/EQCAnalysisSpec.scala @@ -68,7 +68,7 @@ class EQCAnalysisSpec extends FlatSpec { assert(!adjacencyMatrix._2(bIndex)(vIndex)) - val ghostedErrors = GraphAnalysis.bypassSpecial(GraphAnalysis.detectErrors)(targetGraph, adjacencyMatrix) + val ghostedErrors = GraphAnalysis.bypassSpecial(GraphAnalysis.detectPiNodes)(targetGraph, adjacencyMatrix) assert(ghostedErrors._2(bIndex)(eIndex)) assert(ghostedErrors._2(bIndex)(vIndex)) } diff --git a/scala/src/test/scala/quanto/cosy/test/EquivalenceClassesSpec.scala b/scala/src/test/scala/quanto/cosy/test/EquivalenceClassesSpec.scala index 7c3a2947..d4cca9a2 100644 --- a/scala/src/test/scala/quanto/cosy/test/EquivalenceClassesSpec.scala +++ b/scala/src/test/scala/quanto/cosy/test/EquivalenceClassesSpec.scala @@ -6,6 +6,7 @@ import java.nio.file.Paths import org.scalatest._ import quanto.cosy._ import quanto.data._ +import quanto.util.FileHelper import quanto.util.json.{JsonArray, JsonObject} import scala.util.parsing.json.JSON @@ -123,18 +124,6 @@ class EquivalenceClassesSpec extends FlatSpec { behavior of "IO" - it should "save run results to file" in { - var results = EquivClassRunAdjMat(numAngles = 4, - tolerance = EquivClassRunAdjMat.defaultTolerance, - rulesList = emptyRuleList, - theory = rg) - var testFile = new File("test_run_output.qrun") - quanto.util.FileHelper.printToFile(testFile, append = false)( - p => p.println(results.toJSON.toString()) - ) - assert(testFile.delete()) - } - it should "output to and input from JSON" in { var results = EquivClassRunAdjMat(numAngles = 4, tolerance = EquivClassRunAdjMat.defaultTolerance, @@ -161,26 +150,6 @@ class EquivalenceClassesSpec extends FlatSpec { behavior of "batch runner" - it should "create an output qrun file" in { - EquivClassBatchRunner(4, 2, 2, "test.qrun") - var testFile = new File(EquivClassBatchRunner.outputPath + "/" + "test.qrun") - assert(testFile.exists()) - assert(testFile.delete()) - } - - it should "allow outputs to home directory" in { - EquivClassBatchRunner.outputPath = Paths.get(System.getProperty("user.home"), "cosy_synth").toString - println(EquivClassBatchRunner.outputPath) - EquivClassBatchRunner.outputPath = "cosy_synth" // reset to avoid problems in later tests - } - - it should "create an output qtensor file" in { - TensorBatchRunner(1, 2, 2) - var testFile = new File(TensorBatchRunner.outputPath + "/" + "tensors-1-2-2.qtensor") - assert(testFile.exists()) - assert(testFile.delete()) - } - it should "be writing legible JSON" in { var results = EquivClassRunAdjMat(numAngles = 4, tolerance = EquivClassRunAdjMat.defaultTolerance, @@ -200,7 +169,7 @@ class EquivalenceClassesSpec extends FlatSpec { it should "put stacks into equivalence classes" in { var allowedStacks = BlockStackMaker(maxRows = 2, - BlockRowMaker(maxBlocks = 1, BlockRowMaker.ZX(8), maxInOut = Option(2))) + BlockRowMaker(maxBlocks = 1, BlockGenerators.ZXGates(8), maxInOut = Option(2))) var eqc = new EquivClassRunBlockStack() allowedStacks.foreach(s => eqc.add(s)) println(eqc.equivalenceClassesNormalised.foreach( @@ -212,7 +181,7 @@ class EquivalenceClassesSpec extends FlatSpec { it should "put stacks into equivalence classes" in { var allowedStacks = BlockStackMaker(maxRows = 2, - BlockRowMaker(maxBlocks = 2, BlockRowMaker.ZW, maxInOut = Option(2))) + BlockRowMaker(maxBlocks = 2, BlockGenerators.ZW, maxInOut = Option(2))) var eqc = new EquivClassRunBlockStack() allowedStacks.foreach(s => eqc.add(s)) println(eqc.equivalenceClassesNormalised.foreach( diff --git a/scala/src/test/scala/quanto/cosy/test/GraphAnalysisSpec.scala b/scala/src/test/scala/quanto/cosy/test/GraphAnalysisSpec.scala index c04fe5a9..f510b6dd 100644 --- a/scala/src/test/scala/quanto/cosy/test/GraphAnalysisSpec.scala +++ b/scala/src/test/scala/quanto/cosy/test/GraphAnalysisSpec.scala @@ -3,11 +3,15 @@ package quanto.cosy.test import java.io.File import org.scalatest.FlatSpec +import quanto.cosy.BlockGenerators.QuickGraph import quanto.cosy.RuleSynthesis._ +import quanto.cosy.GraphAnalysis._ import quanto.cosy._ import quanto.data._ import quanto.util.Rational +import scala.util.matching.Regex + /** * Test files for the Graph Analysis Object * It is used to calculate various graph properties @@ -25,13 +29,15 @@ class GraphAnalysisSpec extends FlatSpec { behavior of "Graph Analysis" + + private val errorGates = quanto.util.FileHelper.readFile[Graph]( + new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"), + Graph.fromJson(_, rg) + ) + it should "compute adjacency matrices" in { - val targetGraph = quanto.util.FileHelper.readFile[Graph]( - new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"), - Graph.fromJson(_, rg) - ) - val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(targetGraph) + val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(errorGates) // boundary b2, error v2, next gate vertex v3 val bIndex = adjacencyMatrix._1.indexOf(VName("b2")) val eIndex = adjacencyMatrix._1.indexOf(VName("v2")) @@ -40,18 +46,14 @@ class GraphAnalysisSpec extends FlatSpec { assert(!adjacencyMatrix._2(bIndex)(vIndex)) - val ghostedErrors = GraphAnalysis.bypassSpecial(GraphAnalysis.detectErrors)(targetGraph, adjacencyMatrix) + val ghostedErrors = GraphAnalysis.bypassSpecial(GraphAnalysis.detectPiNodes)(errorGates, adjacencyMatrix) assert(ghostedErrors._2(bIndex)(eIndex)) assert(ghostedErrors._2(bIndex)(vIndex)) } it should "calculate distance from ends" in { - val targetGraph = quanto.util.FileHelper.readFile[Graph]( - new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"), - Graph.fromJson(_, rg) - ) - val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(targetGraph) + val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(errorGates) val errorName = VName("v2") val leftBoundary = VName("b2") @@ -73,11 +75,7 @@ class GraphAnalysisSpec extends FlatSpec { it should "find neighbours" in { - val targetGraph = quanto.util.FileHelper.readFile[Graph]( - new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"), - Graph.fromJson(_, rg) - ) - val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(targetGraph) + val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(errorGates) val errorName = VName("v2") @@ -87,12 +85,8 @@ class GraphAnalysisSpec extends FlatSpec { it should "calculate distance of a given, ignored set from ends" in { - val targetGraph = quanto.util.FileHelper.readFile[Graph]( - new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"), - Graph.fromJson(_, rg) - ) - val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(targetGraph) - val ghostedErrors = GraphAnalysis.bypassSpecial(GraphAnalysis.detectErrors)(targetGraph, adjacencyMatrix) + val adjacencyMatrix = GraphAnalysis.adjacencyMatrix(errorGates) + val ghostedErrors = GraphAnalysis.bypassSpecial(GraphAnalysis.detectPiNodes)(errorGates, adjacencyMatrix) val errorName = VName("v2") val leftBoundary = VName("b2") @@ -119,9 +113,163 @@ class GraphAnalysisSpec extends FlatSpec { // Now check with the simproc methods - val eDistances = SimplificationProcedure.PullErrors.errorsDistance(rightBoundaries)(targetGraph, Set(errorName)) + val eDistances = SimplificationProcedure.PullErrors.errorsDistance(rightBoundaries)(errorGates, Set(errorName)) assert(eDistances.get == 2.0) } + behavior of "Graph comparison" + + it should "compare all these graphs" in { + def compare(a: Graph, b: Graph) = GraphAnalysis.zxGraphCompare(a, b) + + implicit def stackToGraph(s: BlockStack) : Graph = s.graph + implicit def rowToGraph(s: BlockRow) : Graph = s.graph + implicit def blockToGraph(s: Block) : Graph = s.graph + + def zx(s: String) : Graph = BlockGenerators.ZXClifford.filter(b => b.name == s).head + + val hadamard = zx(" H ") + val id = zx(" 1 ") + + assert(compare(hadamard, id) > 0) + } + + behavior of "Connectiviy analysis" + + it should "leave empty graph disconnected" in { + var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1) + amat = amat.addVertex(Vector()) + amat = amat.addVertex(Vector(false)) + amat = amat.addVertex(Vector(false, false)) + val cc = connectionClasses(amat) + assert(cc == (0 until 3).toVector) + } + + + it should "connect all in line graph" in { + var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1) + amat = amat.addVertex(Vector()) + amat = amat.addVertex(Vector(true)) + amat = amat.addVertex(Vector(false, true)) + val cc = connectionClasses(amat) + assert(cc == Vector(0,0,0)) + } + + + it should "find two classes" in { + var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1) + amat = amat.addVertex(Vector()) + amat = amat.addVertex(Vector(false)) + amat = amat.addVertex(Vector(true, false)) + val cc = connectionClasses(amat) + assert(cc == Vector(0,1,0)) + } + + it should "detect no scalars in all boundaries" in { + var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1) + amat = amat.addVertex(Vector()) + amat = amat.addVertex(Vector(false)) + amat = amat.addVertex(Vector(true, false)) + assert(!containsScalars(amat)) + } + + it should "detect no scalars with some non-boundaries" in { + var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1) + amat = amat.addVertex(Vector()) + amat = amat.addVertex(Vector(false)) + amat = amat.addVertex(Vector(true, false)) + amat = amat.nextType.get + amat = amat.addVertex(Vector(false, true, false)) + assert(!containsScalars(amat)) + } + + + it should "detect scalars with some non-boundaries" in { + var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1) + amat = amat.addVertex(Vector()) + amat = amat.addVertex(Vector(false)) + amat = amat.addVertex(Vector(true, false)) + amat = amat.nextType.get + amat = amat.addVertex(Vector(false, false, false)) + amat = amat.addVertex(Vector(false, false, false, true)) + assert(containsScalars(amat)) + } + + behavior of "circuit analysis" + + // Only works for circuits generated inside CoSy + // Note that this will not magically give you reductions - it is not graph-invariant by circuit-invariant + // Turning a graph upside down gives very a very different measure + + it should "distill circuit placement from name via regex" in { + val example = "r-2-bl-1-h-1" + val output = CircuitPlacementParser.p(example) + assert (output == (2,1,"h")) + } + + it should "bias circuits to the left" in { + val blocks: List[Block] = BlockGenerators.ZXGates(1) + val rows: List[BlockRow] = BlockRowMaker.makeRowsUpToSize(2, blocks, Some(2)) + val stacks = BlockStackMaker.makeStacksOfSize(1, rows) + val e1 = stacks.find(_.toString == "( 1 x 0Z1)").get + val e2 = stacks.find(_.toString == "(0Z1 x 1 )").get + assert(zxCircuitCompare(e1.graph, e2.graph) > 0) + } + + it should "bias circuits down" in { + val blocks: List[Block] = BlockGenerators.ZXGates(1) + val rows: List[BlockRow] = BlockRowMaker.makeRowsUpToSize(1, blocks, Some(1)) + val stacks = BlockStackMaker.makeStacksOfSize(2, rows) + val e1 = stacks.find(_.toString == "( 1 ) o (0Z1)").get + val e2 = stacks.find(_.toString == "(0Z1) o ( 1 )").get + assert(zxCircuitCompare(e1.graph, e2.graph) < 0) + } + + behavior of "Checking isomorphisms" + + // Boundaries have names like i-20 or o-1 + val boundaryRegex : Option[Regex] = Some(raw"""(i|o)-(\d+)""".r) + def ZXiso(left: Graph, right: Graph): Boolean = checkIsomorphic(rg, boundaryRegex)(left, right) + + it should "not match empty onto non-empty" in { + val g = QuickGraph(rg) + assert(!ZXiso(g, g.node("Z", "0"))) + } + + it should "match discrete graphs onto only themselves" in { + var g = QuickGraph(rg) + val graphs = (for (_ <- 1 to 5) yield { + g = g.node("Z", "0") + g + }).toList.zipWithIndex + for(i <- graphs; j <- graphs) { + // Check we have made alrge enough graphs; index is out by 1 + assert(i._1.verts.size == (i._2 + 1)) + if(i._2 == j._2){ + assert(ZXiso(i._1, j._1)) + } else { + assert(!ZXiso(i._1, j._1)) + } + } + } + + it should "not match different data" in { + val g = QuickGraph(rg) + assert(!ZXiso(g.node("Z","0"), g.node("Z", "1"))) + assert(!ZXiso(g.node("Z","0"), g.node("X", "0"))) + } + + it should "match outputs with the same name" in { + val g1 = QuickGraph(rg).node("Z", angle = "0", nodeName = "v").addInput().join("v", "i-0") + val g2 = QuickGraph(rg).node("Z", angle = "0", nodeName = "w").addInput().join("w", "i-0") + assert(ZXiso(g1,g2)) + } + + it should "not match boundaries with different names" in { + val g1 = QuickGraph(rg).node("Z", angle = "0", nodeName = "v").addInput().join("v", "i-0") + val g2 = QuickGraph(rg).node("Z", angle = "0", nodeName = "v").addOutput().join("v", "o-0") + assert(!ZXiso(g1,g2)) + } + } diff --git a/scala/src/test/scala/quanto/cosy/test/InterpreterSpec.scala b/scala/src/test/scala/quanto/cosy/test/InterpreterSpec.scala index 3dd95070..ddea9606 100644 --- a/scala/src/test/scala/quanto/cosy/test/InterpreterSpec.scala +++ b/scala/src/test/scala/quanto/cosy/test/InterpreterSpec.scala @@ -1,32 +1,186 @@ package quanto.cosy.test import org.scalatest.FlatSpec -import quanto.cosy.Interpreter.AngleMap -import quanto.cosy.{AdjMat, Complex, Interpreter, Tensor} -import quanto.data.{NodeV, Theory} +import quanto.cosy.BlockGenerators.QuickGraph +import quanto.cosy.Interpreter._ +import quanto.cosy._ +import quanto.data.Theory.ValueType +import quanto.data._ import quanto.util.json.JsonObject /** * Created by hector on 24/05/17. */ class InterpreterSpec extends FlatSpec { + + behavior of "Connecting graphs" + + + implicit def vname(str: String): VName = VName(str) + + implicit def vname2(strs: (String, String)): (VName, VName) = (vname(strs._1), vname(strs._2)) + + implicit def ename(str: String): EName = EName(str) + + + val cap : Tensor = Tensor(Array(Array[Complex](1, 0, 0, 1))) + + it should "make two id wires" in { + var g = new Graph(). + addVertex(vname("v0"), WireV()). + addVertex(vname("v1"), WireV()). + addVertex(vname("v2"), WireV()). + addVertex(vname("v3"), WireV()). + addEdge("e0", UndirEdge(), "v0" -> "v1"). + addEdge("e1", UndirEdge(), "v2" -> "v3") + val tensor = stringGraph(g, cap, List("v0", "v2"), List("v1", "v3")) + assert(tensor == Tensor.idWires(2)) + } + + it should "make two caps" in { + var g = new Graph(). + addVertex(vname("v0"), WireV()). + addVertex(vname("v1"), WireV()). + addVertex(vname("v2"), WireV()). + addVertex(vname("v3"), WireV()). + addEdge("e0", UndirEdge(), "v0" -> "v2"). + addEdge("e1", UndirEdge(), "v1" -> "v3") + val tensor = stringGraph(g, cap, List("v0", "v2", "v1","v3"), List()) + assert(tensor == (cap x cap)) + } + + + it should "make cap and wire" in { + var g = new Graph(). + addVertex(vname("v0"), WireV()). + addVertex(vname("v1"), WireV()). + addVertex(vname("v2"), WireV()). + addVertex(vname("v3"), WireV()). + addEdge("e0", UndirEdge(), "v0" -> "v3"). + addEdge("e1", UndirEdge(), "v1" -> "v2") + val tensor = stringGraph(g, cap, List("v1", "v2", "v0"), List("v3")) + assert(tensor == ( cap x Tensor.idWires(1))) + } + + + it should "make wire and cap" in { + var g = new Graph(). + addVertex(vname("v0"), WireV()). + addVertex(vname("v1"), WireV()). + addVertex(vname("v2"), WireV()). + addVertex(vname("v3"), WireV()). + addEdge("e0", UndirEdge(), "v0" -> "v3"). + addEdge("e1", UndirEdge(), "v1" -> "v2") + val tensor = stringGraph(g, cap, List("v0", "v1", "v2"), List("v3")) + assert(tensor == (Tensor.idWires(1) x cap)) + } + + + it should "make cup and wire" in { + var g = new Graph(). + addVertex(vname("v0"), WireV()). + addVertex(vname("v1"), WireV()). + addVertex(vname("v2"), WireV()). + addVertex(vname("v3"), WireV()). + addEdge("e0", UndirEdge(), "v0" -> "v3"). + addEdge("e1", UndirEdge(), "v1" -> "v2") + val tensor = stringGraph(g, cap, List("v0"), List("v1", "v2", "v3")) + assert(tensor == ( cap.transpose x Tensor.idWires(1))) + } + + + + it should "make wire and cup" in { + var g = new Graph(). + addVertex(vname("v0"), WireV()). + addVertex(vname("v1"), WireV()). + addVertex(vname("v2"), WireV()). + addVertex(vname("v3"), WireV()). + addEdge("e0", UndirEdge(), "v0" -> "v3"). + addEdge("e1", UndirEdge(), "v1" -> "v2") + val tensor = stringGraph(g, cap, List("v0"), List("v3", "v1", "v2")) + assert(tensor == (Tensor.idWires(1) x cap.transpose)) + } + + + + it should "make crossing" in { + var g = new Graph(). + addVertex(vname("v0"), WireV()). + addVertex(vname("v1"), WireV()). + addVertex(vname("v2"), WireV()). + addVertex(vname("v3"), WireV()). + addEdge("e0", UndirEdge(), "v0" -> "v3"). + addEdge("e1", UndirEdge(), "v1" -> "v2") + val tensor = stringGraph(g, cap, List("v0", "v1"), List("v2", "v3")) + assert(tensor == Tensor.swap(List(1,0))) + } + + + + it should "make cap and cup" in { + var g = new Graph(). + addVertex(vname("v0"), WireV()). + addVertex(vname("v1"), WireV()). + addVertex(vname("v2"), WireV()). + addVertex(vname("v3"), WireV()). + addEdge("e0", UndirEdge(), "v0" -> "v3"). + addEdge("e1", UndirEdge(), "v1" -> "v2") + val tensor = stringGraph(g, cap, List("v0", "v3"), List("v1", "v2")) + assert(tensor == Tensor(Array(Array(1,0,0,1),Array(0,0,0,0), Array(0,0,0,0), Array(1,0,0,1)))) + } + + + + it should "make (id x s) o (s x id)" in { + var g = QuickGraph.apply(BlockGenerators.ZXTheory).addInput(3).addOutput(3) + .join("i-0", "o-2") + .join("i-1", "o-0") + .join("i-2", "o-1") + val tensor = stringGraph(g, cap, List("i-0","i-1","i-2"), List("o-0","o-1","o-2")) + assert(tensor == Tensor.swap(List(2, 0, 1))) + } + + + it should "make (s x id) o (id x s)" in { + var g = QuickGraph.apply(BlockGenerators.ZXTheory).addInput(3).addOutput(3) + .join("i-0", "o-1") + .join("i-1", "o-2") + .join("i-2", "o-0") + val tensor = stringGraph(g, cap, List("i-0","i-1","i-2"), List("o-0","o-1","o-2")) + assert(tensor == Tensor.swap(List(1,2,0))) + } + behavior of "ZX" val rg = Theory.fromFile("red_green") val pi = math.Pi val rdata = Vector( NodeV(data = JsonObject("type" -> "X", "value" -> "0"), theory = rg), - NodeV(data = JsonObject("type" -> "X", "value" -> pi.toString), theory = rg), - NodeV(data = JsonObject("type" -> "X", "value" -> (0.5 * pi).toString), theory = rg), - NodeV(data = JsonObject("type" -> "X", "value" -> (-0.5 * pi).toString), theory = rg) + NodeV(data = JsonObject("type" -> "X", "value" -> "1"), theory = rg), + NodeV(data = JsonObject("type" -> "X", "value" -> "1/2"), theory = rg), + NodeV(data = JsonObject("type" -> "X", "value" -> "-1/2"), theory = rg) ) val gdata = Vector( NodeV(data = JsonObject("type" -> "Z", "value" -> "0"), theory = rg), - NodeV(data = JsonObject("type" -> "Z", "value" -> pi.toString), theory = rg), - NodeV(data = JsonObject("type" -> "Z", "value" -> (0.5 * pi).toString), theory = rg), - NodeV(data = JsonObject("type" -> "Z", "value" -> (-0.5 * pi).toString), theory = rg) + NodeV(data = JsonObject("type" -> "Z", "value" -> "1"), theory = rg), + NodeV(data = JsonObject("type" -> "Z", "value" -> "1/2"), theory = rg), + NodeV(data = JsonObject("type" -> "Z", "value" -> "-1/2"), theory = rg) ) + //Change the last number for larger tests + //Don't include boundaries as the methods can give permutations of each other's answers + val smallAdjMats: Stream[AdjMat] = ColbournReadEnum.enumerate(2, 2, 2, 2) + + implicit def quickGraph(amat: AdjMat): Graph = Graph.fromAdjMat(amat, rdata, gdata) + + implicit def stringToPhase(s: String): PhaseExpression = { + PhaseExpression.parse(s, ValueType.AngleExpr) + } + var one = Complex.one + var zero = Complex.zero + + def amatToZXTensor(adjMat: AdjMat) = Interpreter.interpretZXAdjMat(adjMat, rdata, gdata) it should "make hadamards" in { // Via "new" @@ -40,39 +194,39 @@ class InterpreterSpec extends FlatSpec { } it should "make Green Spiders" in { - def gs(angle: Double, in: Int, out: Int) = Interpreter.interpretZXSpider(true, angle, in, out) + def gs(angle: String, in: Int, out: Int) = zxSpider(true, angle, in, out) - var g12a0 = Interpreter.interpretZXSpider(true, 0, 1, 2) - var g21aPi = Interpreter.interpretZXSpider(true, math.Pi, 2, 1) - var g11aPi2 = Interpreter.interpretZXSpider(true, math.Pi / 2.0, 1, 1) - var t = gs(math.Pi / 4, 1, 1) - var g2 = gs(math.Pi / 2, 1, 1) + var g12a0 = zxSpider(true, "0", 1, 2) + var g21aPi = zxSpider(true, "pi", 2, 1) + var g11aPi2 = zxSpider(true, "pi/2", 1, 1) + var t = gs("pi/4", 1, 1) + var g2 = gs("pi/2", 1, 1) assert((t o t).isRoughly(g2)) - assert((g2 o g2).isRoughly(gs(math.Pi, 1, 1))) + assert((g2 o g2).isRoughly(gs("pi", 1, 1))) } it should "make red spiders" in { - def rs(angle: Double, in: Int, out: Int) = Interpreter.interpretZXSpider(false, angle, in, out) + def rs(angle: String, in: Int, out: Int) = zxSpider(false, angle, in, out) - assert(rs(math.Pi, 1, 1).isRoughly(Tensor(Array(Array(0, 1), Array(1, 0))))) - var rp2 = rs(math.Pi / 2, 1, 1) - assert((rp2 o rp2).isRoughly(rs(math.Pi, 1, 1))) + assert(rs("pi", 1, 1).isRoughlyUpToScalar(Tensor(Array(Array(0, 1), Array(1, 0))))) + var rp2 = rs("pi/2", 1, 1) + assert((rp2 o rp2).isRoughly(rs("pi", 1, 1))) assert((rp2 o rp2 o rp2 o rp2).isRoughly(Tensor.id(2))) - assert(Interpreter.cached.contains("ZX:false:3.141592653589793:1:1")) + assert(Interpreter.cached.contains("ZX:red:1:1:1")) } it should "respect spider law" in { - var g1 = Interpreter.interpretZXSpider(true, math.Pi / 8, 1, 2) - var g2 = Interpreter.interpretZXSpider(true, math.Pi / 8, 1, 1) - var g3 = Interpreter.interpretZXSpider(true, math.Pi / 4, 1, 2) + var g1 = zxSpider(true, "pi/8", 1, 2) + var g2 = zxSpider(true, "pi/8", 1, 1) + var g3 = zxSpider(true, "pi/4", 1, 2) assert(((Tensor.id(2) x g2) o g1).isRoughly(g3)) } it should "apply Hadamards" in { var h1 = Tensor.hadamard var h2 = h1 x h1 - var g = Interpreter.interpretZXSpider(true, math.Pi / 4, 1, 2) - var r = Interpreter.interpretZXSpider(false, math.Pi / 4, 1, 2) + var g = zxSpider(true, "pi/4", 1, 2) + var r = zxSpider(false, "pi/4", 1, 2) assert((h2 o r o h1).isRoughly(g)) } @@ -83,7 +237,9 @@ class InterpreterSpec extends FlatSpec { amat = amat.nextType.get amat = amat.addVertex(Vector(true, true)) println(amat) - val i1 = Interpreter.interpretZXAdjMat(amat, redAM = rdata, greenAM = gdata) + val i1 = amatToZXTensor(amat) + val i2 = amatToGraphToZXTensor(amat) + assert(i1.isRoughly(i2)) assert(i1.isRoughly(Tensor(Array(Array(1, 0, 0, 1))))) } @@ -97,11 +253,23 @@ class InterpreterSpec extends FlatSpec { amat = amat.nextType.get // Red pi amat = amat.addVertex(Vector(true, true)) + amat = amat.nextType.get + // Green 0 + amat = amat.nextType.get + // Green pi println(amat) - val i1 = Interpreter.interpretZXAdjMat(amat, redAM = rdata, greenAM = gdata) - assert(i1.isRoughly(Interpreter.interpretZXSpider(green = false, math.Pi, 2, 0))) + val i1 = amatToZXTensor(amat) + val i2 = amatToGraphToZXTensor(amat) + assert(i1.isRoughly(i2)) + assert(i1.isRoughly(zxSpider(false, "pi", 2, 0))) } - var zero = Complex.zero + + def amatToGraphToZXTensor(adjMat: AdjMat) = { + val asGraph = quickGraph(adjMat) + Interpreter.interpretZXGraph(asGraph, + asGraph.verts.filter(asGraph.isTerminalWire).toList.sortBy(vn => vn.s), List()) + } + it should "process red spider law" in { // Simple red and green identities var amat = new AdjMat(numRedTypes = 4, numGreenTypes = 4) @@ -117,9 +285,13 @@ class InterpreterSpec extends FlatSpec { amat = amat.addVertex(Vector(false, true, true)) amat = amat.nextType.get //red -pi/2 - val i1 = Interpreter.interpretZXAdjMat(amat, redAM = rdata, greenAM = gdata) - assert(i1.isRoughly(Interpreter.interpretZXSpider(false, -pi / 2, 2, 0))) + amat = amat.nextType.get + val i1 = amatToZXTensor(amat) + val i2 = amatToGraphToZXTensor(amat) + assert(i1.isRoughly(i2)) + assert(i1.isRoughly(zxSpider(false, "-pi / 2", 2, 0))) } + it should "satisfy the Euler identity" in { // Euler identity var amat = new AdjMat(numRedTypes = 4, numGreenTypes = 4) @@ -137,55 +309,55 @@ class InterpreterSpec extends FlatSpec { amat = amat.nextType.get amat = amat.nextType.get amat = amat.addVertex(Vector(false, false, true, true)) + amat = amat.nextType.get println(amat) - val i2 = Interpreter.interpretZXAdjMat(amat, redAM = rdata, greenAM = gdata) - val i3 = Interpreter.interpretZXSpider(true, 0, 2, 0) o (Tensor.id(2) x Tensor.hadamard) - println(i3.scaled(i2.contents(0)(0) / i3.contents(0)(0))) - assert(i2.isRoughly(i3.scaled(i2.contents(0)(0) / i3.contents(0)(0)))) - assert(i2.isRoughlyUpToScalar(Tensor(Array(Array(1, 1, 1, -1))))) + val i1 = Interpreter.interpretZXAdjMat(amat, redAM = rdata, greenAM = gdata) + val i2 = amatToGraphToZXTensor(amat) + assert(i2.isRoughly(i2)) + val i3 = zxSpider(true, "0", 2, 0) o (Tensor.id(2) x Tensor.hadamard) + println(i3.scaled(i1.contents(0)(0) / i3.contents(0)(0))) + assert(i1.isRoughly(i3.scaled(i1.contents(0)(0) / i3.contents(0)(0)))) + assert(i1.isRoughlyUpToScalar(Tensor(Array(Array(1, 1, 1, -1))))) } behavior of "ZW" it should "Evaluate the four-output GHZ spider" in { - var t = Interpreter.interpretZWSpider(black = true, 4) - assert(t == Tensor(Array(Array(1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0))).transpose) + val t1 = Interpreter.interpretZWSpiderNoInputs(black = true, 4) + val g = QuickGraph(Theory.fromFile("ZW")).addOutput(4).node("w", nodeName = "w") + .join("w", "o-0") + .join("w", "o-1") + .join("w", "o-2") + .join("w", "o-3") + + val t2 = Interpreter.interpretSpiderGraph(zwSpiderInterpreter)(g, List(), List("o-0", "o-1", "o-2", "o-3")) + val t3 = Tensor(Array(Array(1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0))).transpose + assert(t1 == t3) + assert(t2 == t3) } - /* Too intensive! - it should "agree on rule 5d" in { - var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1) - amat = amat.addVertex(Vector()) - amat = amat.addVertex(Vector(false)) - amat = amat.nextType.get - // Blacks (reds) - amat = amat.addVertex(Vector(true, false)) - amat = amat.addVertex(Vector(false, false, true)) - amat = amat.addVertex(Vector(false, false, false, true)) - amat = amat.addVertex(Vector(false, true, false, false, true)) - amat = amat.nextType.get - // Whites (greens) - amat = amat.addVertex(Vector(false, false, false, true, true, false)) + it should "make w spiders" in { + val g = QuickGraph(Theory.fromFile("ZW")).addInput().addOutput().node("w", nodeName = "w") + .join("i-0", "w").join("w", "o-0") - var lhs = amat.copy() - amat = new AdjMat(numBoundaries = 2, numRedTypes = 1, numGreenTypes = 1) - amat = amat.addVertex(Vector()) - amat = amat.addVertex(Vector(false)) - amat = amat.nextType.get - // Blacks (reds) - amat = amat.addVertex(Vector(true, false)) - amat = amat.addVertex(Vector(false, false, true)) - amat = amat.addVertex(Vector(false, false, false, false)) - amat = amat.addVertex(Vector(false, true, false, false, true)) + val t1 = Interpreter.interpretSpiderGraph(zwSpiderInterpreter)(g, List("i-0"), List("o-0")) + assert(t1.isRoughly(Tensor(Array(Array(0, 1), Array(1, 0))))) + + + val t2 = Interpreter.interpretSpiderGraph(zwSpiderInterpreter)(g, List("i-0", "o-0"), List()) + assert(t2.isRoughly(Tensor(Array(Array(0, 1, 1, 0))))) - var t1 = Interpreter.interpretZWAdjMat(lhs) - var t2 = Interpreter.interpretZWAdjMat(amat) - println(t1) + + val g3 = QuickGraph(Theory.fromFile("ZW")).addInput().addOutput(2).node("w", nodeName = "w") + .join("i-0", "w").join("w", "o-0").join("w", "o-1") + + + val t3 = Interpreter.interpretSpiderGraph(zwSpiderInterpreter)(g3, List("i-0"), List("o-0", "o-1")) + assert(t3.isRoughly(Tensor(Array(Array(0, 1, 1, 0), Array(1, 0, 0, 0))).transpose)) } - */ - it should "agree on rule 3a" in { + it should "agree on rule nat_c^n" in { var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1) amat = amat.addVertex(Vector()) @@ -214,23 +386,155 @@ class InterpreterSpec extends FlatSpec { amat = amat.addVertex(Vector(false, false, false, true, true)) amat = amat.addVertex(Vector(true, false, false, false, false, true)) - var t1 = Interpreter.interpretZWAdjMat(rhs) - var t2 = Interpreter.interpretZWAdjMat(amat) - assert(t1 == t2) + var t1 = Interpreter.interpretZWAdjMat(rhs, List("v0"), List("v1", "v2")) + var t2 = Interpreter.interpretZWAdjMat(amat, List("v0"), List("v1", "v2")) + assert(t1.isRoughly(t2)) } - it should "agree on rule 5c" in { + it should "agree on rule inv" in { + + val g = QuickGraph(Theory.fromFile("ZW")).addInput().addOutput().node("w",nodeName = "w1").node("w",nodeName = "w2") + .join("i-0", "w1").join("w1","w2").join("w2", "o-0") + + + var t1 = Interpreter.interpretSpiderGraph(zwSpiderInterpreter)(g, List("i-0"),List("o-0")) + assert(t1.isRoughly(Tensor.idWires(1))) + } + + behavior of "block stacks and graph interpreters" + + def zxSpider(isGreen: Boolean, angle: String, inputs: Int, outputs: Int): Tensor = + Interpreter.interpretZXSpider(ZXAngleData(isGreen, angle), inputs, outputs) + + it should "interpret Z pi/4" in { + + val cnotGraph = QuickGraph(BlockGenerators.ZXTheory).addInput(1).addOutput(1).node("Z", nodeName = "z", angle = "1/4") + .join("i-0", "z") + .join("o-0", "z") + + val gt = Interpreter.interpretZXGraph(cnotGraph, List("i-0"), List("o-0")) + assert(gt.isRoughly(zxSpider(true, "1/4", 1, 1))) + } + + + it should "interpret X pi/4" in { + + val cnotGraph = QuickGraph(BlockGenerators.ZXTheory).addInput(1).addOutput(1).node("X", nodeName = "x", angle = "1/4") + .join("i-0", "x") + .join("o-0", "x") + + val gt = Interpreter.interpretZXGraph(cnotGraph, List("i-0"), List("o-0")) + assert(gt.isRoughly(zxSpider(false, "1/4", 1, 1))) + } - var amat = new AdjMat(numRedTypes = 1, numGreenTypes = 1) - amat = amat.nextType.get - // Blacks (reds) - amat = amat.addVertex(Vector()) - amat = amat.addVertex(Vector(true)) - amat = amat.addVertex(Vector(false, true)) - amat = amat.addVertex(Vector(false, false, true)) - amat = amat.nextType.get - var t1 = Interpreter.interpretZWAdjMat(amat) - assert(t1 == Tensor.id(1)) + it should "interpret X pi" in { + + val cnotGraph = QuickGraph(BlockGenerators.ZXTheory).addInput(1).addOutput(1).node("X", nodeName = "x", angle = "1") + .join("i-0", "x") + .join("o-0", "x") + + val gt = Interpreter.interpretZXGraph(cnotGraph, List("i-0"), List("o-0")) + assert(gt.isRoughly(zxSpider(false, "1", 1, 1))) + } + + + it should "interpret CNOT" in { + val cnotGraph = QuickGraph(BlockGenerators.ZXTheory).addInput(2).addOutput(2).node("Z", nodeName = "z").node("X", xCoord = 1, nodeName = "x") + .join("i-0", "z").join("i-1", "x") + .join("o-0", "z").join("o-1", "x") + .join("z", "x") + + + assert(BlockGenerators.ZXClifford.filter(_.name == "CNT").head.tensor.isRoughlyUpToScalar( + Tensor(Array(Array(1, 0, 0, 0), Array(0, 1, 0, 0), Array(0, 0, 0, 1), Array(0, 0, 1, 0))))) + + assert(Interpreter.interpretZXGraph(cnotGraph, List("i-0", "i-1"), List("o-0", "o-1")).isRoughlyUpToScalar( + Tensor(Array(Array(1, 0, 0, 0), Array(0, 1, 0, 0), Array(0, 0, 0, 1), Array(0, 0, 1, 0))))) + } + + it should "interpret CNOT x id" in { + val cnotGraph = QuickGraph(BlockGenerators.ZXTheory).addInput(3).addOutput(3).node("Z", nodeName = "z").node("X", xCoord = 1, nodeName = "x") + .join("i-0", "z").join("i-1", "x") + .join("o-0", "z").join("o-1", "x") + .join("z", "x") + .join("i-2", "o-2") + + val height = 1 + val width = 3 + val rows = BlockRowMaker.makeRowsUpToSize(width, BlockGenerators.ZXCNOT, Some(width)) + val stacks = BlockStackMaker.makeStacksOfSize(height, rows) + var gate = stacks.filter(_.toString == "(CNT x 1 )") + + val breakdown = Tensor.swap(List(2,1,0)) o + (Tensor.idWires(2) x Tensor(Array(Array(1,0,0,0), Array(0,0,0,1)))) o + Tensor.swap(List(2,0,1,3)) o + (Tensor.idWires(2) x Tensor(Array(Array(1,0,0,1), Array(0,1,1,0))).transpose) o + Tensor.swap(List(0,2,1)) + + + val breakdown2 = Tensor(Array(Array(1,0,0,0), Array(0,1,0,0), Array(0,0,0,1), Array(0,0,1,0))) x Tensor.idWires(1) + + assert(breakdown.isRoughlyUpToScalar(breakdown2)) + + + assert(Interpreter.interpretZXGraph(cnotGraph, List("i-0", "i-1","i-2"), List("o-0", "o-1","o-2")).isRoughlyUpToScalar( + breakdown + )) + + + assert(Interpreter.interpretZXGraph(cnotGraph, List("i-0", "i-1","i-2"), List("o-0", "o-1","o-2")).isRoughlyUpToScalar( + gate.head.tensor + )) + + } + + + it should "interpret green strings" in { + val cnotGraph = QuickGraph(BlockGenerators.ZXTheory).addInput(3).addOutput(3) + .node("Z", nodeName = "z2") + .node("Z", nodeName = "z1") + .node("Z", nodeName = "z0") + .join("i-0", "z0").join("z0", "o-0") + .join("i-1", "z1").join("z1", "o-1") + .join("i-2", "z2").join("z2", "o-2") + + assert(Interpreter.interpretZXGraph(cnotGraph, List("i-0", "i-1", "i-2"), List("o-0", "o-1", "o-2")).isRoughlyUpToScalar( + Tensor.idWires(3))) + + } + + + it should "agree on small adjmats" in { + val errors: List[AdjMat] = smallAdjMats.filterNot(adj => { + val am = amatToZXTensor(adj) + val gr = amatToGraphToZXTensor(adj) + am.isRoughly(gr) + }).toList + assert(errors.isEmpty) + } + + it should "agree between block stacks and spiders" in { + // Will only work up to width 10 unless you change how boundaries are sorted + val height = 1 + val width = 3 + val rows = BlockRowMaker.makeRowsUpToSize(width, BlockGenerators.ZXClifford.filterNot(b => b.toString == " H "), Some(width)) + val stacks = BlockStackMaker.makeStacksOfSize(height, rows) + stacks.foreach(bs => { + val asGraph = bs.graph + val tensorFromGraph = Interpreter.interpretZXGraph(asGraph, + asGraph.verts.filter(_.s.matches(raw"r-0-i-\d+")).toList.sortBy(vn => vn.s), + asGraph.verts.filter(_.s.matches(raw"r-"+(height-1)+raw"-o-\d+")).toList.sortBy(vn => vn.s)) + assert(bs.tensor.isRoughlyUpToScalar(tensorFromGraph)) + val asGraph2 = bs.graph + val tensorFromGraph2 = Interpreter.interpretZXGraph(asGraph, + asGraph.verts.filter(_.s.matches(raw"r-0-i-\d+")).toList.sortBy(vn => vn.s), + asGraph.verts.filter(_.s.matches(raw"r-"+(height-1)+raw"-o-\d+")).toList.sortBy(vn => vn.s)) + assert(bs.tensor.isRoughlyUpToScalar(tensorFromGraph)) + } + ) + } + + } diff --git a/scala/src/test/scala/quanto/cosy/test/RuleSynthesisSpec.scala b/scala/src/test/scala/quanto/cosy/test/RuleSynthesisSpec.scala index eae9e6ec..072166f6 100644 --- a/scala/src/test/scala/quanto/cosy/test/RuleSynthesisSpec.scala +++ b/scala/src/test/scala/quanto/cosy/test/RuleSynthesisSpec.scala @@ -6,10 +6,11 @@ import quanto.cosy._ import org.scalatest.FlatSpec import quanto.data._ import quanto.rewrite.{Matcher, Rewriter} -import quanto.data.Derivation.DerivationWithHead import quanto.cosy.RuleSynthesis._ import quanto.cosy.AutoReduce._ +import quanto.cosy.BlockGenerators.QuickGraph import quanto.data +import quanto.util.json.JsonObject import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration.Duration @@ -81,7 +82,8 @@ class RuleSynthesisSpec extends FlatSpec { ) var r1 = ruleList.head var m = Matcher.findMatches(r1.lhs, r1.lhs) - var shrunkRules = RuleSynthesis.discardDirectlyReducibleRules(rules = ruleList, rg, seed = new Random(1)) + var shrunkRules = RuleSynthesis.discardDirectlyReducibleRules( + comparison = basicGraphComparison, rules = ruleList, seed = new Random(1)) println(shrunkRules) assert(ruleList.length > shrunkRules.length) } @@ -90,7 +92,7 @@ class RuleSynthesisSpec extends FlatSpec { it should "find small rules" in { var results = EquivClassRunBlockStack(1e-14) - var rowsAllowed = BlockRowMaker(2, BlockRowMaker.Bian2Qubit, maxInOut = Option(2)) + var rowsAllowed = BlockRowMaker(2, BlockGenerators.Bian2Qubit, maxInOut = Option(2)) var stacks = BlockStackMaker(2, rowsAllowed) stacks.foreach(s => results.add(s)) results.equivalenceClassesNormalised @@ -102,7 +104,7 @@ class RuleSynthesisSpec extends FlatSpec { it should "find small rules" in { var results = EquivClassRunBlockStack(1e-14) - var rowsAllowed = BlockRowMaker(1, BlockRowMaker.ZXQutrit(3), maxInOut = Option(2)) + var rowsAllowed = BlockRowMaker(1, BlockGenerators.ZXQutrit(3), maxInOut = Option(2)) var stacks = BlockStackMaker(2, rowsAllowed) stacks.foreach(s => results.add(s)) results.equivalenceClassesNormalised @@ -115,7 +117,7 @@ class RuleSynthesisSpec extends FlatSpec { it should "find small rules" in { var results = EquivClassRunBlockStack(1e-14) - var rowsAllowed = BlockRowMaker(1, BlockRowMaker.ZXQudit(4, 2), maxInOut = Option(2)) + var rowsAllowed = BlockRowMaker(1, BlockGenerators.ZXQudit(4, 2), maxInOut = Option(2)) var stacks = BlockStackMaker(2, rowsAllowed) stacks.foreach(s => results.add(s)) results.equivalenceClasses @@ -130,7 +132,7 @@ class RuleSynthesisSpec extends FlatSpec { // Pick out S1, S2 and REDUCIBLE var smallRules = ctRules.filter(_.name.matches(raw"S\d|RED.*")) var reducibleGraph = smallRules.filter(_.name.matches(raw"RED.*")).head.lhs - var resultingDerivation = greedyReduce(RuleSynthesis.graphToDerivation(reducibleGraph, rg), smallRules) + var resultingDerivation = greedyReduce(basicGraphComparison, graphToDerivation(reducibleGraph), smallRules) // println(resultingDerivation.stepsTo(resultingDerivation.firstHead)) assert(Derivation.derivationHeadPairToGraph(resultingDerivation).verts.size < reducibleGraph.verts.size) } @@ -138,10 +140,13 @@ class RuleSynthesisSpec extends FlatSpec { it should "automatically reduce" in { var ctRules = ZXRules // Pick out S1, S2 and REDUCIBLE - var smallRules = ctRules.filter(_.name.matches(raw"S\d.*")) - var minimisedRules = RuleSynthesis.minimiseRuleset(smallRules ::: smallRules.map(_.inverse), rg) + def compare(left: Graph, right: Graph): Int = GraphAnalysis.zxGraphCompare(left, right) + var smallRules : List[Rule] = + ctRules.filter(_.name.matches(raw"REDUCIBLE")) ::: + ctRules.filter(_.name.matches(raw"S[12]")) + var minimisedRules = RuleSynthesis.greedyReduceRules(compare)(smallRules) minimisedRules.foreach(println) - assert(minimisedRules.exists(_.name.matches(raw".*reduced.*"))) + assert(minimisedRules.head.lhs.verts.size == 2) } it should "make a long derivation from annealing" in { @@ -149,21 +154,152 @@ class RuleSynthesisSpec extends FlatSpec { var target = ctRules.filter(_.name.matches(raw"RED.*")).head.lhs var remaining = ctRules.filterNot(_.name.matches(raw"RED.*")) var annealed = annealingReduce( - RuleSynthesis.graphToDerivation(target, rg), + GraphAnalysis.zxGraphCompare, + graphToDerivation(target), remaining ::: remaining.map(_.inverse), 100, 3, new Random(3), None) assert(annealed._1.steps.size > target.verts.size) + assert(quanto.rewrite.Simproc.fromDerivationWithHead(annealed).hasNext) } it should "randomly apply appropriate rules" in { var ctRules = ZXRules var target = ctRules.filter(_.name.matches(raw"RED.*")).head.lhs var remaining = ctRules.filter(_.name.matches(raw"S\d+.*")) - val reducedDerivation = randomApply((new Derivation(rg, target), None), + val reducedDerivation = randomApply((new Derivation(target), None), remaining, 100, alwaysTrue, new Random(1)) assert(reducedDerivation._1.steps(reducedDerivation._2.get).graph < target) } + behavior of "reducing rulesets" + + it should "greedily reduce one of these rules" in { + // Starting with rules a-> b and b -> c + // End with a -> c and b -> c + def reduce(listRules : List[Rule]): List[Rule] = greedyReduceRules(GraphAnalysis.zxGraphCompare)(listRules) + + val theory = BlockGenerators.ZXTheory + val Z = NodeV(data = JsonObject("type" -> "Z"), theory = theory) + val r1 = new Rule( + QuickGraph(theory).addVertex(VName("v1"), Z) + .addVertex(VName("v2"), Z) + .addEdge(EName("e1"), UndirEdge(), (VName("v1"), VName("v2"))), + QuickGraph(theory).addVertex(VName("v1"), Z)) + var r2 = new Rule( + QuickGraph(theory).addVertex(VName("v1"), Z), + QuickGraph(theory) + ) + val start: List[Rule] = List(r1, r2) + val end : List[Rule] = reduce(start) + assert(start.toSet.intersect(end.toSet).size == 1) + + // Check that it only runs rules forwards: + val startFlipped = start.map(_.inverse) + val endFlipped = reduce(startFlipped) + assert(startFlipped.toSet.intersect(endFlipped.toSet).size == 2) + // Two identical rules should not annihilate each other + // (One should be reduced, the other left as-is) + val duplicate = List(r1,r1.copy(description = "Different")) + val duplicateReduced = reduce(duplicate) + assert(duplicateReduced.toSet.intersect(duplicate.toSet).size == 1) + } + + behavior of "colour swapping" + + it should "Send X to Z" in { + val theory = rg + val g = QuickGraph(theory).addInput(). + node(nodeType = "Z",nodeName = "z"). + node(nodeType = "X",nodeName = "x"). + node(nodeType = "hadamard",nodeName = "h"). + join("x", "z") + val r = new Rule(g,g) + val r2 = r.colourSwap(Map("Z" -> "X", "X" -> "Z")) + assert(r2.lhs.vdata(VName("z")).typ == "X") + assert(r2.lhs.vdata(VName("x")).typ == "Z") + assert(r2.lhs.vdata(VName("h")).typ == "hadamard") + } + + behavior of "extending rules" + + it should "not try and extend the following rules" in { + val theory = rg + val r1l = QuickGraph(theory).addInput().node(nodeType = "Z",nodeName = "v").join("v","i-0") + val r1r = QuickGraph(theory).addInput().node(nodeType = "X",nodeName = "v").join("v","i-0") + val r1 = Rule(r1l, r1r) + assert(extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex) == r1) + } + + it should "try and extend the following" in { + val theory = rg + val Z = NodeV(data = JsonObject("type" -> "Z"), theory = theory) + val X = NodeV(data = JsonObject("type" -> "X"), theory = theory) + val r1l = QuickGraph(theory).addInput().node(nodeType = "Z",nodeName = "v").join("v","i-0") + val r1 = Rule(r1l, r1l) + val extended = extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex) + assert(extended != r1) + assert(extended.hasBBoxes) + assert(extended.lhs.bboxesContaining(VName("i-0")).nonEmpty) + } + + it should "satisfy one-input, one-node expansion" in { + val g = QuickGraph(rg) + val g1 = g.addInput().node(nodeType = "Z", nodeName = "z").join("i-0","z") + val r1 = Rule(g1,g1) + val ext = extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex) + assert(ext.lhs.bboxesContaining(VName("i-0")).nonEmpty) + assert(ext.lhs.bboxesContaining(VName("z")).isEmpty) + } + + it should "satisfy two-input, two-node expansion" in { + val g = QuickGraph(rg) + val g1 = g.addInput(2). + node(nodeType = "Z", nodeName = "z").join("i-0","z"). + node(nodeType = "X", nodeName = "x").join("i-1","x"). + join("z","x") + val g2 = g.addInput(2). + node(nodeType = "Z", nodeName = "z2").join("i-0","z2"). + node(nodeType = "X", nodeName = "x2").join("i-1","x2"). + join("z2", "x2") + val r1 = Rule(g1,g2) + val ext = extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex) + assert(ext.lhs.bboxesContaining(VName("i-0")).nonEmpty) + assert(ext.lhs.bboxesContaining(VName("i-1")).nonEmpty) + assert(ext.lhs.bboxesContaining(VName("z")).isEmpty) + assert(ext.lhs.bboxesContaining(VName("x")).isEmpty) + } + + + it should "satisfy two-input, one-node expansion" in { + val g = QuickGraph(rg) + val g1 = g.addInput(2). + node(nodeType = "Z", nodeName = "z"). + join("i-0","z"). + join("i-1","z") + val r1 = Rule(g1,g1) + val ext = extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex) + assert(ext.lhs.adjacentNodesAndBoundaries(VName("z")).size == 1) + assert(ext.lhs.bboxesContaining(VName("z")).isEmpty) + } + + + it should "satisfy one-input, two-node expansion" in { + val g = QuickGraph(rg) + val g1 = g.addInput().node(nodeType = "Z", angle = "pi", nodeName = "z").join("i-0","z"). + node(nodeType = "X", nodeName = "x").join("z","x") + val r1 = Rule(g1,g1) + val ext = extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex) + assert(ext.lhs.bboxesContaining(VName ("i-0")).nonEmpty) + } + + it should "not match different colours, one-input, one-node" in { + val g = QuickGraph(rg) + val g1 = g.addInput().node(nodeType = "Z", nodeName = "v").join("i-0","v") + val g2 = g.addInput().node(nodeType = "X", nodeName = "v").join("i-0","v") + val r1 = Rule(g1,g2) + assert(extendMatchingSpidersWithBBoxes(r1, QuickGraph.boundaryRegex) == r1) + } + } \ No newline at end of file diff --git a/scala/src/test/scala/quanto/cosy/test/SimplificationProcedureSpec.scala b/scala/src/test/scala/quanto/cosy/test/SimplificationProcedureSpec.scala index 1cc9b9ac..d559bb60 100644 --- a/scala/src/test/scala/quanto/cosy/test/SimplificationProcedureSpec.scala +++ b/scala/src/test/scala/quanto/cosy/test/SimplificationProcedureSpec.scala @@ -43,7 +43,7 @@ class SimplificationProcedureSpec extends FlatSpec { 3, None) val simplificationProcedure = new SimplificationProcedure[State]( - (new Derivation(rg, target), None), + (new Derivation(target), None), initialState, step, progress, @@ -77,7 +77,7 @@ class SimplificationProcedureSpec extends FlatSpec { val t1 = raw"\beta" val targetString: String = t1.replaceAll(raw"\\", raw"\\\\") val replacementString: String = raw"\pi".replaceAll(raw"\\", raw"\\\\") - val initialDerivation = graphToDerivation(target, rg) + val initialDerivation = graphToDerivation(target) import SimplificationProcedure.Evaluation._ if (targetString.length > 0) { @@ -125,7 +125,7 @@ class SimplificationProcedureSpec extends FlatSpec { new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"), Graph.fromJson(_, rg) ) - val initialDerivation: DerivationWithHead = graphToDerivation(targetGraph, rg) + val initialDerivation: DerivationWithHead = graphToDerivation(targetGraph) val graph = Derivation.derivationHeadPairToGraph(initialDerivation) val boundaries = graph.verts.filter(v => graph.vdata(v).isBoundary) import SimplificationProcedure.LTEByWeight._ @@ -175,7 +175,7 @@ class SimplificationProcedureSpec extends FlatSpec { new File(examplesDirectory + "ZX_errors/ErrorGate.qgraph"), Graph.fromJson(_, rg) ) - val initialDerivation: DerivationWithHead = graphToDerivation(targetGraph, rg) + val initialDerivation: DerivationWithHead = graphToDerivation(targetGraph) val graph = Derivation.derivationHeadPairToGraph(initialDerivation) val boundaries = graph.verts.filter(v => graph.vdata(v).isBoundary) import SimplificationProcedure.PullErrors._ @@ -213,7 +213,7 @@ class SimplificationProcedureSpec extends FlatSpec { case Success(d) => println("Success") println(data.Derivation.derivationHeadPairToGraph(d).vdata) - assert(errorsDistance(targets.toSet)(d, GraphAnalysis.detectErrors(d)).get < 1) + assert(errorsDistance(targets.toSet)(d, GraphAnalysis.detectPiNodes(d)).get < 1) case Failure(_) => assert(false) } @@ -221,61 +221,6 @@ class SimplificationProcedureSpec extends FlatSpec { } } - // AK: This one takes too long to run as a unit test. - // TODO: find a better solution for one-off tests like this. - - ignore should "perform large error pull simplifications" in { - val allowedRules = ZXErrorRules - val targetGraph = quanto.util.FileHelper.readFile[Graph]( - new File(examplesDirectory + "ZX_errors/Huge_With_Error.qgraph"), - Graph.fromJson(_, rg) - ) - val initialDerivation: DerivationWithHead = graphToDerivation(targetGraph, rg) - val graph = Derivation.derivationHeadPairToGraph(initialDerivation) - val boundaries = graph.verts.filter(v => graph.vdata(v).isBoundary) - import SimplificationProcedure.PullErrors._ - val targets = boundaries.filter(t => t.toString.matches(raw"b(1[1-9]|2\d)")).toList - if (targets.nonEmpty) { - val initialState: State = State( - allowedRules.filterNot(r => r.name.matches(raw".*g_ann")), - 0, - None, - new Random(), - weightFunction = errorsDistance(targets.toSet), - Some(ZXErrorRules.filter(r => r.name.matches(raw".*g_ann"))), - None, - heldVertices = None, - None - ) - val simplificationProcedure = new SimplificationProcedure[State]( - initialDerivation, - initialState, - step, - progress, - (_, state) => state.currentStep == state.maxSteps.getOrElse(-1) || state.currentDistance.getOrElse(2.0) < 1 - ) - var returningDerivation = simplificationProcedure.initialDerivation - val backgroundDerivation: Future[DerivationWithHead] = Future[DerivationWithHead] { - //println("future started") - while (!simplificationProcedure.stopped) { - //println("futured loop") - simplificationProcedure.step() - returningDerivation = simplificationProcedure.current - } - simplificationProcedure.current - } - backgroundDerivation onComplete { - case Success(_) => - println("Success") - println(simplificationProcedure.state.currentDistance) - println(data.Derivation.derivationHeadPairToGraph(returningDerivation).vdata) - assert(true) - case Failure(_) => - assert(false) - } - Await.result(backgroundDerivation, Duration(waitTime, "seconds")) - } - } } \ No newline at end of file diff --git a/scala/src/test/scala/quanto/cosy/test/TensorSpec.scala b/scala/src/test/scala/quanto/cosy/test/TensorSpec.scala index f13938a8..71c7f4a5 100644 --- a/scala/src/test/scala/quanto/cosy/test/TensorSpec.scala +++ b/scala/src/test/scala/quanto/cosy/test/TensorSpec.scala @@ -80,13 +80,16 @@ class TensorSpec extends FlatSpec { assert((p2 o l1.t) == l2.t) } - it should "construct swap matrices" in { + it should "construct 2-wire swap matrices" in { var s1 = Tensor.swap(2, x => 1 - x) assert(s1.toString == "1 0 0 0\n0 0 1 0\n0 1 0 0\n0 0 0 1") - var s2 = Tensor.swap(List(0, 2, 1)) - assert(s2.toStringSparse == - "1 . . . . . . .\n. . 1 . . . . .\n. 1 . . . . . .\n. . . 1 . . . ." + - "\n. . . . 1 . . .\n. . . . . . 1 .\n. . . . . 1 . .\n. . . . . . . 1") + } + + it should "make swaps the right way up" in { + val id = Tensor.idWires(1) + val s1 = Tensor.swap(List(1,0)) + val s2 = Tensor.swap(List(2,0,1)) + assert(((id x s1) o (s1 x id)) == s2) } it should "plug tensors into other tensors" in { @@ -119,8 +122,8 @@ class TensorSpec extends FlatSpec { } it should "create direct sums" in { - var t1 = Tensor.diagonal(Array[Complex](1,2)) - var t2 = Tensor(Array(Array(3,4))) + var t1 = Tensor.diagonal(Array[Complex](1, 2)) + var t2 = Tensor(Array(Array(3, 4))) assert((t1 sum t2) == Tensor(Array(Array(1, 0, 0, 0), Array(0, 2, 0, 0), Array(0, 0, 3, 4)))) assert((t1 sum t2.transpose) == Tensor(Array(Array(1, 0, 0), Array(0, 2, 0), Array(0, 0, 3), Array(0, 0, 4)))) } @@ -161,4 +164,18 @@ class TensorSpec extends FlatSpec { var t3 = new Tensor(Array(Array(Complex(0, 1), zero), Array(zero, Complex(1, 0)))) assert(!t1.isRoughlyUpToScalar(t3)) } + + + it should "compare zeroes" in { + var t1 = new Tensor(Array(Array(zero, zero))) + var t2 = Tensor(Array(Array(zero,Complex(0.000001,0)))) + assert(!t1.isRoughlyUpToScalar(t2)) + assert(t1.isRoughlyUpToScalar(t2.scaled(Complex(0.00000000001,0)))) + } + + it should "compare zero and one" in { + var t1 = new Tensor(Array(Array(Complex(0, 0.00000000000000001)))) + var t2 = Tensor(Array(Array(one))) + assert(!t1.isRoughlyUpToScalar(t2)) + } } diff --git a/scala/src/test/scala/quanto/data/test/AngleExpressionSpec.scala b/scala/src/test/scala/quanto/data/test/AngleExpressionSpec.scala deleted file mode 100644 index c5cfd6af..00000000 --- a/scala/src/test/scala/quanto/data/test/AngleExpressionSpec.scala +++ /dev/null @@ -1,166 +0,0 @@ -package quanto.data.test - -import org.scalatest._ -import quanto.data._ -import AngleExpression._ -import quanto.util.Rational - -class AngleExpressionSpec extends FlatSpec { - behavior of "A rational number" - - it should "add correctly" in { - assert(Rational(1,2) + Rational(2,3) === Rational(7,6)) - } - - behavior of "An angle expression" - - def testReparse(e : AngleExpression) { - assert(e === parse(e.toString)) - } - - it should "compare expressions" in { - val a = AngleExpression(Rational(0), Map("a" -> Rational(1))) - val b = AngleExpression(Rational(0), Map("b" -> Rational(1))) - assert(ZERO === AngleExpression(Rational(0))) - assert(ONE_PI === AngleExpression(Rational(1))) - assert(a + b === b + a) - assert(ZERO !== ONE_PI) - assert(ONE_PI + (a * Rational(1,2)) + (b * 4) === (a * Rational(1,2)) + ONE_PI + (b * 4)) - assert((a * Rational(1,2)) + (a * Rational(2,3)) === (a * Rational(7,6))) - } - - it should "parse '0'" in { - testReparse(ZERO) - assert(parse("") === ZERO) - assert(parse("0") === ZERO) - } - - it should "parse 'PI'" in { - testReparse(ONE_PI) - assert(parse("\\pi") === ONE_PI) - assert(parse("1\\pi") === ONE_PI) - assert(parse("1*\\pi") === ONE_PI) - assert(parse("1/1\\pi") === ONE_PI) - assert(parse("1/1*\\pi") === ONE_PI) - assert(parse("1\\pi/1") === ONE_PI) - assert(parse("1*\\pi/1") === ONE_PI) - assert(parse("\\pi/1") === ONE_PI) - - assert(parse("pi") === ONE_PI) - assert(parse("1pi") === ONE_PI) - assert(parse("1*pi") === ONE_PI) - assert(parse("1/1pi") === ONE_PI) - assert(parse("1/1*pi") === ONE_PI) - assert(parse("1pi/1") === ONE_PI) - assert(parse("1*pi/1") === ONE_PI) - assert(parse("pi/1") === ONE_PI) - - assert(parse("PI") === ONE_PI) - assert(parse("1PI") === ONE_PI) - assert(parse("1*PI") === ONE_PI) - assert(parse("1/1PI") === ONE_PI) - assert(parse("1/1*PI") === ONE_PI) - assert(parse("1PI/1") === ONE_PI) - assert(parse("1*PI/1") === ONE_PI) - assert(parse("PI/1") === ONE_PI) - } - - it should "parse 'a'" in { - val a = AngleExpression(Rational(0), Map("a" -> Rational(1))) - testReparse(a) - assert(parse("a") === a) - assert(parse("1a") === a) - assert(parse("1*a") === a) - assert(parse("1/1a") === a) - assert(parse("1/1*a") === a) - assert(parse("1a/1") === a) - assert(parse("1*a/1") === a) - assert(parse("a/1") === a) - } - - it should "parse '-PI'" in { - val minusPI = ONE_PI * -1 - testReparse(minusPI) - assert(parse("-pi") === minusPI) - assert(parse("-1pi") === minusPI) - assert(parse("-1*pi") === minusPI) - assert(parse("-1/1pi") === minusPI) - assert(parse("-1/1*pi") === minusPI) - assert(parse("-1pi/1") === minusPI) - assert(parse("-1*pi/1") === minusPI) - assert(parse("-pi/1") === minusPI) - } - - it should "parse '-a'" in { - val minusA = AngleExpression(Rational(0), Map("a" -> Rational(-1))) - testReparse(minusA) - assert(parse("-a") === minusA) - assert(parse("-1a") === minusA) - assert(parse("-1*a") === minusA) - assert(parse("-1/1a") === minusA) - assert(parse("-1/1*a") === minusA) - assert(parse("-1a/1") === minusA) - assert(parse("-1*a/1") === minusA) - assert(parse("-a/1") === minusA) - } - - it should "parse '+-3/4 PI'" in { - val tfPI = AngleExpression(Rational(3,4)) - testReparse(tfPI) - assert(parse("3/4") === tfPI) - assert(parse("3/4pi") === tfPI) - assert(parse("3/4*pi") === tfPI) - assert(parse("3pi/4") === tfPI) - assert(parse("3*pi/4") === tfPI) - - val mtfPI = AngleExpression(Rational(-3,4)) - testReparse(mtfPI) - assert(parse("-3/4") === mtfPI) - assert(parse("-3/4pi") === mtfPI) - assert(parse("-3/4*pi") === mtfPI) - assert(parse("-3pi/4") === mtfPI) - assert(parse("-3*pi/4") === mtfPI) - } - - it should "parse '+-1/4 PI' and '+-1/4 a'" in { - val fPI = AngleExpression(Rational(1,4)) - val mfPI = fPI * -1 - val fA = AngleExpression(Rational(0), Map("a" -> Rational(1,4))) - val mfA = fA * -1 - assert(parse("pi/4") === fPI) - assert(parse("-pi/4") === mfPI) - assert(parse("a/4") === fA) - assert(parse("-a/4") === mfA) - } - - it should "parse addition and subtraction correctly" in { - val a = AngleExpression(Rational(0), Map("a" -> Rational(1))) - val b = AngleExpression(Rational(0), Map("b" -> Rational(1))) - testReparse(a + b) - testReparse(a - b) - testReparse((a * -1) - b) - assert(parse("a + b") === a + b) - assert(parse("a - b") === a - b) - assert(parse("-a + b") === b - a) - assert(parse("- a - b") === (a * -1) - b) - assert(parse("-(a + b)") === (a * -1) - b) - assert(parse("-(a - b)") === b - a) - } - - it should "throw an exception on failed parse" in { - intercept[AngleParseException] { parse("x + ") } - intercept[AngleParseException] { parse("%") } - } - - it should "do substitutions correctly" in { - val e1 = parse("x - 2 y") - val e2 = parse("a + b - c") - assert(e1.subst("x", e2) === parse("a + b - c - 2y")) - assert(e1.subst("y", e2) === parse("x - 2a - 2b + 2c")) - } - - it should "evaluate a polynomial" in { - val e1 = parse("2x + 3/4") - assert(Math.abs(e1.evaluate(Map("x" -> 1.0/8.0)) - 1) < 1e-15) - } -} diff --git a/scala/src/test/scala/quanto/data/test/BooleanExpressionSpec.scala b/scala/src/test/scala/quanto/data/test/BooleanExpressionSpec.scala new file mode 100644 index 00000000..552061df --- /dev/null +++ b/scala/src/test/scala/quanto/data/test/BooleanExpressionSpec.scala @@ -0,0 +1,122 @@ +package quanto.data.test + +import org.scalatest._ +import quanto.data.Theory.ValueType +import quanto.data.{PhaseParseException, PhaseExpression} +import quanto.util.Rational + +class BooleanExpressionSpec extends FlatSpec { + behavior of "A boolean expression" + + def BooleanExpression(constant: Rational) = PhaseExpression(constant, Map(), ValueType.Boolean) + + def BooleanExpression(constant: Int, coefficients: Map[String, Rational]) = + PhaseExpression(constant, coefficients, ValueType.Boolean) + + def BOOL_FALSE = PhaseExpression.zero(ValueType.Boolean) + + def BOOL_TRUE = PhaseExpression.one(ValueType.Boolean) + + def testReparse(e: PhaseExpression) { + assert(e === parse(e.toString)) + } + + def parse(s: String): PhaseExpression = PhaseExpression.parse(s, ValueType.Boolean) + + + val a = BooleanExpression(0,Map("a"-> 1)) + val b = BooleanExpression(0,Map("b"-> 1)) + + it should "output to string" in { + assert(BOOL_FALSE.toString == "\\False") + assert(BOOL_TRUE.toString == "\\True") + assert(BooleanExpression(0,Map("a"-> 1)).toString == "a") + assert(BooleanExpression(1,Map("a"-> 1)).toString == "\\True + a") + assert(BooleanExpression(1,Map("a"-> 1, "b"-> 1)).toString == "\\True + a + b") + assert(BooleanExpression(0,Map("a"-> 1, "b"-> 1)).toString == "a + b") + } + + it should "compare expressions" in { + assert(BOOL_FALSE === BooleanExpression(0)) + assert(BOOL_TRUE === BooleanExpression(1)) + assert(a + b === b + a) + assert(BOOL_FALSE !== BOOL_TRUE) + assert(BOOL_TRUE + (a * 1) + (b * 4) === (a * 1) + BOOL_TRUE + (b * 4)) + assert((a * 1) + (a * 1) === (a * 0)) + } + + it should "parse '0'" in { + assert(BOOL_FALSE.toString == "\\False") + testReparse(BOOL_FALSE) + assert(parse("") === BOOL_FALSE) + assert(parse("0") === BOOL_FALSE) + } + + it should "parse 't'" in { + testReparse(BOOL_TRUE) + assert(parse("\\t") === BOOL_TRUE) + assert(parse("1*\\t") === BOOL_TRUE) + + assert(parse("t") === BOOL_TRUE) + assert(parse("1*t") === BOOL_TRUE) + + assert(parse("True") === BOOL_TRUE) + assert(parse("1*True") === BOOL_TRUE) + } + + it should "parse 'a'" in { + testReparse(a) + assert(parse("a") === a) + assert(parse("1*a") === a) + } + + it should "parse '-T'" in { + val minusTrue = BOOL_TRUE * -1 + assert(minusTrue === BOOL_TRUE) + testReparse(minusTrue) + assert(parse("-t") === minusTrue) + assert(parse("-1*t") === minusTrue) + } + + it should "parse '-a'" in { + val minusA = a * -1 + testReparse(minusA) + assert(parse("-a") === minusA) + assert(parse("-1*a") === minusA) + } + + it should "parse '+-1'" in { + val t = BooleanExpression(1) + testReparse(t) + assert(parse("-1") === t) + assert(parse("-1*true") === t) + assert(parse("+-1") === t) + } + + it should "parse addition and subtraction correctly" in { + testReparse(a + b) + testReparse(a - b) + testReparse((a * -1) - b) + assert(parse("a + b") === a + b) + assert(parse("a - b") === a - b) + assert(parse("-a + b") === b - a) + assert(parse("- a - b") === (a * -1) - b) + assert(parse("-(a + b)") === (a * -1) - b) + assert(parse("-(a - b)") === b - a) + } + + + it should "do substitutions correctly" in { + val e1 = parse("x - 2*y") + val e2 = parse("a + b - c") + assert(e1.subst("x", e2) === parse("a + b - c - 2y")) + assert(e1.subst("y", e2) === parse("x - 2a - 2b + 2c")) + } + + it should "evaluate a sum" in { + val e1 = parse("x + 1") + assert(e1.evaluate(Map("x" -> 1)) == 0) + val e2 = parse("x + y + z") + assert(e2.evaluate(Map("x" -> 1, "y" ->1, "z" ->0)) == 0) + } +} diff --git a/scala/src/test/scala/quanto/data/test/CompositeExpressionSpec.scala b/scala/src/test/scala/quanto/data/test/CompositeExpressionSpec.scala new file mode 100644 index 00000000..3aad80c1 --- /dev/null +++ b/scala/src/test/scala/quanto/data/test/CompositeExpressionSpec.scala @@ -0,0 +1,41 @@ +package quanto.data.test + +import org.scalatest._ +import quanto.data.CompositeExpression._ +import quanto.data.Theory.ValueType +import quanto.data._ +import quanto.util.Rational + +class CompositeExpressionSpec extends FlatSpec { + + behavior of "Type Parsing" + + val vs : List[ValueType] = ValueType.values.toList + + it should "parse singletons" in { + assert(parseTypes("Angle") == Vector(ValueType.AngleExpr)) + assert(parseTypes("string") == Vector(ValueType.String)) + assert(parseTypes("String") == Vector(ValueType.String)) + assert(parseTypes("angle_expr") == Vector(ValueType.AngleExpr)) + assert(parseTypes("long") == Vector(ValueType.Long)) + assert(parseTypes("Empty") == Vector(ValueType.Empty)) + assert(parseTypes("empty") == Vector(ValueType.Empty)) + } + + it should "parse pairs" in { + var pairs = vs.flatMap(v => vs.map(w => (v,w))) + pairs.foreach(p => { + var v1 = p._1 + var v2 = p._2 + var combined = s"$v1, $v2" + assert(parseTypes(combined) === Vector(v1, v2)) + } + ) + } + + it should "parse all in a row" in { + var all = vs.mkString("(", ", ", ")") + assert(parseTypes(all) === vs.toVector) + } + +} \ No newline at end of file diff --git a/scala/src/test/scala/quanto/data/test/GraphAdjmatSpec.scala b/scala/src/test/scala/quanto/data/test/GraphAdjmatSpec.scala index 019b6744..248ef8e9 100644 --- a/scala/src/test/scala/quanto/data/test/GraphAdjmatSpec.scala +++ b/scala/src/test/scala/quanto/data/test/GraphAdjmatSpec.scala @@ -65,9 +65,13 @@ class GraphAdjMatSpec extends FlatSpec { |} """.stripMargin), rg) - assert(g.isBoundary("v0")) - assert(g.isBoundary("v1")) - assert(g === g1) + assert(g.isTerminalWire("v0")) + assert(g.isTerminalWire("v1")) + var g2 = g1.copy() + g1.verts.foreach(vn => g2 = g2.updateVData(vn) { vd => vd.withCoord(0, 0) }) + var g3 = g.copy() + g.verts.foreach(vn => g3 = g3.updateVData(vn) { vd => vd.withCoord(0, 0) }) + assert(g3 === g2) // LAYOUT A GRAPH LIKE THIS: // val layoutProc = new ForceLayout diff --git a/scala/src/test/scala/quanto/data/test/GraphSpec.scala b/scala/src/test/scala/quanto/data/test/GraphSpec.scala index 76b721a9..92f9adda 100644 --- a/scala/src/test/scala/quanto/data/test/GraphSpec.scala +++ b/scala/src/test/scala/quanto/data/test/GraphSpec.scala @@ -3,12 +3,15 @@ package quanto.data.test import org.scalatest._ import quanto.data._ import quanto.data.Names._ +import quanto.data.Theory.{EdgeDesc, EdgeStyleDesc, ValueDesc, ValueType} import quanto.util.json._ + import scala.collection.immutable.TreeSet class GraphSpec extends FlatSpec with GivenWhenThen { - val rg = Theory.fromFile("red_green") + val rg : Theory = Theory.fromFile("red_green") + val composite_thy : Theory = Theory.fromFile("composite") behavior of "A graph" var g : Graph = _ @@ -70,9 +73,34 @@ class GraphSpec extends FlatSpec with GivenWhenThen { it should "be equal to its copy" in { val g1 = g.copy() + assert(g1 != null) assert(g1 === g) } + + it should "store and retrieve non-default edge types" in { + + val eDesc = EdgeDesc( + value = ValueDesc(typ = Vector(ValueType.Empty)), + style = EdgeStyleDesc(), + defaultData = JsonObject("type" -> "recorded") + ) + + val TwoWireTheory = Theory.DefaultTheory.mixin(Map(), Map("recorded" -> eDesc), Some("TwoWireTheory")) + + val eData = UndirEdge(eDesc.defaultData, JsonObject(), TwoWireTheory) + + val g = new Graph() + .addVertex(VName("v1"), NodeV()) + .addVertex(VName("v2"), NodeV()) + .addEdge(EName("e"), eData, ("v1", "v2")) + .addEdge(EName("f"), UndirEdge(), ("v1", "v2")) + val json = g.toJson(TwoWireTheory) + assert (! (json / "undir_edges" / "e").isEmpty) + assert (! (json / "undir_edges" / "e" /"data").isEmpty) + assert (((json / "undir_edges" / "f") ? "data").isEmpty) + } + behavior of "Another graph" var otherG : Graph = _ @@ -160,6 +188,96 @@ class GraphSpec extends FlatSpec with GivenWhenThen { assert(jsonGraphShouldBe === Graph.fromJson(Graph.toJson(jsonGraphShouldBe))) } + behavior of "Normalisation" + + it should "Normalise a single wire vertex to a single wire vertex" in { + val g1 = Graph.fromJson(Json.parse( + """ + |{ + | "wire_vertices": ["w0"], + | "undir_edges": { + | } + |} + """.stripMargin)) + assert(g1.normalise === g1) + } + + + it should "Normalise K2 to itself" in { + val g1 = Graph.fromJson(Json.parse( + """ + |{ + | "wire_vertices": ["w0", "w1"], + | "undir_edges": { + | "e0": {"src": "w0", "tgt": "w1"} + | } + |} + """.stripMargin)) + assert(g1.normalise === g1) + } + + it should "Normalise P3 to itself" in { + val g1 = Graph.fromJson(Json.parse( + """ + |{ + | "wire_vertices": ["w0", "w1", "w2"], + | "undir_edges": { + | "e0": {"src": "w0", "tgt": "w1"}, + | "e1": {"src": "w1", "tgt": "w2"} + | } + |} + """.stripMargin)) + assert(g1.normalise === g1) + } + + + it should "Normalise P4 to P3" in { + val g1 = Graph.fromJson(Json.parse( + """ + |{ + | "wire_vertices": ["w0", "w1", "w2","w3"], + | "undir_edges": { + | "e0": {"src": "w0", "tgt": "w1"}, + | "e1": {"src": "w1", "tgt": "w2"}, + | "e2": {"src": "w2", "tgt": "w3"} + | } + |} + """.stripMargin)) + val g2 = Graph.fromJson(Json.parse( + """ + |{ + | "wire_vertices": ["w0", "w1", "w3"], + | "undir_edges": { + | "e0": {"src": "w0", "tgt": "w1"}, + | "e2": {"src": "w1", "tgt": "w3"} + | } + |} + """.stripMargin)) + assert(g1.normalise === g2) + } + + + it should "Not normalise self loop" in { + val g1 = Graph.fromJson(Json.parse( + """ + |{ + | "wire_vertices": ["w0"], + | "undir_edges": { + | "e0": {"src": "w0", "tgt": "w0"} + | } + |} + """.stripMargin)) + val g2 = Graph.fromJson(Json.parse( + """ + |{ + | "wire_vertices": ["w0"], + | "undir_edges": { + | } + |} + """.stripMargin)) + assert(g1.normalise !== g2) + } + behavior of "Some more graphs" it should "normalise" in { @@ -301,6 +419,18 @@ class GraphSpec extends FlatSpec with GivenWhenThen { assert(twobb1.inBBox.domf(v1) === bs) } + it should "rename correctly" in { + var renamed = twobb.addEdge(twobb.edges.fresh, UndirEdge(), "v0" -> "v0").rename(Map("v0" -> "v1")) + assert(renamed.verts.contains("v1")) + assert(renamed.inBBox.domf("v1").nonEmpty) + } + + it should "add correctly" in { + var added = twobb.appendGraph(twobb.renameAvoiding(twobb)) + assert(added.verts.size == 2) + println(added) + } + behavior of "A graph with angles" it should "return the free variables" in { @@ -314,10 +444,114 @@ class GraphSpec extends FlatSpec with GivenWhenThen { |} """.stripMargin), thy = rg) - assert(g.freeVars === Set("x", "y", "z")) + assert(g.freeVars === Set( + (ValueType.AngleExpr, "x"), + (ValueType.AngleExpr, "y"), + (ValueType.AngleExpr, "z"))) + } + + behavior of "A graph with composite values" + + it should "accept a composite-valued graph" in { + val g = Graph.fromJson(Json.parse( + """ + |{ + | "node_vertices": { + | "v0": {"data": {"type": "Z", "value": "x + y, true"}}, + | "v1": {"data": {"type": "X", "value": "z + pi, false"}} + | } + |} + """.stripMargin), thy = composite_thy) + var pi_commute = Graph.fromJson(Json.parse( + """ + |{"wire_vertices":{ + | "b1":{"annotation":{"boundary":true,"coord":[-0.75,2.0]}}, + | "b0":{"annotation":{"boundary":true,"coord":[-11.5,1.5]}} + | }, + |"node_vertices":{ + | "v1":{"data":{"type":"Z","value":"\\pi/4,false"},"annotation":{"coord":[-3.0,0.5]}}, + | "v0":{"data":{"type":"Z","value":"\\pi/2,false"},"annotation":{"coord":[-8.5,1.75]}}, + | "v2":{"data":{"type":"X","value":"\\pi, \\true"},"annotation":{"coord":[-5.5,2.5]}}}, + | "undir_edges":{ + | "e0":{"src":"b0","tgt":"v0"}, + | "e1":{"src":"v0","tgt":"v2"}, + | "e2":{"src":"v2","tgt":"v1"}, + | "e3":{"src":"v1","tgt":"b1"} + | } + |}""".stripMargin + ), thy = composite_thy) + + assert(pi_commute.freeVars === Set()) + assert(pi_commute.vdata("v2").asInstanceOf[NodeV].phaseData === + new CompositeExpression( + Vector(ValueType.AngleExpr, ValueType.Boolean), + Vector(PhaseExpression(1, Map(), ValueType.AngleExpr), PhaseExpression(1, Map(), ValueType.Boolean))) + ) } // it should "support Graph.Flavor clipboard flavor" in { // assert(Graph().isDataFlavorSupported(Graph.Flavor)) // } + + behavior of "vertex cutting" + // Create a graph with left, middle and right vertices + // Join middle and right, + // Experiment with cutting out the left vertex + // Note that we are normalising the graphs, which can cause unintuitive behaviour + var cut_g = new Graph() + cut_g = cut_g.addVertex("v0", NodeV()). + addVertex("v1", NodeV()). + addVertex("v2", NodeV()). + addEdge("e", DirEdge(), "v1" -> "v2") + + it should "count arities" in { + val g0 = cut_g.normalise + assert(g0.arity("v0") == 0) + assert(g0.arity("v1") == 1) + assert(g0.arity("v2") == 1) + } + + + it should "find end-points of wires" in { + val g0 = cut_g.normalise + assert(g0.edgeEndPoints("e0")._1 == Set("v1", "v2").map(VName)) + } + + it should "find no end-points in a loop" in { + val g0 = new Graph(). + addVertex("v0", WireV()). + addVertex("v1", WireV()). + addEdge("e0", DirEdge(), "v0"->"v1"). + addEdge("e1", DirEdge(), "v1"->"v0") + assert(g0.edgeEndPoints("e0") == (Set(), Set("e0","e1").map(EName), Set("v0","v1").map(VName))) + } + + it should "cut out a lone vertex" in { + val(g0,nb, rb) = cut_g. + normalise.cutVertex("v0") + assert(nb.isEmpty) + assert(g0.nodesThatAreNotWires.size == 2) + assert(g0.edges.size == 2) //double expected because normalisation + } + + it should "cut out a vertex attached to another vertex" in { + val (g0,nb, rb) = cut_g. + normalise.addEdge("ee0", DirEdge(), "v0" -> "v1").cutVertex("v0") + assert(nb.size == 1) + assert(g0.nodesThatAreNotWires.size == 2) + assert(g0.edges.size == 3) // One going left to boundary, two between v1 and v2 + } + + + it should "cut out a vertex attached to another vertex and a boundary" in { + val (g0,nb, rb) = cut_g. + normalise. + addEdge("ee0", DirEdge(), "v0" -> "v1"). + addVertex("b0", WireV()). + addEdge("ee1", DirEdge(), "b0" -> "v0"). + cutVertex("v0") + assert(nb.size == 1) + assert(g0.nodesThatAreNotWires.size == 2) + assert(g0.edges.size == 3) // One going left to boundary, two between v1 and v2 + } } diff --git a/scala/src/test/scala/quanto/data/test/PhaseExpressionSpec.scala b/scala/src/test/scala/quanto/data/test/PhaseExpressionSpec.scala new file mode 100644 index 00000000..181486fe --- /dev/null +++ b/scala/src/test/scala/quanto/data/test/PhaseExpressionSpec.scala @@ -0,0 +1,224 @@ +package quanto.data.test + +import org.scalatest._ +import quanto.data.Theory.ValueType +import quanto.data._ +import quanto.util.Rational + +class PhaseExpressionSpec extends FlatSpec { + behavior of "A rational number" + + it should "add correctly" in { + assert(Rational(1, 2) + Rational(2, 3) === Rational(7, 6)) + } + + behavior of "An angle expression" + + + def AngleExpression(constant: Rational) = PhaseExpression(constant, Map(), ValueType.AngleExpr) + + def AngleExpression(constant: Rational, coefficients: Map[String, Rational]) = + PhaseExpression(constant, coefficients, ValueType.AngleExpr) + + def zero = PhaseExpression.zero(ValueType.AngleExpr) + + def one = PhaseExpression.one(ValueType.AngleExpr) + + def testReparse(e: PhaseExpression) { + assert(e === parse(e.toString)) + } + + def parse(s: String): PhaseExpression = PhaseExpression.parse(s, ValueType.AngleExpr) + + it should "create expressions" in { + val a1 = AngleExpression(Rational(1,2), Map("a" -> Rational(2,1))) + assert(a1.constant == Rational(1,2)) + assert(a1.coefficients.size == 1) + testReparse(a1) + } + + + it should "compare expressions" in { + val a = AngleExpression(Rational(0), Map("a" -> Rational(1))) + val b = AngleExpression(Rational(0), Map("b" -> Rational(1))) + assert(zero === AngleExpression(Rational(0))) + assert(one === AngleExpression(Rational(1))) + assert(a + b === b + a) + assert(zero !== one) + assert(one + (a * Rational(1, 2)) + (b * 4) === (a * Rational(1, 2)) + one + (b * 4)) + assert((a * Rational(1, 2)) + (a * Rational(2, 3)) === (a * Rational(7, 6))) + } + + it should "parse '0'" in { + testReparse(zero) + assert(parse("") === zero) + assert(parse("0") === zero) + } + + it should "parse 'PI'" in { + testReparse(one) + assert(parse("\\pi") === one) + assert(parse("1\\pi") === one) + assert(parse("1*\\pi") === one) + assert(parse("1/1\\pi") === one) + assert(parse("1/1*\\pi") === one) + assert(parse("1\\pi/1") === one) + assert(parse("1*\\pi/1") === one) + assert(parse("\\pi/1") === one) + + assert(parse("pi") === one) + assert(parse("1pi") === one) + assert(parse("1*pi") === one) + assert(parse("1/1pi") === one) + assert(parse("1/1*pi") === one) + assert(parse("1pi/1") === one) + assert(parse("1*pi/1") === one) + assert(parse("pi/1") === one) + + assert(parse("PI") === one) + assert(parse("1PI") === one) + assert(parse("1*PI") === one) + assert(parse("1/1PI") === one) + assert(parse("1/1*PI") === one) + assert(parse("1PI/1") === one) + assert(parse("1*PI/1") === one) + assert(parse("PI/1") === one) + } + + it should "parse 'a'" in { + val a = AngleExpression(Rational(0), Map("a" -> Rational(1))) + testReparse(a) + assert(parse("a") === a) + assert(parse("1a") === a) + assert(parse("1*a") === a) + assert(parse("1/1a") === a) + assert(parse("1/1*a") === a) + assert(parse("1a/1") === a) + assert(parse("1*a/1") === a) + assert(parse("a/1") === a) + } + + it should "parse '-PI'" in { + val minusPI = one * -1 + testReparse(minusPI) + assert(parse("-pi") === minusPI) + assert(parse("-1pi") === minusPI) + assert(parse("-1*pi") === minusPI) + assert(parse("-1/1pi") === minusPI) + assert(parse("-1/1*pi") === minusPI) + assert(parse("-1pi/1") === minusPI) + assert(parse("-1*pi/1") === minusPI) + assert(parse("-pi/1") === minusPI) + } + + it should "parse '-a'" in { + val minusA = AngleExpression(Rational(0), Map("a" -> Rational(-1))) + testReparse(minusA) + assert(parse("-a") === minusA) + assert(parse("-1a") === minusA) + assert(parse("-1*a") === minusA) + assert(parse("-1/1a") === minusA) + assert(parse("-1/1*a") === minusA) + assert(parse("-1a/1") === minusA) + assert(parse("-1*a/1") === minusA) + assert(parse("-a/1") === minusA) + } + + it should "parse \\pi/2" in { + val ohPI = AngleExpression(Rational(1, 2)) + testReparse(ohPI) + assert(parse("1/2") === ohPI) + assert(parse("\\pi/2") === ohPI) + } + + it should "parse '+-3/4 PI'" in { + val tfPI = AngleExpression(Rational(3, 4)) + testReparse(tfPI) + assert(parse("3/4") === tfPI) + assert(parse("3/4pi") === tfPI) + assert(parse("3/4*pi") === tfPI) + assert(parse("3pi/4") === tfPI) + assert(parse("3*pi/4") === tfPI) + + val mtfPI = AngleExpression(Rational(-3, 4)) + testReparse(mtfPI) + assert(parse("-3/4") === mtfPI) + assert(parse("-3/4pi") === mtfPI) + assert(parse("-3/4*pi") === mtfPI) + assert(parse("-3pi/4") === mtfPI) + assert(parse("-3*pi/4") === mtfPI) + } + + it should "parse '+-1/4 PI' and '+-1/4 a'" in { + val fPI = AngleExpression(Rational(1, 4)) + val mfPI = fPI * -1 + val fA = AngleExpression(Rational(0), Map("a" -> Rational(1, 4))) + val mfA = fA * -1 + assert(parse("pi/4") === fPI) + assert(parse("-pi/4") === mfPI) + assert(parse("a/4") === fA) + assert(parse("-a/4") === mfA) + } + + it should "parse addition and subtraction correctly" in { + val a = AngleExpression(Rational(0), Map("a" -> Rational(1))) + val b = AngleExpression(Rational(0), Map("b" -> Rational(1))) + testReparse(a + b) + testReparse(a - b) + testReparse((a * -1) - b) + assert(parse("a + b") === a + b) + assert(parse("a - b") === a - b) + assert(parse("-a + b") === b - a) + assert(parse("- a - b") === (a * -1) - b) + assert(parse("-(a + b)") === (a * -1) - b) + assert(parse("-(a - b)") === b - a) + } + + + it should "do substitutions correctly" in { + val e1 = parse("x - 2 y") + val e2 = parse("a + b - c") + assert(e1.subst("x", e2) === parse("a + b - c - 2y")) + assert(e1.subst("y", e2) === parse("x - 2a - 2b + 2c")) + } + + it should "evaluate a polynomial" in { + val e1 = parse("2x + 3/4") + assert(Math.abs(e1.evaluate(Map("x" -> 1.0/8.0)) - 1) < 1e-15) + } + + it should "evaluate a string" in { + val s1 = PhaseExpression.parse("Hi", ValueType.String) + assert(s1.toString == "Hi") + } + + it should "evaluate an empty" in { + val e1 = PhaseExpression.parse("Hi", ValueType.Empty) + assert(e1.toString == "") + } + + it should "print a rational" in { + val r1 = PhaseExpression.parse("3/27", ValueType.Rational) + assert(r1.toString == "1/9") + val r2 = PhaseExpression.parse("3/27 + a/4", ValueType.Rational) + assert(r2.toString == "1/9 + 1/4 a") + } + + it should "parse alpha'" in { + val e = parse("alpha' + alpha") + assert(e.coefficients.keys.size == 2) + } + + it should "substitute beta" in { + val e = parse("-beta") + val f = e.subst("beta", parse("0")) + assert(f == zero) + } + + it should "substitute composite beta" in { + val e = parse("-beta") + val ee = CompositeExpression.wrap(e) + val f = ee.substSubVariables(Map((ValueType.AngleExpr, "beta") -> "0")) + assert(f.values.head == zero) + } +} diff --git a/scala/src/test/scala/quanto/data/test/TheorySpec.scala b/scala/src/test/scala/quanto/data/test/TheorySpec.scala index 9650ab5b..c0c6b528 100644 --- a/scala/src/test/scala/quanto/data/test/TheorySpec.scala +++ b/scala/src/test/scala/quanto/data/test/TheorySpec.scala @@ -9,13 +9,13 @@ class TheorySpec extends FlatSpec { behavior of "A theory" val rgValueDesc = Theory.ValueDesc( - typ = Theory.ValueType.String, + typ = Vector(Theory.ValueType.String), latexConstants = true, validateWithCore = true ) val hValueDesc = Theory.ValueDesc( - typ = Theory.ValueType.Empty, + typ = Vector(Theory.ValueType.Empty), latexConstants = false, validateWithCore = false ) @@ -68,9 +68,9 @@ class TheorySpec extends FlatSpec { | "vertex_types" : { | "red" : { | "value" : { - | "validate_with_core" : true, + | "type" : "string", | "latex_constants" : true, - | "type" : "string" + | "validate_with_core" : true | }, | "style" : { | "label" : { @@ -79,7 +79,8 @@ class TheorySpec extends FlatSpec { | }, | "stroke_color" : [ 0.0, 0.0, 0.0 ], | "fill_color" : [ 1.0, 0.0, 0.0 ], - | "shape" : "circle" + | "shape" : "circle", + | "stroke_width" : 1 | }, | "default_data" : { | "type" : "red", @@ -91,9 +92,9 @@ class TheorySpec extends FlatSpec { | }, | "green" : { | "value" : { - | "validate_with_core" : true, + | "type" : "string", | "latex_constants" : true, - | "type" : "string" + | "validate_with_core" : true | }, | "style" : { | "label" : { @@ -102,7 +103,8 @@ class TheorySpec extends FlatSpec { | }, | "stroke_color" : [ 0.0, 0.0, 0.0 ], | "fill_color" : [ 0.0, 1.0, 0.0 ], - | "shape" : "circle" + | "shape" : "circle", + | "stroke_width" : 1 | }, | "default_data" : { | "type" : "green", @@ -114,9 +116,9 @@ class TheorySpec extends FlatSpec { | }, | "hadamard" : { | "value" : { - | "validate_with_core" : false, + | "type" : "empty", | "latex_constants" : false, - | "type" : "empty" + | "validate_with_core" : false | }, | "style" : { | "label" : { @@ -125,19 +127,22 @@ class TheorySpec extends FlatSpec { | }, | "stroke_color" : [ 0.0, 0.0, 0.0 ], | "fill_color" : [ 1.0, 1.0, 0.0 ], - | "shape" : "rectangle" + | "shape" : "rectangle", + | "stroke_width" : 1 | }, | "default_data" : { | "type" : "hadamard" | } | } | }, + | "default_vertex_type" : "red", + | "default_edge_type" : "plain", | "edge_types" : { | "plain" : { | "value" : { - | "validate_with_core" : false, + | "type" : "empty", | "latex_constants" : false, - | "type" : "empty" + | "validate_with_core" : false | }, | "style" : { | "stroke_color" : [ 0.0, 0.0, 0.0 ], @@ -151,9 +156,7 @@ class TheorySpec extends FlatSpec { | "type" : "plain" | } | } - | }, - | "default_vertex_type" : "red", - | "default_edge_type" : "plain" + | } |} """.stripMargin) @@ -164,6 +167,27 @@ class TheorySpec extends FlatSpec { } it should "load from JSON" in { - assert(Theory.fromJson(thyJson) === thy) + var loaded : Theory = Theory.fromJson(thyJson) + assert(loaded.vertexTypes === thy.vertexTypes) + print(loaded.vertexTypes) + assert(loaded === thy) + } + + behavior of "mixing theories" + + it should "mix with itself" in { + assert(thy.mixin(thy, None) == thy) + } + + val plain = Theory.fromFile("plain") + val rg = Theory.fromFile("red_green") + + it should "mix with others" in { + val mixedPlainRG = plain.mixin(rg, Some("plain with red_green")) + assert(mixedPlainRG.vertexTypes.keySet == Set("var", "hadamard", "Z", "X")) + } + + it should "mix with fragments" in { + assert(plain.mixin(newVertexTypes = rg.vertexTypes.filter(_._1 == "Z")).vertexTypes.keySet == Set("var", "Z")) } } diff --git a/scala/src/test/scala/quanto/rewrite/test/AngleExpressionMatcherSpec.scala b/scala/src/test/scala/quanto/rewrite/test/AngleExpressionMatcherSpec.scala index 64372e4f..fb222fe7 100644 --- a/scala/src/test/scala/quanto/rewrite/test/AngleExpressionMatcherSpec.scala +++ b/scala/src/test/scala/quanto/rewrite/test/AngleExpressionMatcherSpec.scala @@ -1,11 +1,30 @@ package quanto.rewrite.test import org.scalatest._ +import quanto.data.Theory.ValueType import quanto.rewrite._ import quanto.data._ +import quanto.util.Rational class AngleExpressionMatcherSpec extends FlatSpec { - import AngleExpression.parse + + def AngleExpressionMatcher(pVars: Vector[String], tVars: Vector[String]) = PhaseExpressionMatcher(pVars, tVars, Some(2)) + + def AngleExpression(constant: Rational) = PhaseExpression(constant, Map(), ValueType.AngleExpr) + + def AngleExpression(constant: Rational, coefficients: Map[String, Rational]) = + PhaseExpression(constant, coefficients, ValueType.AngleExpr) + + def zero = PhaseExpression.zero(ValueType.AngleExpr) + + def one = PhaseExpression.one(ValueType.AngleExpr) + + def testReparse(e: PhaseExpression) { + assert(e === parse(e.toString)) + } + + def parse(s: String): PhaseExpression = PhaseExpression.parse(s, ValueType.AngleExpr) + behavior of "An angle expression matcher" it should "handle single-variable matches" in { @@ -15,7 +34,7 @@ class AngleExpressionMatcherSpec extends FlatSpec { val mp = m.toMap // check we got correct map - assert(m.toMap === Map("a" -> parse("x + 2 y"), "b" -> parse("z + pi"))) + assert(m.toMap.mapValues(_.as(ValueType.AngleExpr)) === Map("a" -> parse("x + 2 y"), "b" -> parse("z + pi"))) // check substitutions into pattern yield target assert(parse("a").subst(mp) === parse("x + 2 y")) @@ -27,7 +46,7 @@ class AngleExpressionMatcherSpec extends FlatSpec { m = m.addMatch(parse("a + 2 b"), parse("x + 2 y")).get m = m.addMatch(parse("b + c"), parse("z + pi")).get m = m.addMatch(parse("a - c"), parse("4 x")).get - val mp = m.toMap + val mp = m.toMap(ValueType.AngleExpr) // check we got correct map assert(mp === Map( diff --git a/scala/src/test/scala/quanto/rewrite/test/BBoxMatcherSpec.scala b/scala/src/test/scala/quanto/rewrite/test/BBoxMatcherSpec.scala index 3a0542b7..111594f7 100644 --- a/scala/src/test/scala/quanto/rewrite/test/BBoxMatcherSpec.scala +++ b/scala/src/test/scala/quanto/rewrite/test/BBoxMatcherSpec.scala @@ -842,4 +842,43 @@ class BBoxMatcherSpec extends FlatSpec { // should match once as empty graph (killing both bboxes), and once using two copies of each bbox assert(matches.size === 2) } + + it should "match spider-law with multiple connecting edges (and no external ones)" in { + val g1 = Graph.fromJson(Json.parse( + """ + |{ + | "wire_vertices": ["w1"], + | "node_vertices": { + | "v0": {"data": {"type": "Z"}}, + | "v1": {"data": {"type": "Z"}} + | }, + | "undir_edges": { + | "e3": {"src": "v0", "tgt": "w1"}, + | "e4": {"src": "v1", "tgt": "w1"} + | }, + | "bang_boxes": { + | "bb2": {"contents": ["w1"]} + | } + |} + """.stripMargin), thy = rg) + + + val g2 = Graph.fromJson(Json.parse( + """ + |{ + | "node_vertices": { + | "v0": {"data": {"type": "Z"}}, + | "v1": {"data": {"type": "Z"}} + | }, + | "undir_edges": { + | "e0": {"src": "v0", "tgt": "v1"}, + | "e1": {"src": "v0", "tgt": "v1"} + | } + |} + """.stripMargin), thy = rg) + + val matches = Matcher.findMatches(g1, g1) + + assert(matches.size === 2) + } } diff --git a/scala/src/test/scala/quanto/rewrite/test/CompositeExpressionMatcherSpec.scala b/scala/src/test/scala/quanto/rewrite/test/CompositeExpressionMatcherSpec.scala new file mode 100644 index 00000000..05995aeb --- /dev/null +++ b/scala/src/test/scala/quanto/rewrite/test/CompositeExpressionMatcherSpec.scala @@ -0,0 +1,74 @@ +package quanto.rewrite.test + +import org.scalatest._ +import quanto.data.Theory.ValueType +import quanto.data._ +import quanto.rewrite._ +import quanto.util.Rational + + +class CompositeExpressionMatcherSpec extends FlatSpec { + + def AngleExpression(constant: Rational) = PhaseExpression(constant, Map(), ValueType.AngleExpr) + + def AngleExpression(constant: Rational, coefficients: Map[String, Rational]) = + PhaseExpression(constant, coefficients, ValueType.AngleExpr) + + def parse(types: String, values: String) = CompositeExpression.parse(types, values) + + def composite(p: PhaseExpression): CompositeExpression = CompositeExpression.wrap(p) + + def zero = PhaseExpression.zero(ValueType.AngleExpr) + + def one = PhaseExpression.one(ValueType.AngleExpr) + + def testReparse(e: PhaseExpression) { + assert(e === parse(e.toString)) + } + + def parse(s: String): PhaseExpression = PhaseExpression.parse(s, ValueType.AngleExpr) + + behavior of "An angle expression matcher" + + it should "handle single-variable matches" in { + + var m = CompositeExpressionMatcher() + m = m.addMatch(parse("angle", "a"), parse("angle", "x + 2 y")).get + m = m.addMatch(parse("angle", "b"), parse("angle", "z + pi")).get + val mp = m.toMap + + // check we got correct map + assert(m.toMap.values.flatten.toMap === + Map("a" -> parse("angle", "x + 2 y").firstOrError(ValueType.AngleExpr), "b" -> parse("angle", "z + pi").firstOrError(ValueType.AngleExpr))) + + // check substitutions into pattern yield target + assert(parse("angle", "a").substSubValues(mp.values.flatten.toMap) === parse("angle", "x + 2 y")) + assert(parse("angle", "b").substSubValues(mp.values.flatten.toMap) === parse("angle", "z + pi")) + } + + it should "handle expression matches" in { + var m = CompositeExpressionMatcher() + m = m.addMatch(parse("angle, boolean", "a + 2 b, d"), parse("angle, boolean", "x + 2 y, true")).get + m = m.addMatch(parse("angle, boolean", "b + c, false"), parse("angle, boolean", "z + pi, false")).get + m = m.addMatch(parse("angle, boolean", "a - c,false"), parse("angle, boolean", "4 x, false")).get + val mp = m.toMap + + // check we got correct map + assert(mp(ValueType.AngleExpr) === Map( + "a" -> parse("angle", "7 x - 2 y + 2 z").firstOrError(ValueType.AngleExpr), + "b" -> parse("angle", "-pi - 3 x + 2 y - z").firstOrError(ValueType.AngleExpr), + "c" -> parse("angle", "3 x - 2 y + 2 z").firstOrError(ValueType.AngleExpr) + )) + assert(mp(ValueType.Boolean) === Map( + "d" -> parse("boolean", "true").firstOrError(ValueType.Boolean) + )) + + // check substitutions into pattern yield target + assert(parse("angle, boolean", "a + 2 b, d").substSubValues(mp.values.flatten.toMap) === + parse("angle,boolean","x + 2 y, true")) + assert(parse("angle, boolean", "b + c").substSubValues(mp.values.flatten.toMap) === + parse("angle, boolean","z + pi")) + assert(parse("angle, boolean", "a - c").substSubValues(mp.values.flatten.toMap) === + parse("angle, boolean","4 x")) + } +} diff --git a/scala/src/test/scala/quanto/rewrite/test/MatcherSpec.scala b/scala/src/test/scala/quanto/rewrite/test/MatcherSpec.scala index b0e5a842..8137c296 100644 --- a/scala/src/test/scala/quanto/rewrite/test/MatcherSpec.scala +++ b/scala/src/test/scala/quanto/rewrite/test/MatcherSpec.scala @@ -3,6 +3,7 @@ package quanto.rewrite.test import quanto.rewrite._ import quanto.data._ import org.scalatest._ +import quanto.data.Theory.{EdgeDesc, ValueType} import quanto.util.json.Json class MatcherSpec extends FlatSpec { @@ -189,7 +190,8 @@ class MatcherSpec extends FlatSpec { """.stripMargin), thy = rg) val matches = Matcher.findMatches(g1, g2) assert(matches.size === 1) - assert(matches.head.subst === Map("x" -> AngleExpression.parse("(1/2) \\pi"))) + assert(matches.head.subst(ValueType.AngleExpr).mapValues(_.as(ValueType.AngleExpr)) === + Map("x" -> PhaseExpression.parse("(1/2) \\pi", ValueType.AngleExpr))) } it should "match a graph with one wire on itself" in { @@ -635,6 +637,81 @@ class MatcherSpec extends FlatSpec { assert(matches.forall { _.isHomomorphism }) } + + val stringWire = rg.edgeTypes("string") + val twoWire : Theory = rg.mixin(newEdgeTypes = Map("q" -> stringWire.copy())) + + it should "fail to match edges of different types" in { + val g = Graph.fromJson(Json.parse( + """ + |{ + | "node_vertices": { + | "v0": {"data": {"type": "Z", "value": ""}}, + | "v1": {"data": {"type": "X", "value": ""}} + | }, + | "undir_edges": { + | "e0": {"data": {"type": "string"},"src": "v0", "tgt": "v1"}, + | "e1": {"data": {"type": "q"},"src": "v0", "tgt": "v1"} + | } + |} + """.stripMargin), thy = twoWire) + val matches = Matcher.findMatches(g, g) + assert(matches.size === 1) + } + + + it should "fail to match bare edges of different types" ignore { + val g = Graph.fromJson(Json.parse( + """ + |{ + | "node_vertices": { + | "v0": {"data": {"type": "Z", "value": ""}}, + | "v1": {"data": {"type": "Z", "value": ""}} + | }, + | "wire_vertices": ["i0"], + | "undir_edges": { + | "e0": {"data": {"type": "string"},"src": "v0", "tgt": "v1"}, + | "e1": {"data": {"type": "string"},"src": "v0", "tgt": "i0"} + | } + |} + """.stripMargin), thy = twoWire) + + val g2 = Graph.fromJson(Json.parse( + """ + |{ + | "wire_vertices": ["i0", "o0"], + | "undir_edges": { + | "e0": {"data": {"type": "q"},"src": "i0", "tgt": "o0"} + | } + |} + """.stripMargin), thy = twoWire) + val matches = Matcher.findMatches(g2, g) + assert(matches.size === 0) + } + + + it should "find 2x2 matches from different wire types" in { + val g = Graph.fromJson(Json.parse( + """ + |{ + | "node_vertices": { + | "v0": {"data": {"type": "Z", "value": ""}}, + | "v1": {"data": {"type": "X", "value": ""}} + | }, + | "undir_edges": { + | "e0": {"data": {"type": "string"},"src": "v0", "tgt": "v1"}, + | "e1": {"data": {"type": "q"},"src": "v0", "tgt": "v1"}, + | "e2": {"data": {"type": "string"},"src": "v0", "tgt": "v1"}, + | "e3": {"data": {"type": "q"},"src": "v0", "tgt": "v1"} + | } + |} + """.stripMargin), thy = rg) + val matches = Matcher.findMatches(g, g) + assert(matches.size === 4) + } + + + it should "match a graph with 2 components on itself" in { val g1 = Graph.fromJson(Json.parse( """ diff --git a/scala/src/test/scala/quanto/util/test/RationalMatrixSpec.scala b/scala/src/test/scala/quanto/util/test/RationalMatrixSpec.scala index 9a39deaf..34075e80 100644 --- a/scala/src/test/scala/quanto/util/test/RationalMatrixSpec.scala +++ b/scala/src/test/scala/quanto/util/test/RationalMatrixSpec.scala @@ -9,13 +9,13 @@ class RationalMatrixSpec extends FlatSpec { behavior of "A rational matrix" it should "be constructable" in { - val m = new RationalMatrix(Vector(Vector(1,2,3), Vector(4,5,6), Vector(7,8,9)), 3) + val m = new RationalMatrix(Vector(Vector(1,2,3), Vector(4,5,6), Vector(7,8,9)), 3, Some(2)) } it should "perform gaussian elimination" in { - val m1 = new RationalMatrix(Vector(Vector(1,2,3,4), Vector(2,2,2,2), Vector(2,1,3,1)), 3) - val m2 = new RationalMatrix(Vector(Vector(1,2), Vector(2,2)), 2) - val m3 = new RationalMatrix(Vector(Vector(Rational(1,5),2,3,4), Vector(2,2,2,2)), 3) + val m1 = new RationalMatrix(Vector(Vector(1,2,3,4), Vector(2,2,2,2), Vector(2,1,3,1)), 3, Some(2)) + val m2 = new RationalMatrix(Vector(Vector(1,2), Vector(2,2)), 2, Some(2)) + val m3 = new RationalMatrix(Vector(Vector(Rational(1,5),2,3,4), Vector(2,2,2,2)), 3, Some(2)) assert(!m1.isReduced) assert(m1.gauss.get.isReduced)