Skip to content

Commit

Permalink
Initial implementationn for testing (
Browse files Browse the repository at this point in the history
  • Loading branch information
hyerinshelly committed Feb 15, 2024
1 parent bbac00c commit 406c234
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 1 deletion.
32 changes: 31 additions & 1 deletion src/main/scala/fhetest/Command.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package fhetest

import fhetest.Utils.*
import fhetest.Phase.{Parse, Interp, Print, Execute, Generate}
import fhetest.Phase.{Parse, Interp, Print, Execute, Generate, Check}

sealed abstract class Command(
/** command name */
Expand Down Expand Up @@ -159,3 +159,33 @@ case object CmdGen extends Command("gen") {
}
}
}

/** `check` command */
case object CmdCheck extends Command("check") {
val help = "Check results of T2 program execution."
val examples = List(
"fhetest check tmp.t2 --SEAL --OpenFHE",
)
def apply(args: List[String]): Unit = args match {
// TODO: dir instead of a single file?
case file :: backendStrings => {
val backendOptList = backendStrings.map(parseBackend(_))
val validBackends = backendOptList.forall(_ match {
case Some(backend) => true
case None => false
})
if (validBackends) {
val backendList = backendOptList.map(_ match {
case Some(backend) => backend
})
val (ast, symbolTable, encType) = Parse(file)
// TODO: temporary encParams. Fix after having parameter genernation.
val encParams = EncParams(32768, 5, 65537)
val output =
Check(List((ast, encParams, encType, symbolTable)), backendList)
println(output)
} else { println("Argument parsing error: Invalid backend.") }
}
case _ => println("Invalid arguments")
}
}
108 changes: 108 additions & 0 deletions src/main/scala/fhetest/Phase/Checker.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package fhetest.Phase

import fhetest.Utils.*
import org.twc.terminator.t2dsl_compiler.T2DSLsyntaxtree.*;
import org.twc.terminator.SymbolTable;

import java.nio.file.{Files, Paths};

case object Check {

def apply(
programs: List[(Goal, EncParams, ENC_TYPE, SymbolTable)],
backends: List[Backend],
): String = {
val bugs =
programs.foldLeft(List[(Goal, List[(String, String)])]())((acc, pgm) => {
val (interpResult, executeResults) = getResults(pgm, backends)
interpResult._2 match {
case Normal(expected) => {
val diffs = executeResults.filter(interpResult._2 != _._2)
val lst = diffs.map((backend, res) =>
res match {
case PrintError => ("- " + backend.toString, ": PrintFail")
case LibraryError(m) =>
("- " + backend.toString, ": [Exception] " + m)
case Normal(r) => ("- " + backend.toString, ": \n" + r)
},
)
if (lst.isEmpty) acc
else { acc :+ (pgm._1, ("", "[Expected] \n" + expected) :: lst) }
}
case InterpError => acc :+ (pgm._1, List(("CLEAR", "InterpFail")))
}
})
bugs.foldLeft("")((str, bug) => {
// TODO: print ast
// val ast = bug._1
val report = bug._2
str + report.foldLeft("")((str2, r) => str2 + r._1 + r._2 + "\n")
})
}

trait ExecuteResult
case class Normal(res: String) extends ExecuteResult
// case object ParseError extends ExecuteResult //TODO: if receive T2 program string instead of AST
case object InterpError extends ExecuteResult
case object PrintError extends ExecuteResult
case class LibraryError(msg: String) extends ExecuteResult
// case object TimeoutError extends ExecuteResult //TODO: development
case object Throw extends ExecuteResult

def getResults(
program: (Goal, EncParams, ENC_TYPE, SymbolTable),
backends: List[Backend],
): ((String, ExecuteResult), List[(String, ExecuteResult)]) = {
val ast = program._1
val encParams = program._2
val encType = program._3
val symbolTable = program._4
val interpResult = (
"CLEAR",
try {
val res = Interp(ast, encParams.ringDim, encParams.plainMod)
Normal(res.trim)
} catch { case _ => InterpError },
)
val executeResults: List[(String, ExecuteResult)] =
backends.map(backend =>
(
backend.toString,
execute(backend, ast, encParams, encType, symbolTable),
),
)
(interpResult, executeResults)
}

def execute(
backend: Backend,
ast: Goal,
encParams: EncParams,
encType: ENC_TYPE,
symbolTable: SymbolTable,
): ExecuteResult =
withBackendTempDir(
backend,
{ workspaceDir =>
given DirName = workspaceDir
try {
Print(
ast,
symbolTable,
encType,
backend,
encParamsOpt = Some(encParams),
)
try {
val res = Execute(backend)
Normal(res.trim)
} catch {
// TODO?: classify exception related with parmeters?
case ex: Exception => LibraryError(ex.getMessage)
}
} catch {
case _ => PrintError
}
},
)
}
2 changes: 2 additions & 0 deletions src/main/scala/fhetest/fhetest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ object FHETest {
CmdRun,
// Generate random T2 programs
CmdGen,
// Check results of T2 program execution
CmdCheck,
)
val cmdMap = commands.foldLeft[Map[String, Command]](Map()) {
case (map, cmd) => map + (cmd.name -> cmd)
Expand Down

0 comments on commit 406c234

Please sign in to comment.