diff --git a/src/main/scala/fhetest/Command.scala b/src/main/scala/fhetest/Command.scala index a9ba2af..16b5cce 100644 --- a/src/main/scala/fhetest/Command.scala +++ b/src/main/scala/fhetest/Command.scala @@ -148,15 +148,20 @@ case object CmdExecute extends Command("execute") { case object CmdGen extends Command("gen") { val help = "Generate random T2 programs." val examples = List( - "fhetest gen", - "fhetest gen -n 10", + "fhetest gen --INT 10", + "fhetest gen --DOUBLE 10", ) def apply(args: List[String]): Unit = args match { case Nil => println("No argument given.") - case n :: _ => { - val num = n.toInt - Generate(List(Backend.SEAL, Backend.OpenFHE), ENC_TYPE.ENC_INT, num) - } + case encTypeString :: remain => + val encType = parseEncType(encTypeString) + val generator = Generate(encType) + val n = remain match { + case nString :: Nil => + nString.toInt + case _ => 10 // default value + } + generator.show(List(Backend.SEAL, Backend.OpenFHE), n) } } diff --git a/src/main/scala/fhetest/Phase/Generate.scala b/src/main/scala/fhetest/Phase/Generate.scala index f5cf7b2..f31dae0 100644 --- a/src/main/scala/fhetest/Phase/Generate.scala +++ b/src/main/scala/fhetest/Phase/Generate.scala @@ -14,9 +14,11 @@ import javax.print.attribute.EnumSyntax import scala.jdk.CollectionConverters._ import scala.util.Random -case object Generate { +case class Generate(encType: ENC_TYPE) { + val symbolTable = boilerplate()._2 - def apply(backends: List[Backend], encType: ENC_TYPE, n: Int) = + // This is for testing purpose + def show(backends: List[Backend], n: Int) = for { template <- allTempletes.take(n) } { @@ -31,7 +33,7 @@ case object Generate { print(toString(concretized) + "\n") print("-" * 80 + "\n") - val ast = buildTemplate(concretized) + val ast: Goal = buildTemplate(concretized) val result = Interp(ast, 32768, 65537) print("CLEAR" + " : " + result + "\n") for { @@ -61,64 +63,75 @@ case object Generate { } // TODO: current length = 100, it can be changed to larger value - def concretizeTemplate(template: Templete): Templete = - return assignRandIntValues(template, 100, 100) - - // TODO: This is just a temporary solution for making symbolTable and encType available - val (_: Goal, symbolTable, encType) = - val baseStr = """ + def concretizeTemplate(template: Templete): Templete = encType match + case ENC_TYPE.ENC_INT => assignRandValues(template, 100, 100) + case ENC_TYPE.ENC_DOUBLE => assignRandValues(template, 100, 2.0) + case _ => throw new Exception("encType is not set") + + def boilerplate(): (Goal, SymbolTable, _) = + val baseStr = encType match { + case ENC_TYPE.ENC_INT => + """ int main(void) { EncInt x, y; int c; - print (x); + print_batched (x, 10); + return 0; + } + """ + case ENC_TYPE.ENC_DOUBLE => + """ + int main(void) { + EncDouble x, y; + int c; + print_batched (x, 10); return 0; } """ + case ENC_TYPE.None => throw new Exception("encType is not set") + } val baseStream = new ByteArrayInputStream(baseStr.getBytes("UTF-8")) Parse(baseStream) // TODO: current print only 10 values, it can be changed to larger value - def createNewBaseTemplate(): Goal = { - val baseStr = """ - int main(void) { - EncInt x, y; - int c; - print_batched (x, 10); - return 0; - } - """ - val baseStream = new ByteArrayInputStream(baseStr.getBytes("UTF-8")) - Parse(baseStream)._1 - } + def createNewBaseTemplate(): Goal = boilerplate()._1 def assignIntValue(template: Templete, vx: Int, vy: Int): Templete = val assignments = List(Assign("x", vx), Assign("y", vy)) return assignments ++ template // vxs = [1, 2, 3], vys = [4, 5, 6] => x = { 1, 2, 3 }; y = { 4, 5, 6 }; - def assignIntValues( + def assignValues( template: Templete, - vxs: List[Int], - vys: List[Int], + vxs: List[Int | Double], + vys: List[Int | Double], vc: Int, ): Templete = val assignments = List(AssignVec("x", vxs), AssignVec("y", vys), Assign("c", vc)) return assignments ++ template - def assignRandIntValue(template: Templete, bound: Int): Templete = + def assignRandValue(template: Templete, bound: Int): Templete = val vx = Random.between(0, bound) val vy = Random.between(0, bound) return assignIntValue(template, vx, vy) // TODO: Currently, T2 DSL does not support negative numbers - def assignRandIntValues(template: Templete, len: Int, bound: Int): Templete = + def assignRandValues(template: Templete, len: Int, bound: Int): Templete = + val l = Random.between(1, len) + val vxs = List.fill(l)(Random.between(0, bound)) + val vys = List.fill(l)(Random.between(0, bound)) + // TODO: Currently, c is bounded by 10. It can be changed to larger value + val vc = Random.between(0, 10) + return assignValues(template, vxs, vys, vc) + + def assignRandValues(template: Templete, len: Int, bound: Double): Templete = val l = Random.between(1, len) val vxs = List.fill(l)(Random.between(0, bound)) val vys = List.fill(l)(Random.between(0, bound)) // TODO: Currently, c is bounded by 10. It can be changed to larger value val vc = Random.between(0, 10) - return assignIntValues(template, vxs, vys, vc) + return assignValues(template, vxs, vys, vc) def parseStmt(stmtStr: String): Statement = val input_stream: InputStream = new ByteArrayInputStream( @@ -128,8 +141,8 @@ case object Generate { trait Stmt case class Var() - case class Assign(l: String, r: Int) extends Stmt - case class AssignVec(l: String, r: List[Int]) extends Stmt + case class Assign(l: String, r: (Int | Double)) extends Stmt + case class AssignVec(l: String, r: List[Int | Double]) extends Stmt case class Add(l: Var, r: Var) extends Stmt case class Sub(l: Var, r: Var) extends Stmt case class Mul(l: Var, r: Var) extends Stmt diff --git a/src/main/scala/fhetest/Utils/Utils.scala b/src/main/scala/fhetest/Utils/Utils.scala index 65beb9f..3dd273e 100644 --- a/src/main/scala/fhetest/Utils/Utils.scala +++ b/src/main/scala/fhetest/Utils/Utils.scala @@ -35,6 +35,12 @@ def parseBackend(backendString: String): Option[Backend] = case _ => None } +def parseEncType(encTypeString: String): ENC_TYPE = + parsePrefixedArg(encTypeString) match + case Some("INT") => ENC_TYPE.ENC_INT + case Some("DOUBLE") => ENC_TYPE.ENC_DOUBLE + case _ => ENC_TYPE.None + // TODO: Refactor this function def parseWordSizeAndEncParams(args: List[String]): (Option[Int], EncParams) = { val argMap = args