From 63a23453470fe6b2dd2cd5a20baf78ea9856217a Mon Sep 17 00:00:00 2001 From: Hyerin Park Date: Thu, 28 Mar 2024 08:32:51 +0000 Subject: [PATCH] Add Count command --- src/main/scala/fhetest/Checker/Utils.scala | 18 ++++++++ src/main/scala/fhetest/Command.scala | 50 +++++++++++++++++++++- src/main/scala/fhetest/fhetest.scala | 3 ++ 3 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/main/scala/fhetest/Checker/Utils.scala b/src/main/scala/fhetest/Checker/Utils.scala index 4d6f183..97757b8 100644 --- a/src/main/scala/fhetest/Checker/Utils.scala +++ b/src/main/scala/fhetest/Checker/Utils.scala @@ -204,6 +204,17 @@ implicit val encodeIntOrDouble: Encoder[Int | Double] = Encoder.instance { case d: Double => Json.fromDoubleOrNull(d) } +implicit val mapEncoder: Encoder[Map[List[Int], Int]] = + new Encoder[Map[List[Int], Int]] { + override def apply(a: Map[List[Int], Int]): Json = { + val encodedPairs = a.map { + case (key, value) => + key.toString() -> value.asJson + } + Json.obj(encodedPairs.toSeq: _*) + } + } + object DumpUtil { def dumpFile(data: String, filename: String): Unit = { val writer = new PrintWriter(filename) @@ -285,6 +296,13 @@ object DumpUtil { } } + def dumpCount(dir: String, countMap: Map[List[Int], Int]): Unit = { + val jsonString = + countMap.asJson.spaces2 + val outputFileName = s"$dir/count.json" + dumpFile(jsonString, outputFileName) + } + def readResult(filePath: String): ResultInfo = { val fileContents = readFile(filePath) val resultInfo = diff --git a/src/main/scala/fhetest/Command.scala b/src/main/scala/fhetest/Command.scala index a531680..ee1d9bf 100644 --- a/src/main/scala/fhetest/Command.scala +++ b/src/main/scala/fhetest/Command.scala @@ -2,9 +2,14 @@ package fhetest import fhetest.Utils.* import fhetest.Generate.* +import fhetest.Generate.Utils.combinations import fhetest.Phase.{Parse, Interp, Print, Execute, Generate, Check} import fhetest.Checker.DumpUtil +import java.nio.file.{Files, Paths}; +import java.io.File +import scala.jdk.CollectionConverters._ + sealed abstract class Command( /** command name */ val name: String, @@ -119,7 +124,7 @@ case object CmdRun extends BackendCommand("run") { /** `compile` command */ case object CmdCompile extends Command("compile") { - val help = "compiles a T2 file to the given backend." + val help = "Compile a T2 file to the given backend." val examples = List( "fhetest compile -file:tmp.t2 -b:SEAL", "fhetest compile -file:tmp.t2 -b:OpenFHE", @@ -299,3 +304,46 @@ case object CmdReplay extends Command("replay") { print(result) } } + +case object CmdCount extends Command("count") { + val help = + "Count the number of programs tested for each combination of valid filters" + val examples = List( + "fhetest count -dir:logs/test-invalid", + ) + def runJob(config: Config): Unit = + val dirString = config.dirName.getOrElseThrow("No directory given.") + if (dirString contains "invalid") { + val dir = new File(dirString) + if (dir.exists() && dir.isDirectory) { + val numOfValidFilters = + classOf[ValidFilter].getDeclaredClasses.toList.filter { cls => + classOf[ValidFilter] + .isAssignableFrom(cls) && cls != classOf[ValidFilter] + }.length + val allCombinations = (1 to numOfValidFilters).toList.flatMap( + combinations(_, numOfValidFilters), + ) + var countMap: Map[List[Int], Int] = + allCombinations.foldLeft(Map.empty[List[Int], Int]) { + case (acc, comb) => + acc + (comb -> 0) + } + val files = Files.list(Paths.get(dirString)) + val fileList = files.iterator().asScala.toList + for { + filePath <- fileList + fileName = filePath.toString() + } yield { + val resultInfo = DumpUtil.readResult(fileName) + val t2Program = resultInfo.program + val invalidFilterIdxList = t2Program.invalidFilterIdxList + countMap = countMap.updatedWith(invalidFilterIdxList) { + case Some(cnt) => Some(cnt + 1) + case None => Some(1) // unreachable + } + } + DumpUtil.dumpCount(dirString, countMap) + } + } else println("Directrory contains test cases of VALID programs") +} diff --git a/src/main/scala/fhetest/fhetest.scala b/src/main/scala/fhetest/fhetest.scala index 810a2b9..ad029c2 100644 --- a/src/main/scala/fhetest/fhetest.scala +++ b/src/main/scala/fhetest/fhetest.scala @@ -33,6 +33,9 @@ object FHETest { CmdTest, // Replay the given json CmdReplay, + // Make a json report of invalid program testing + // Count the number of programs tested for each combination of valid filters + CmdCount, ) val cmdMap = commands.foldLeft[Map[String, Command]](Map()) { case (map, cmd) => map + (cmd.name -> cmd)