diff --git a/src/main/scala/fhetest/Checker/Utils.scala b/src/main/scala/fhetest/Checker/Utils.scala index 31e3723..cbbe6bd 100644 --- a/src/main/scala/fhetest/Checker/Utils.scala +++ b/src/main/scala/fhetest/Checker/Utils.scala @@ -1,8 +1,11 @@ package fhetest.Checker +import fhetest.Utils.* + import java.io.{File, PrintWriter} import java.nio.file.{Files, Path, Paths, StandardCopyOption} import spray.json._ +import spray.json.DefaultJsonProtocol._ // file writer def getPrintWriter(filename: String): PrintWriter = @@ -21,6 +24,60 @@ def dumpJson[T](data: T, filename: String)(implicit ): Unit = dumpFile(data.toJson.prettyPrint, filename) +def dumpResult( + program: T2Program, + i: Int, + res: CheckResult, +): Unit = { + val pgm_info = Map( + ("programId" -> JsString(i.toString)), + ("program" -> JsString(program.content)), + ) + res match { + case Same(res) => { + val (expectedLst, obtainedLst) = res.partition(_.backend == "CLEAR") + val expected_res = expectedLst.apply(0).result + val result = pgm_info ++ Map( + "result" -> JsString("Success"), + "failedLibraires" -> JsString("0"), + "failures" -> JsArray(), + "expected" -> JsString(expected_res.toString), + ) + val succFilename = s"$succDir/$i.json" + dumpJson(result, succFilename) + } + case Diff(res) => { + val (expectedLst, obtainedLst) = res.partition(_.backend == "CLEAR") + val expected = expectedLst.apply(0) + val diffResults = obtainedLst.filter(isDiff(expected, _)) + val failures = diffResults.map(r => + Map( + ("library" -> r.backend), + ("failedResult" -> r.result.toString), + ), + ) + val result = pgm_info ++ Map( + "result" -> JsString("Fail"), + "failedLibraires" -> JsString(diffResults.size.toString), + "failures" -> failures.toJson, + "expected" -> JsString(expected._2.toString), + ) + val failFilename = s"$failDir/$i.json" + dumpJson(result, failFilename) + } + case ParserError(_) => { + val result = pgm_info ++ Map( + "result" -> JsString("ParseError"), + "failedLibraires" -> JsString("NaN"), + "failures" -> JsArray(), + "expected" -> JsString(""), + ) + val psrErrFilename = s"$psrErrDir/$i.json" + dumpJson(result, psrErrFilename) + } + } +} + val TEST_DIR = fhetest.TEST_DIR val succDir = s"$TEST_DIR/succ" val failDir = s"$TEST_DIR/fail" diff --git a/src/main/scala/fhetest/Command.scala b/src/main/scala/fhetest/Command.scala index edf1f59..7b245c6 100644 --- a/src/main/scala/fhetest/Command.scala +++ b/src/main/scala/fhetest/Command.scala @@ -1,6 +1,7 @@ package fhetest import fhetest.Utils.* +import fhetest.Generate.* import fhetest.Phase.{Parse, Interp, Print, Execute, Generate, Check} sealed abstract class Command( @@ -15,15 +16,21 @@ sealed abstract class Command( /** help message */ def examples: List[String] + /** run command with parsed arguments */ + def runJob(config: Config): Unit + /** run command with command-line arguments */ - def apply(args: List[String]): Unit + def apply(args: List[String]) = { + val config = Config(args) + runJob(config) + } } /** base command */ case object CmdBase extends Command("") { val help = "does nothing." val examples = Nil - def apply(args: List[String]) = () + def runJob(config: Config) = () } /** `help` command */ @@ -32,7 +39,7 @@ case object CmdHelp extends Command("help") { val examples = List( "fhetest help", ) - def apply(args: List[String]): Unit = { + def runJob(config: Config): Unit = { println("Usage: fhetest [options]") println("Commands:") for cmd <- FHETest.commands do println(s" ${cmd.name}\n\t${cmd.help}") @@ -43,185 +50,140 @@ case object CmdHelp extends Command("help") { case object CmdInterp extends Command("interp") { val help = "Interp a T2 file." val examples = List( - "fhetest interp tmp.t2", - "fhetest interp tmp.t2 -n 4096 -m 40961", + "fhetest interp -file:tmp.t2", + "fhetest interp -file:tmp.t2 -n:4096 -m:40961", ) - def apply(args: List[String]): Unit = args match { - case file :: Nil => { - val (ast, _, _) = Parse(file) - val result = Interp(ast, 32768, 65537) - print(result) - } - case file :: remainArgs => { - val (ast, _, _) = Parse(file) - val (_, encParams) = parseWordSizeAndEncParams(remainArgs) - val result = Interp(ast, encParams.ringDim, encParams.plainMod) - print(result) - } - case Nil => println("No T2 file given.") - } + def runJob(config: Config): Unit = + val fname = config.fileName.getOrElseThrow("No T2 file given.") + val (ast, _, _) = Parse(fname) + val ringDim: Int = config.encParams.map(_.ringDim).getOrElse(32768) + val plainMod: Int = config.encParams.map(_.plainMod).getOrElse(65537) + val result = Interp(ast, ringDim, plainMod) + print(result) } /** `run` command */ case object CmdRun extends Command("run") { val help = "Run the given T2 program." val examples = List( - "fhetest run tmp.t2 --SEAL", - "fhetest run tmp.t2 --OpenFHE", - "fhetest run tmp.t2 --SEAL -w 4 -n 4096 -d 5 -m 40961", - "fhetest run tmp.t2", + "fhetest run -file:tmp.t2 -b:SEAL", + "fhetest run -file:tmp.t2 -b:OpenFHE", + "fhetest run -file:tmp.t2 -b:SEAL -w:4 -n:4096 -d:5 -m:40961", + "fhetest run -file:tmp.t2", ) - // TODO: Refactor this function: parseWordSizeAndEncParams - def apply(args: List[String]): Unit = args match { - case file :: backendString :: remainArgs => - parseBackend(backendString) match { - case Some(backend) => - given DirName = getWorkspaceDir(backend) - val (ast, symbolTable, encType) = Parse(file) - val (wordSizeOpt, encParams) = parseWordSizeAndEncParams(remainArgs) - Print( - ast, - symbolTable, - encType, - backend, - wordSizeOpt, - Some(encParams), - ) - val result = Execute(backend) - print(result) - case None => println("Argument parsing error: Invalid backend.") - } - case file :: Nil => - val (ast, _, _) = Parse(file) - val result = Interp(ast, 32768, 65537) - print(result) - case Nil => println("No T2 file given.") - } + def runJob(config: Config): Unit = + val fname = config.fileName.getOrElseThrow("No T2 file given.") + val encParams: EncParams = + config.encParams.getOrElse(EncParams(32768, 5, 65537)) + val wordSizeOpt: Option[Int] = config.wordSize + val plainMod: Int = config.encParams.map(_.plainMod).getOrElse(65537) + config.backend match { + case Some(backend) => + given DirName = getWorkspaceDir(backend) + val (ast, symbolTable, encType) = Parse(fname) + Print( + ast, + symbolTable, + encType, + backend, + wordSizeOpt, + Some(encParams), + ) + case None => + val (ast, _, _) = Parse(fname) + val result = Interp(ast, 32768, 65537) + print(result) + } } /** `compile` command */ case object CmdCompile extends Command("compile") { val help = "compiles a T2 file to the given backend." val examples = List( - "fhetest compile tmp.t2 --SEAL", - "fhetest compile tmp.t2 --OpenFHE", + "fhetest compile -file:tmp.t2 -b:SEAL", + "fhetest compile -file:tmp.t2 -b:OpenFHE", ) - def apply(args: List[String]): Unit = args match { - case file :: backendString :: remain => - parseBackend(backendString) match { - case Some(backend) => - given DirName = getWorkspaceDir(backend) - val (ast, symbolTable, encType) = Parse(file) - remain match { - case params :: _ => ??? - case Nil => Print(ast, symbolTable, encType, backend) - } - case None => println("Argument parsing error: Invalid backend.") - } - case _ :: Nil => println("No backend given.") - case Nil => println("No T2 file given.") - } - + def runJob(config: Config): Unit = + val fname = config.fileName.getOrElseThrow("No T2 file given.") + val backend = config.backend.getOrElseThrow("No backend given.") + given DirName = getWorkspaceDir(backend) + val (ast, symbolTable, encType) = Parse(fname) + Print(ast, symbolTable, encType, backend) } /** `execute` command */ case object CmdExecute extends Command("execute") { val help = "Execute the compiled code in the given backend." val examples = List( - "fhetest execute --SEAL", - "fhetest execute --OpenFHE", + "fhetest execute -b:SEAL", + "fhetest execute -b:OpenFHE", ) - def apply(args: List[String]): Unit = args match { - case backendString :: _ => - parseBackend(backendString) match { - case Some(backend) => - given DirName = getWorkspaceDir(backend) - val output = Execute(backend) - println(output) - case None => println("Argument parsing error: Invalid backend.") - } - case Nil => println("No backend given.") - } + def runJob(config: Config): Unit = + val backend = config.backend.getOrElseThrow("No backend given.") + given DirName = getWorkspaceDir(backend) + val output = Execute(backend) + println(output) } -// TODO : Get Strategy from the command line /** `gen` command */ case object CmdGen extends Command("gen") { val help = "Generate random T2 programs." val examples = List( - "fhetest gen --INT 10", - "fhetest gen --DOUBLE 10", + "fhetest gen -type:int -c:10", + "fhetest gen -type:double -c:10", + "fhetest gen -type:int -stg:exhaust -c:10", + "fhetest gen -type:double -stg:random -c:10", ) - def apply(args: List[String]): Unit = args match { - case Nil => println("No argument given.") - case encTypeString :: remain => - val encType = parseEncType(encTypeString) - val generator = Generate(encType) - val n = remain match { - case nString :: Nil => - nString.toInt - case _ => 10 // default value - } - generator.show(List(Backend.SEAL, Backend.OpenFHE), n) - } + def runJob(config: Config): Unit = + val encType = config.encType.getOrElseThrow("No encType given.") + val genCount = config.genCount.getOrElse(10) + val generator = Generate(encType) + generator.show(List(Backend.SEAL, Backend.OpenFHE), genCount) } /** `check` command */ case object CmdCheck extends Command("check") { val help = "Check results of T2 program execution." val examples = List( - "fhetest check tmp --SEAL --OpenFHE", + "fhetest check -dir:tmp -json:true", ) // TODO: json option 추가 - def apply(args: List[String]): Unit = args match { - case dir :: backendStrings => { - val backendList = backendStrings.flatMap(parseBackend(_)) - if (backendStrings.size == backendList.size) { - // TODO: temporary encParams. Fix after having parameter genernation. - val encParams = EncParams(32768, 5, 65537) - val outputs = Check(dir, backendList, encParams) - for output <- outputs do { - println(output) - } - } else { println("Argument parsing error: Invalid backend.") } + def runJob(config: Config): Unit = + val dir = config.dirName.getOrElseThrow("No directory given.") + val encParams: EncParams = + config.encParams.getOrElse(EncParams(32768, 5, 65537)) + val backends = List(Backend.SEAL, Backend.OpenFHE) + val toJson = config.toJson + val outputs = Check(dir, backends, encParams, toJson) + for output <- outputs do { + println(output) } - case _ => println("Invalid arguments") - } } /** `test` command */ case object CmdTest extends Command("test") { val help = "Check after Generate random T2 programs." val examples = List( - "fhetest test --INT --random", - "fhetest test --INT --random 10", - "fhetest test --DOUBLE --exhaust 10", + "fhetest test -type:int -stg:random", + "fhetest test -type:int -stg:random -count:10", + "fhetest test -type:double -stg:exhaust -count:10", ) // TODO: json option 추가 - def apply(args: List[String]): Unit = args match { - case Nil => println("No argument given.") - case encTypeString :: stgString :: remain => { - val nOpt = remain match { - case nString :: _ => - Some(nString.toInt) - case _ => None - } - val encType = parseEncType(encTypeString) - val strategy = parseStrategy(stgString) - val generator = Generate(encType, strategy) - val programs = generator(nOpt).map(T2Program(_)) - val backendList = List(Backend.SEAL, Backend.OpenFHE) - // TODO: temporary encParams. Fix after having parameter genernation. - val encParams = EncParams(32768, 5, 65537) - val outputs = Check(programs, backendList, encParams) - for (program, output) <- outputs do { - println("=" * 80) - println("Program : " + program.content) - println("-" * 80) - println(output) - println("=" * 80) - } + def runJob(config: Config): Unit = + val encType = config.encType.getOrElseThrow("No encType given.") + val genStrategy = config.genStrategy.getOrElse(Strategy.Random) + val genCount = config.genCount + val generator = Generate(encType, genStrategy) + val programs = generator(genCount).map(T2Program(_)) + val backendList = List(Backend.SEAL, Backend.OpenFHE) + val encParams = config.encParams.getOrElse(EncParams(32768, 5, 65537)) + val toJson = config.toJson + val outputs = Check(programs, backendList, encParams, toJson) + for (program, output) <- outputs do { + println("=" * 80) + println("Program : " + program.content) + println("-" * 80) + println(output) + println("=" * 80) } - case _ => println("EncType and Strategy are required.") - } } diff --git a/src/main/scala/fhetest/Config.scala b/src/main/scala/fhetest/Config.scala new file mode 100644 index 0000000..ab71881 --- /dev/null +++ b/src/main/scala/fhetest/Config.scala @@ -0,0 +1,58 @@ +package fhetest + +import fhetest.Utils.* +import fhetest.Generate.Strategy + +class Config( + var fileName: Option[String] = None, + var dirName: Option[String] = None, + var backend: Option[Backend] = None, + var wordSize: Option[Int] = None, + var encParams: Option[EncParams] = None, + var encType: Option[ENC_TYPE] = None, + var genStrategy: Option[Strategy] = None, + var genCount: Option[Int] = None, + var toJson: Boolean = false, +) + +object Config { + // each argument is formatted as "-key1:value1 -key2:value2 ..." + def apply(args: List[String]): Config = { + val config = new Config() + args.foreach { + case arg if arg.startsWith("-") => + val Array(key, value) = arg.drop(1).split(":", 2) + key.toLowerCase() match { + case "file" => config.fileName = Some(value) + case "dir" => config.dirName = Some(value) + case "b" => config.backend = parseBackend(value) + case "w" => config.wordSize = Some(value.toInt) + case "n" => + config.encParams = Some( + config.encParams + .getOrElse(EncParams(0, 0, 0)) + .copy(ringDim = value.toInt), + ) + case "d" => + config.encParams = Some( + config.encParams + .getOrElse(EncParams(0, 0, 0)) + .copy(mulDepth = value.toInt), + ) + case "m" => + config.encParams = Some( + config.encParams + .getOrElse(EncParams(0, 0, 0)) + .copy(plainMod = value.toInt), + ) + case "type" => config.encType = parseEncType(value) + case "stg" => config.genStrategy = parseStrategy(value) + case "count" => config.genCount = Some(value.toInt) + case "json" => config.toJson = value.toBoolean + case _ => throw new Error(s"Unknown option: $key") + } + case _ => // 잘못된 형식의 인자 처리 + } + config + } +} diff --git a/src/main/scala/fhetest/Phase/Check.scala b/src/main/scala/fhetest/Phase/Check.scala index 2e10e47..9e6e0ad 100644 --- a/src/main/scala/fhetest/Phase/Check.scala +++ b/src/main/scala/fhetest/Phase/Check.scala @@ -9,7 +9,6 @@ import java.nio.file.{Files, Paths}; import java.io.{File, InputStream, ByteArrayInputStream} import scala.jdk.CollectionConverters._ import spray.json._ -import spray.json.DefaultJsonProtocol._ case object Check { def apply( @@ -34,37 +33,38 @@ case object Check { programs: LazyList[T2Program], backends: List[Backend], encParams: EncParams, + toJson: Boolean, ): LazyList[(T2Program, CheckResult)] = { setTestDir() - var i = 0 val checkResults = for { - program <- programs + (program, i) <- programs.zipWithIndex } yield { - val checkResult = getAndWriteResult(i, program, backends, encParams) - i = i + 1 + val checkResult = apply(program, backends, encParams) + if (toJson) dumpResult(program, i, checkResult) (program, checkResult) } checkResults } + // TODO: Do we need this function? def apply( directory: String, backends: List[Backend], encParams: EncParams, + toJson: Boolean, ): LazyList[String] = { val dir = new File(directory) if (dir.exists() && dir.isDirectory) { val files = Files.list(Paths.get(directory)) val fileList = files.iterator().asScala.toList setTestDir() - var i = 0 val checkResults = for { - filePath <- fileList.to(LazyList) + (filePath, i) <- fileList.to(LazyList).zipWithIndex } yield { val fileStr = Files.readAllLines(filePath).asScala.mkString("") val program = T2Program(fileStr) - val checkResult = getAndWriteResult(i, program, backends, encParams) - i = i + 1 + val checkResult = apply(program, backends, encParams) + if (toJson) dumpResult(program, i, checkResult) val pgmStr = "-" * 10 + " Program " + "-" * 10 + "\n" + fileStr + "\n" val reportStr = checkResult.toString + "\n" pgmStr + reportStr @@ -140,60 +140,4 @@ case object Check { }, ) - def getAndWriteResult( - i: Int, - program: T2Program, - backends: List[Backend], - encParams: EncParams, - ): CheckResult = { - val pgm_info = Map( - ("programId" -> JsString(i.toString)), - ("program" -> JsString(program.content)), - ) - val checkResult = apply(program, backends, encParams) - checkResult match { - case Same(res) => { - val (expectedLst, obtainedLst) = res.partition(_.backend == "CLEAR") - val expected_res = expectedLst.apply(0).result - val result = pgm_info ++ Map( - "result" -> JsString("Success"), - "failedLibraires" -> JsString("0"), - "failures" -> JsArray(), - "expected" -> JsString(expected_res.toString), - ) - val succFilename = s"$succDir/$i.json" - dumpJson(result, succFilename) - } - case Diff(res) => { - val (expectedLst, obtainedLst) = res.partition(_.backend == "CLEAR") - val expected = expectedLst.apply(0) - val diffResults = obtainedLst.filter(isDiff(expected, _)) - val failures = diffResults.map(r => - Map( - ("library" -> r.backend), - ("failedResult" -> r.result.toString), - ), - ) - val result = pgm_info ++ Map( - "result" -> JsString("Fail"), - "failedLibraires" -> JsString(diffResults.size.toString), - "failures" -> failures.toJson, - "expected" -> JsString(expected._2.toString), - ) - val failFilename = s"$failDir/$i.json" - dumpJson(result, failFilename) - } - case ParserError(_) => { - val result = pgm_info ++ Map( - "result" -> JsString("ParseError"), - "failedLibraires" -> JsString("NaN"), - "failures" -> JsArray(), - "expected" -> JsString(""), - ) - val psrErrFilename = s"$psrErrDir/$i.json" - dumpJson(result, psrErrFilename) - } - } - checkResult - } } diff --git a/src/main/scala/fhetest/Phase/Generate.scala b/src/main/scala/fhetest/Phase/Generate.scala index 1f74877..120feaa 100644 --- a/src/main/scala/fhetest/Phase/Generate.scala +++ b/src/main/scala/fhetest/Phase/Generate.scala @@ -25,7 +25,6 @@ case class Generate( "int main(void) { EncInt x, y; int yP; int c; " case ENC_TYPE.ENC_DOUBLE => "int main(void) { EncDouble x, y; double yP; int c; " - case ENC_TYPE.None => throw new Exception("encType is not set") } val baseStrBack = " print_batched (x, 20); return 0; } " val baseStr = baseStrFront + baseStrBack diff --git a/src/main/scala/fhetest/Utils/Utils.scala b/src/main/scala/fhetest/Utils/Utils.scala index 06ef22f..c810483 100644 --- a/src/main/scala/fhetest/Utils/Utils.scala +++ b/src/main/scala/fhetest/Utils/Utils.scala @@ -18,10 +18,16 @@ enum Backend(val name: String): case OpenFHE extends Backend("OpenFHE") enum ENC_TYPE: - case None, ENC_INT, ENC_DOUBLE + case ENC_INT, ENC_DOUBLE type DirName = String +extension [T](opt: Option[T]) + def getOrElseThrow(message: => String): T = opt match { + case Some(value) => value + case None => throw new Exception(message) + } + case class EncParams(ringDim: Int, mulDepth: Int, plainMod: Int) def translateT2EncType(enc_type: T2ENC_TYPE): ENC_TYPE = enc_type match @@ -41,33 +47,17 @@ def parseBackend(backendString: String): Option[Backend] = case _ => None } -def parseEncType(encTypeString: String): ENC_TYPE = - parsePrefixedArg(encTypeString) match - case Some("INT") => ENC_TYPE.ENC_INT - case Some("DOUBLE") => ENC_TYPE.ENC_DOUBLE - case _ => ENC_TYPE.None - -// TODO: Refactor this function -def parseWordSizeAndEncParams(args: List[String]): (Option[Int], EncParams) = { - val argMap = args - .grouped(2) - .collect { - case List(key, value) => key -> value - } - .toMap - - val wordSizeOpt = argMap.get("-w").map(_.toInt) - val ringDim = argMap.get("-n").map(_.toInt).getOrElse(0) - val mulDepth = argMap.get("-d").map(_.toInt).getOrElse(0) - val plainMod = argMap.get("-m").map(_.toInt).getOrElse(0) - - (wordSizeOpt, EncParams(ringDim, mulDepth, plainMod)) -} -def parseStrategy(sString: String): Strategy = - parsePrefixedArg(sString) match - case Some("exhaust") => Strategy.Exhaustive - case Some("random") => Strategy.Random - case _ => throw new Exception("Invalid strategy") +def parseEncType(encTypeString: String): Option[ENC_TYPE] = + encTypeString.toLowerCase() match + case "int" => Some(ENC_TYPE.ENC_INT) + case "double" => Some(ENC_TYPE.ENC_DOUBLE) + case _ => None + +def parseStrategy(sString: String): Option[Strategy] = + sString.toLowerCase() match + case "exhaust" => Some(Strategy.Exhaustive) + case "random" => Some(Strategy.Random) + case _ => None def getWorkspaceDir(backend: Backend): String = backend match case Backend.SEAL => fhetest.SEAL_DIR