Skip to content

Commit

Permalink
Refactor Generate.scala to support different encryption types (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
Maokami committed Feb 15, 2024
1 parent 406c234 commit c489cb3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 36 deletions.
17 changes: 11 additions & 6 deletions src/main/scala/fhetest/Command.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,20 @@ case object CmdExecute extends Command("execute") {
case object CmdGen extends Command("gen") {
val help = "Generate random T2 programs."
val examples = List(
"fhetest gen",
"fhetest gen -n 10",
"fhetest gen --INT 10",
"fhetest gen --DOUBLE 10",
)
def apply(args: List[String]): Unit = args match {
case Nil => println("No argument given.")
case n :: _ => {
val num = n.toInt
Generate(List(Backend.SEAL, Backend.OpenFHE), ENC_TYPE.ENC_INT, num)
}
case encTypeString :: remain =>
val encType = parseEncType(encTypeString)
val generator = Generate(encType)
val n = remain match {
case nString :: Nil =>
nString.toInt
case _ => 10 // default value
}
generator.show(List(Backend.SEAL, Backend.OpenFHE), n)
}
}

Expand Down
73 changes: 43 additions & 30 deletions src/main/scala/fhetest/Phase/Generate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import javax.print.attribute.EnumSyntax
import scala.jdk.CollectionConverters._
import scala.util.Random

case object Generate {
case class Generate(encType: ENC_TYPE) {
val symbolTable = boilerplate()._2

def apply(backends: List[Backend], encType: ENC_TYPE, n: Int) =
// This is for testing purpose
def show(backends: List[Backend], n: Int) =
for {
template <- allTempletes.take(n)
} {
Expand All @@ -31,7 +33,7 @@ case object Generate {
print(toString(concretized) + "\n")
print("-" * 80 + "\n")

val ast = buildTemplate(concretized)
val ast: Goal = buildTemplate(concretized)
val result = Interp(ast, 32768, 65537)
print("CLEAR" + " : " + result + "\n")
for {
Expand Down Expand Up @@ -61,64 +63,75 @@ case object Generate {
}

// TODO: current length = 100, it can be changed to larger value
def concretizeTemplate(template: Templete): Templete =
return assignRandIntValues(template, 100, 100)

// TODO: This is just a temporary solution for making symbolTable and encType available
val (_: Goal, symbolTable, encType) =
val baseStr = """
def concretizeTemplate(template: Templete): Templete = encType match
case ENC_TYPE.ENC_INT => assignRandValues(template, 100, 100)
case ENC_TYPE.ENC_DOUBLE => assignRandValues(template, 100, 2.0)
case _ => throw new Exception("encType is not set")

def boilerplate(): (Goal, SymbolTable, _) =
val baseStr = encType match {
case ENC_TYPE.ENC_INT =>
"""
int main(void) {
EncInt x, y;
int c;
print (x);
print_batched (x, 10);
return 0;
}
"""
case ENC_TYPE.ENC_DOUBLE =>
"""
int main(void) {
EncDouble x, y;
int c;
print_batched (x, 10);
return 0;
}
"""
case ENC_TYPE.None => throw new Exception("encType is not set")
}
val baseStream = new ByteArrayInputStream(baseStr.getBytes("UTF-8"))
Parse(baseStream)

// TODO: current print only 10 values, it can be changed to larger value
def createNewBaseTemplate(): Goal = {
val baseStr = """
int main(void) {
EncInt x, y;
int c;
print_batched (x, 10);
return 0;
}
"""
val baseStream = new ByteArrayInputStream(baseStr.getBytes("UTF-8"))
Parse(baseStream)._1
}
def createNewBaseTemplate(): Goal = boilerplate()._1

def assignIntValue(template: Templete, vx: Int, vy: Int): Templete =
val assignments = List(Assign("x", vx), Assign("y", vy))
return assignments ++ template

// vxs = [1, 2, 3], vys = [4, 5, 6] => x = { 1, 2, 3 }; y = { 4, 5, 6 };
def assignIntValues(
def assignValues(
template: Templete,
vxs: List[Int],
vys: List[Int],
vxs: List[Int | Double],
vys: List[Int | Double],
vc: Int,
): Templete =
val assignments =
List(AssignVec("x", vxs), AssignVec("y", vys), Assign("c", vc))
return assignments ++ template

def assignRandIntValue(template: Templete, bound: Int): Templete =
def assignRandValue(template: Templete, bound: Int): Templete =
val vx = Random.between(0, bound)
val vy = Random.between(0, bound)
return assignIntValue(template, vx, vy)

// TODO: Currently, T2 DSL does not support negative numbers
def assignRandIntValues(template: Templete, len: Int, bound: Int): Templete =
def assignRandValues(template: Templete, len: Int, bound: Int): Templete =
val l = Random.between(1, len)
val vxs = List.fill(l)(Random.between(0, bound))
val vys = List.fill(l)(Random.between(0, bound))
// TODO: Currently, c is bounded by 10. It can be changed to larger value
val vc = Random.between(0, 10)
return assignValues(template, vxs, vys, vc)

def assignRandValues(template: Templete, len: Int, bound: Double): Templete =
val l = Random.between(1, len)
val vxs = List.fill(l)(Random.between(0, bound))
val vys = List.fill(l)(Random.between(0, bound))
// TODO: Currently, c is bounded by 10. It can be changed to larger value
val vc = Random.between(0, 10)
return assignIntValues(template, vxs, vys, vc)
return assignValues(template, vxs, vys, vc)

def parseStmt(stmtStr: String): Statement =
val input_stream: InputStream = new ByteArrayInputStream(
Expand All @@ -128,8 +141,8 @@ case object Generate {

trait Stmt
case class Var()
case class Assign(l: String, r: Int) extends Stmt
case class AssignVec(l: String, r: List[Int]) extends Stmt
case class Assign(l: String, r: (Int | Double)) extends Stmt
case class AssignVec(l: String, r: List[Int | Double]) extends Stmt
case class Add(l: Var, r: Var) extends Stmt
case class Sub(l: Var, r: Var) extends Stmt
case class Mul(l: Var, r: Var) extends Stmt
Expand Down
6 changes: 6 additions & 0 deletions src/main/scala/fhetest/Utils/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def parseBackend(backendString: String): Option[Backend] =
case _ => None
}

def parseEncType(encTypeString: String): ENC_TYPE =
parsePrefixedArg(encTypeString) match
case Some("INT") => ENC_TYPE.ENC_INT
case Some("DOUBLE") => ENC_TYPE.ENC_DOUBLE
case _ => ENC_TYPE.None

// TODO: Refactor this function
def parseWordSizeAndEncParams(args: List[String]): (Option[Int], EncParams) = {
val argMap = args
Expand Down

0 comments on commit c489cb3

Please sign in to comment.