From fd90c75e8b07ba92bafd4c9047ac10903788f3d2 Mon Sep 17 00:00:00 2001 From: Jaeho Choi Date: Thu, 15 Feb 2024 08:39:30 +0000 Subject: [PATCH] Add functionality for generator to produce a list of T2 program as string (#11) --- src/main/scala/fhetest/Phase/Generate.scala | 45 +++++++++++---------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/main/scala/fhetest/Phase/Generate.scala b/src/main/scala/fhetest/Phase/Generate.scala index f31dae0..b056515 100644 --- a/src/main/scala/fhetest/Phase/Generate.scala +++ b/src/main/scala/fhetest/Phase/Generate.scala @@ -15,8 +15,28 @@ import scala.jdk.CollectionConverters._ import scala.util.Random case class Generate(encType: ENC_TYPE) { + // TODO : This boilerplate code is really ugly. But I cannot find a better way to do this. + val baseStrFront = encType match { + case ENC_TYPE.ENC_INT => + "int main(void) { EncInt x, y; int c; " + case ENC_TYPE.ENC_DOUBLE => + "int main(void) { EncDouble x, y; int c; " + case ENC_TYPE.None => throw new Exception("encType is not set") + } + val baseStrBack = " print_batched (x, 10); return 0; } " + val baseStr = baseStrFront + baseStrBack + val symbolTable = boilerplate()._2 + def apply(n: Int): List[String] = { + for { + template <- allTempletes.take(n).toList + } yield { + val concretized = concretizeTemplate(template) + toStringWithBaseStr(concretized) + } + } + // This is for testing purpose def show(backends: List[Backend], n: Int) = for { @@ -69,27 +89,6 @@ case class Generate(encType: ENC_TYPE) { case _ => throw new Exception("encType is not set") def boilerplate(): (Goal, SymbolTable, _) = - val baseStr = encType match { - case ENC_TYPE.ENC_INT => - """ - int main(void) { - EncInt x, y; - int c; - print_batched (x, 10); - return 0; - } - """ - case ENC_TYPE.ENC_DOUBLE => - """ - int main(void) { - EncDouble x, y; - int c; - print_batched (x, 10); - return 0; - } - """ - case ENC_TYPE.None => throw new Exception("encType is not set") - } val baseStream = new ByteArrayInputStream(baseStr.getBytes("UTF-8")) Parse(baseStream) @@ -160,7 +159,9 @@ case class Generate(encType: ENC_TYPE) { def concretize(s: Stmt) = parseStmt(toString(s)) type Templete = List[Stmt] - def toString(t: Templete): String = t.map(toString).mkString("\n") + def toString(t: Templete): String = t.map(toString).mkString("") + def toStringWithBaseStr(t: Templete): String = + baseStrFront + toString(t) + baseStrBack def getMulDepth(t: Templete): Int = t.count { case Mul(_, _) => true; case _ => false }