Skip to content

Commit

Permalink
Make wireWith work better with scala3
Browse files Browse the repository at this point in the history
  • Loading branch information
jilen committed Oct 11, 2024
1 parent b99be6b commit c5083e0
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,28 @@ object MacwireMacros {

val dependencyResolver = DependencyResolver.throwErrorOnResolutionFailure[q.type, T](log)

val (params, fun) = factory.asTerm match {
case Inlined(_, _, Block(List(DefDef(_, List(p), _, Some(Apply(f, _)))), _)) => (p.params, f)
case _ => report.errorAndAbort(s"Not supported factory type: [$factory]")
def functionParamTypes(t: TypeRepr): List[List[TypeRepr]] = {
if (t.isFunctionType) {
// Handle curried function
t.typeArgs.init :: functionParamTypes(t.typeArgs.last)
} else {
Nil
}
}

val values = params.map {
// case vd@ValDef(_, name, tpt, rhs) => dependencyResolver.resolve(vd.symbol, typeCheckIfNeeded(tpt))
case vd @ ValDef(name, tpt, rhs) => dependencyResolver.resolve(vd.symbol, tpt.tpe)
// Implicit params are pre-applied while passing to wireWith
val values = functionParamTypes(factory.asTerm.tpe).zipWithIndex.map { case (paramList, i) =>
paramList.zipWithIndex.map { case (tpe, j) =>
// Resolve require a symbol, create a fake symbol here
val fakeSymbol = Symbol.newVal(Symbol.noSymbol, s"p_${i}_${j}", tpe, Flags.Param, Symbol.noSymbol)
dependencyResolver.resolve(fakeSymbol, tpe)
}
}

val code = Apply(fun, values).asExprOf[T]
val funApply: Term = Select.unique(factory.asTerm, "apply")
val code = values
.foldLeft(funApply) { (fun, args) =>
Apply(fun, args)
}
.asExprOf[T]
log(s"Generated code: ${code.show}")
code
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ object Test {
val b = wire[B[IO]]
}

require(Test.lb.la == Test.la)
require(Test.b.a == Test.a)
require(Test.lb.la eq Test.la)
require(Test.b.a eq Test.a)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
case class A(s: String)

object A {
def create() = {
val s = "foo"
wire[A]
}
}
require(A.create().s == "foo")
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ case class A()
case class B()

object Test {
case class C(a: A, b: B, s: String)
case class C(a: A, b: B)

object C {
def factory(a: A, b: B): C = {
val s = "hey!"
wire[C]
}
}
Expand All @@ -17,6 +16,5 @@ object Test {
lazy val c: C = wireWith(C.factory _)
}

require(Test.c.s == "hey!")
require(Test.c.a eq Test.a)
require(Test.c.b eq Test.b)
require(Test.c.b eq Test.b)

0 comments on commit c5083e0

Please sign in to comment.