diff --git a/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala b/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala index 4333635c..6c4bc7c9 100644 --- a/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala +++ b/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala @@ -54,15 +54,15 @@ private[clazz] class Utils(using val quotes: Quotes): def resolveAndOrTypeParamRefs: TypeRepr = tpe match { - case AndType(left: ParamRef, right: ParamRef) => + case AndType(left @ (_: ParamRef | _: AppliedType), right @ (_: ParamRef | _: AppliedType)) => TypeRepr.of[Any] - case AndType(left: ParamRef, right) => + case AndType(left @ (_: ParamRef | _: AppliedType), right) => right.resolveAndOrTypeParamRefs - case AndType(left, right: ParamRef) => + case AndType(left, right @ (_: ParamRef | _: AppliedType)) => left.resolveAndOrTypeParamRefs - case OrType(_: ParamRef, _) => + case OrType(_: ParamRef | _: AppliedType, _) => TypeRepr.of[Any] - case OrType(_, _: ParamRef) => + case OrType(_, _: ParamRef | _: AppliedType) => TypeRepr.of[Any] case other => other @@ -77,6 +77,12 @@ private[clazz] class Utils(using val quotes: Quotes): case pr@ParamRef(bindings, idx) if bindings == baseBindings => methodArgs.head(idx).asInstanceOf[TypeTree].tpe + case AndType(left, right) => + AndType(loop(left), loop(right)) + + case OrType(left, right) => + OrType(loop(left), loop(right)) + case AppliedType(tycon, args) => AppliedType(loop(tycon), args.map(arg => loop(arg))) diff --git a/shared/src/test/scala-3/com/paulbutcher/test/Scala3Spec.scala b/shared/src/test/scala-3/com/paulbutcher/test/Scala3Spec.scala index 67412e04..8d7b7d08 100644 --- a/shared/src/test/scala-3/com/paulbutcher/test/Scala3Spec.scala +++ b/shared/src/test/scala-3/com/paulbutcher/test/Scala3Spec.scala @@ -195,6 +195,75 @@ class Scala3Spec extends AnyFunSpec with MockFactory with Matchers { m.methodWithGenericUnion(obj2) } + it("mock union return type") { + + trait A + + trait B + + trait TraitWithUnionReturnType { + + def methodWithUnionReturnType[T](): T | A + } + + val m = mock[TraitWithUnionReturnType] + + val obj = new B {} + + (() => m.methodWithUnionReturnType[B]()).expects().returns(obj) + + m.methodWithUnionReturnType[B]() shouldBe obj + } + + it("mock intersection return type") { + + trait A + + trait B + + trait TraitWithIntersectionReturnType { + + def methodWithIntersectionReturnType[T](): A & T + } + + val m = mock[TraitWithIntersectionReturnType] + + val obj = new A with B {} + + (() => m.methodWithIntersectionReturnType[B]()).expects().returns(obj) + + m.methodWithIntersectionReturnType[B]() shouldBe obj + } + + it("mock intersection|union types with type constructors") { + + trait A[T] + + trait B + + trait C + + trait ComplexUnionIntersectionCases { + + def complexMethod1[T](x: A[T] & T): A[T] & T + def complexMethod2[T](x: A[A[T]] | T): A[T] | T + def complexMethod3[F[_], T](x: F[A[T] & F[T]] | T & A[F[T]]): F[T] & T + def complexMethod4[T](x: A[B & C] ): A[B & C] + def complexMethod5[T](x: A[B | A[C]]): A[B | C] + } + + val m = mock[ComplexUnionIntersectionCases] + + val obj = new A[B] with B {} + val obj2 = new A[A[B]] with B {} + + (m.complexMethod1[B] _).expects(obj).returns(obj) + (m.complexMethod2[B] _).expects(obj2).returns(new A[B] {}) + + m.complexMethod1[B](obj) + m.complexMethod2[B](obj2) + } + it("mock methods returning function") { trait Test { def method(x: Int): Int => String