diff --git a/cli/src/swam/cli/Main.scala b/cli/src/swam/cli/Main.scala index e8a1027f..afaf1ac0 100644 --- a/cli/src/swam/cli/Main.scala +++ b/cli/src/swam/cli/Main.scala @@ -13,6 +13,7 @@ import io.odin.formatter.Formatter import io.odin.formatter.options.ThrowableFormat import io.odin.{Logger, consoleLogger} import fs2._ +import swam.ValType.{F32, F64, I32, I64} import swam.binary.ModuleStream import swam.decompilation._ import swam.code_analysis.coverage.{CoverageListener, CoverageReporter} @@ -21,6 +22,8 @@ import swam.runtime.trace._ import swam.runtime.wasi.Wasi import swam.runtime.{Engine, Function, Module, Value} import swam.text.Compiler +import swam.binary.custom.{FunctionNames, ModuleName} +import swam.runtime.internals.compiler.CompiledFunction private object NoTimestampFormatter extends JFormatter { override def format(x: LogRecord): String = @@ -38,6 +41,9 @@ object Main extends CommandIOApp(name = "swam-cli", header = "Swam from the comm val wasmFile = Opts.argument[Path](metavar = "wasm") + val func_name = + Opts.argument[String](metavar = "functionName") + // Arguments that get passed to the WASM code you execute. They are available through WASI args_get. val restArguments = Opts.arguments[String](metavar = "args").orEmpty @@ -128,7 +134,7 @@ object Main extends CommandIOApp(name = "swam-cli", header = "Swam from the comm debug, wasmFile, restArguments, - covOut, + covOut, covfilter, wasmArgTypes).mapN { (main, wat, wasi, time, dirs, trace, traceFile, filter, debug, wasm, args, covOut, covfilter, wasmArgTypes) => @@ -165,21 +171,8 @@ object Main extends CommandIOApp(name = "swam-cli", header = "Swam from the comm restArguments, covfilter, wasmArgTypes) - .mapN { - (main, wat, wasi, time, dirs, trace, traceFile, filter, debug, wasm, args, covfilter, wasmArgTypes) => - RunServer(wasm, - args, - main, - wat, - wasi, - time, - trace, - filter, - traceFile, - dirs, - debug, - covfilter, - wasmArgTypes) + .mapN { (main, wat, wasi, time, dirs, trace, traceFile, filter, debug, wasm, args, covfilter, wasmArgTypes) => + RunServer(wasm, args, main, wat, wasi, time, trace, filter, traceFile, dirs, debug, covfilter, wasmArgTypes) } } @@ -187,6 +180,11 @@ object Main extends CommandIOApp(name = "swam-cli", header = "Swam from the comm (textual, wasmFile, out.orNone).mapN { (textual, wasm, out) => Decompile(wasm, textual, out) } } + val inferOpts: Opts[Options] = + Opts.subcommand("infer", "Get the parameters type for functions file in Wasm module.") { + (wasmFile, wat, wasi, func_name).mapN { (wasm, wat, wasi, func_name) => Infer(wasm, wat, wasi, func_name) } + } + val validateOpts: Opts[Options] = Opts.subcommand("validate", "Validate a wasm file") { (wasmFile, wat, dev).mapN(Validate(_, _, _)) } @@ -198,8 +196,14 @@ object Main extends CommandIOApp(name = "swam-cli", header = "Swam from the comm val outFileOptions = List(StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING) def main: Opts[IO[ExitCode]] = - runOpts.orElse(serverOpts).orElse(covOpts).orElse(decompileOpts).orElse(validateOpts).orElse(compileOpts).map { - opts => + runOpts + .orElse(serverOpts) + .orElse(covOpts) + .orElse(inferOpts) + .orElse(decompileOpts) + .orElse(validateOpts) + .orElse(compileOpts) + .map { opts => Blocker[IO].use { blocker => opts match { case Run(file, args, main, wat, wasi, time, trace, filter, tracef, dirs, debug, wasmArgTypes) => @@ -221,7 +225,7 @@ object Main extends CommandIOApp(name = "swam-cli", header = "Swam from the comm _ <- IO(executeFunction(IO(preparedFunction), argsParsed, time)) } yield ExitCode.Success - // TODO: Remove this and instead to coverage flag in Run(...) + // TODO: Remove this and instead to coverage flag in Run(...) case WasmCov(file, args, main, @@ -284,8 +288,9 @@ object Main extends CommandIOApp(name = "swam-cli", header = "Swam from the comm module = if (wat) tcompiler.stream(file, debug, blocker) else engine.sections(file, blocker) compiled <- engine.compile(module) preparedFunction <- prepareFunction(compiled, main, dirs, args, wasi, blocker) - _ <- IO(Server - .listen(IO(preparedFunction), wasmArgTypes, time, file, coverageListener)) + _ <- IO( + Server + .listen(IO(preparedFunction), wasmArgTypes, time, file, coverageListener)) } yield ExitCode.Success case Decompile(file, textual, out) => @@ -318,7 +323,49 @@ object Main extends CommandIOApp(name = "swam-cli", header = "Swam from the comm res <- engine.validate(module).attempt _ <- res.fold(t => logger.error("Module is invalid", t), _ => logger.info("Module is valid")) } yield ExitCode.Success + case Infer(file, wat, wasi, func_name) => + for { + engine <- Engine[IO](blocker) + tcompiler <- swam.text.Compiler[IO](blocker) + module = if (wat) tcompiler.stream(file, false, blocker) else engine.sections(file, blocker) + compiled <- engine.compile(module) + names <- IO(compiled.names.flatMap(_.subsections.collectFirst { case FunctionNames(n) => n })) + exitCode <- IO( + names match { + case Some(x) => { + val func = x.filter { case (idx, name) => func_name == name } + + if (func.nonEmpty) { + + if (func.size > 1) { + + System.err.println(s"Warning $func_name has more than one definition, taking the first one") + } + + val tpeidx = func.collectFirst { case (tid, _) => tid }.get + + // There is always one at this point + val tpe = compiled.functions.filter(f => f.idx == tpeidx)(0).tpe + + val params = tpe.params.map { + case I32 => "Int32" + case I64 => "Int64" + case F32 => "Float32" + case F64 => "Float64" + } + println(params.mkString(",")) + ExitCode.Success + } else { + System.err.println(s"Function '$func_name' does not exist") + ExitCode.Error + } + } + case None => ExitCode.Error + } + ) + + } yield exitCode case Compile(file, out, debug) => for { tcompiler <- Compiler[IO](blocker) @@ -333,7 +380,7 @@ object Main extends CommandIOApp(name = "swam-cli", header = "Swam from the comm } } - } + } def prepareFunction(module: Module[IO], functionName: String, diff --git a/cli/src/swam/cli/options.scala b/cli/src/swam/cli/options.scala index 6df43fea..301ebb7b 100644 --- a/cli/src/swam/cli/options.scala +++ b/cli/src/swam/cli/options.scala @@ -71,3 +71,10 @@ case class WasmCov(file: Path, filter: Boolean, wasmArgTypes: List[String]) extends Options + +case class Infer(file: Path, + wat: Boolean, + wasi: Boolean, + functionName:String + )extends Options + diff --git a/runtime/src/swam/runtime/Module.scala b/runtime/src/swam/runtime/Module.scala index 58500ddf..00bd9740 100644 --- a/runtime/src/swam/runtime/Module.scala +++ b/runtime/src/swam/runtime/Module.scala @@ -45,12 +45,12 @@ class Module[F[_]] private[runtime] ( private[runtime] val tables: Vector[TableType], private[runtime] val memories: Vector[MemType], private[runtime] val start: Option[Int], - private[runtime] val functions: Vector[CompiledFunction[F]], + val functions: Vector[CompiledFunction[F]], private[runtime] val elems: Vector[CompiledElem[F]], private[runtime] val data: Vector[CompiledData[F]])(implicit F: MonadError[F, Throwable]) { self => - private[runtime] lazy val names = { + lazy val names = { val sec = customs.collectFirst { case Custom("name", payload) => payload } diff --git a/runtime/src/swam/runtime/internals/compiler/CompiledFunction.scala b/runtime/src/swam/runtime/internals/compiler/CompiledFunction.scala index 54cd2af8..eff2d374 100644 --- a/runtime/src/swam/runtime/internals/compiler/CompiledFunction.scala +++ b/runtime/src/swam/runtime/internals/compiler/CompiledFunction.scala @@ -23,7 +23,7 @@ import cfg._ import cats.effect.IO import swam.runtime.internals.interpreter.AsmInst -private[runtime] case class CompiledFunction[F[_]](idx: Int, +case class CompiledFunction[F[_]](idx: Int, tpe: FuncType, locals: Vector[ValType], code: Array[AsmInst[F]])