diff --git a/modules/cats/src/main/scala/derevo/cats/eqv.scala b/modules/cats/src/main/scala/derevo/cats/eqv.scala index 37fab761..53abcb18 100644 --- a/modules/cats/src/main/scala/derevo/cats/eqv.scala +++ b/modules/cats/src/main/scala/derevo/cats/eqv.scala @@ -1,9 +1,10 @@ package derevo.cats import cats.Eq +import derevo.{Derivation, NewTypeDerivation} import magnolia.{CaseClass, Magnolia, SealedTrait} -import derevo.Derivation -import derevo.NewTypeDerivation + +import scala.reflect.macros.blackbox object eqv extends Derivation[Eq] with NewTypeDerivation[Eq] { type Typeclass[T] = Eq[T] @@ -25,6 +26,31 @@ object eqv extends Derivation[Eq] with NewTypeDerivation[Eq] { implicit def instance[T]: Eq[T] = macro Magnolia.gen[T] + def apply[T](eqFields: String*): Eq[T] = macro eqExtendedImpl[T] + + def eqExtendedImpl[T: c.WeakTypeTag](c: blackbox.Context)(eqFields: c.Expr[String]*): c.Tree = { + import c.universe._ + + val T = weakTypeOf[T] + val eqFieldsSet = eqFields.map(_.tree).collect { + case Literal(Constant(field: String)) => field + }.toSet + + val comparisons = T.decls.collect { + case m: MethodSymbol if m.isCaseAccessor && eqFieldsSet.contains(m.name.toString) => + val name = m.name.toTermName + q"Eq[${m.typeSignature}].eqv(x.$name, y.$name)" + } + + q""" + new Eq[$T] { + def eqv(x: $T, y: $T): Boolean = { + ..$comparisons + } + } + """ + } + object universal extends Derivation[Eq] { implicit def instance[T]: Eq[T] = Eq.fromUniversalEquals[T] } diff --git a/modules/cats/src/test/scala/derevo/cats/EqSpec.scala b/modules/cats/src/test/scala/derevo/cats/EqSpec.scala index 200927a9..6a1483ca 100644 --- a/modules/cats/src/test/scala/derevo/cats/EqSpec.scala +++ b/modules/cats/src/test/scala/derevo/cats/EqSpec.scala @@ -34,6 +34,15 @@ class EqSpec extends AnyFreeSpec { assert(Eq[Qux].eqv(Qux(1), Qux(1))) assert(Eq[Qux].neqv(Qux(1), Qux(-1))) } + + "with ignoring fields" in { + @derive(eqv("bar")) + case class Foo(bar: Int, baz: Int) + + assert(Eq[Foo].eqv(Foo(1, 2), Foo(1, 5))) + assert(Eq[Foo].neqv(Foo(2, 2), Foo(3, 2))) + assert(Eq[Foo].neqv(Foo(2, 2), Foo(3, 1))) + } } } }