Skip to content

Commit

Permalink
Discriminator failure is now an internal error
Browse files Browse the repository at this point in the history
  • Loading branch information
milessabin committed Aug 17, 2024
1 parent bd8c15e commit 81d8c77
Show file tree
Hide file tree
Showing 15 changed files with 140 additions and 165 deletions.
16 changes: 9 additions & 7 deletions modules/circe/src/main/scala/circemapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,24 @@ trait CirceMappingLike[F[_]] extends Mapping[F] {
case _ => Result.internalError(s"Expected Nullable type, found $focus for $tpe")
}

def narrowsTo(subtpe: TypeRef): Boolean =
subtpe <:< tpe &&
def narrowsTo(subtpe: TypeRef): Result[Boolean] =
(subtpe <:< tpe &&
((subtpe.dealias, focus.asObject) match {
case (nt: TypeWithFields, Some(obj)) =>
nt.fields.forall { f =>
f.tpe.isNullable || obj.contains(f.name)
} && obj.keys.forall(nt.hasField)

case _ => false
})
})).success

def narrow(subtpe: TypeRef): Result[Cursor] =
if (narrowsTo(subtpe))
mkChild(context.asType(subtpe)).success
else
Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
narrowsTo(subtpe).flatMap { n =>
if (n)
mkChild(context.asType(subtpe)).success
else
Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
}

def field(fieldName: String, resultName: Option[String]): Result[Cursor] = {
val localField =
Expand Down
6 changes: 3 additions & 3 deletions modules/core/src/main/scala/cursor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ trait Cursor {
def isDefined: Result[Boolean]

/** Is the value at this `Cursor` narrowable to `subtpe`? */
def narrowsTo(subtpe: TypeRef): Boolean
def narrowsTo(subtpe: TypeRef): Result[Boolean]

/**
* Yield a `Cursor` corresponding to the value at this `Cursor` narrowed to
Expand Down Expand Up @@ -251,7 +251,7 @@ object Cursor {
def isDefined: Result[Boolean] =
Result.internalError(s"Expected Nullable type, found $focus for $tpe")

def narrowsTo(subtpe: TypeRef): Boolean = false
def narrowsTo(subtpe: TypeRef): Result[Boolean] = false.success

def narrow(subtpe: TypeRef): Result[Cursor] =
Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
Expand Down Expand Up @@ -290,7 +290,7 @@ object Cursor {

def isDefined: Result[Boolean] = underlying.isDefined

def narrowsTo(subtpe: TypeRef): Boolean = underlying.narrowsTo(subtpe)
def narrowsTo(subtpe: TypeRef): Result[Boolean] = underlying.narrowsTo(subtpe)

def narrow(subtpe: TypeRef): Result[Cursor] = underlying.narrow(subtpe)

Expand Down
42 changes: 23 additions & 19 deletions modules/core/src/main/scala/mapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ abstract class Mapping[F[_]] {
*/
protected final def mkCursorForField(parent: Cursor, fieldName: String, resultName: Option[String]): Result[Cursor] = {
typeMappings.fieldMapping(parent, fieldName).
toResultOrError(s"No mapping for field '$fieldName' for type ${parent.tpe}").
flatMap(_.toResultOrError(s"No mapping for field '$fieldName' for type ${parent.tpe}")).
flatMap {
case (np, fm) =>
val fieldContext = np.context.forFieldOrAttribute(fieldName, resultName)
Expand Down Expand Up @@ -201,15 +201,17 @@ abstract class Mapping[F[_]] {
* Yields the `FieldMapping` associated with `fieldName` in the runtime context
* determined by the given `Cursor`, if any.
*/
def fieldMapping(parent: Cursor, fieldName: String): Option[(Cursor, FieldMapping)] = {
// TODO: Might simplify things if we drop the Option here
def fieldMapping(parent: Cursor, fieldName: String): Result[Option[(Cursor, FieldMapping)]] = {
val context = parent.context
fieldIndex(context).flatMap(_.get(fieldName)).flatMap {
case ifm: InheritedFieldMapping =>
ifm.select(parent.context).map((parent, _))
case pfm: PolymorphicFieldMapping =>
fieldIndex(context).flatMap(_.get(fieldName)) match {
case Some(ifm: InheritedFieldMapping) =>
ifm.select(parent.context).map((parent, _)).success
case Some(pfm: PolymorphicFieldMapping) =>
pfm.select(parent)
case fm =>
Some((parent, fm))
case Some(fm) =>
Option((parent, fm)).success
case None => None.success
}
}

Expand Down Expand Up @@ -539,18 +541,20 @@ abstract class Mapping[F[_]] {
def hidden: Boolean = false
def subtree: Boolean = false

def select(cursor: Cursor): Option[(Cursor, FieldMapping)] = {
def select(cursor: Cursor): Result[Option[(Cursor, FieldMapping)]] = {
val applicable =
candidates.mapFilter {
case (pred, fm) if cursor.narrowsTo(schema.uncheckedRef(pred.tpe)) =>
for {
nc <- cursor.narrow(schema.uncheckedRef(pred.tpe)).toOption
prio <- pred(nc.context)
} yield (prio, (nc, fm))
case _ =>
None
candidates.traverseFilter {
case (pred, fm) =>
cursor.narrowsTo(schema.uncheckedRef(pred.tpe)).flatMap { narrows =>
if (narrows)
for {
nc <- cursor.narrow(schema.uncheckedRef(pred.tpe))
} yield pred(nc.context).map(prio => (prio, (nc, fm)))
else
None.success
}
}
applicable.maxByOption(_._1).map(_._2)
applicable.map(_.maxByOption(_._1).map(_._2))
}

def select(context: Context): Option[FieldMapping] = {
Expand Down Expand Up @@ -1282,7 +1286,7 @@ abstract class Mapping[F[_]] {
case _ => Result.internalError(s"Not nullable at ${context.path}")
}

def narrowsTo(subtpe: TypeRef): Boolean = false
def narrowsTo(subtpe: TypeRef): Result[Boolean] = false.success
def narrow(subtpe: TypeRef): Result[Cursor] =
Result.failure(s"Cannot narrow $tpe to $subtpe")

Expand Down
45 changes: 24 additions & 21 deletions modules/core/src/main/scala/queryinterpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,23 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) {
siblings.flatTraverse(query => runFields(query, tpe, cursor))

case Introspect(schema, s@Select("__typename", _, Empty)) if tpe.isNamed =>
(tpe.dealias match {
case o: ObjectType => Some(o.name)
val fail = Result.failure(s"'__typename' cannot be applied to non-selectable type '$tpe'")
def mkTypeNameFields(name: String) =
List((s.resultName, ProtoJson.fromJson(Json.fromString(name)))).success
def mkTypeNameFieldsOrFail(name: Option[String]) =
name.map(mkTypeNameFields).getOrElse(fail)

tpe.dealias match {
case o: ObjectType => mkTypeNameFields(o.name)
case i: InterfaceType =>
(schema.implementations(i).collectFirst {
case o if cursor.narrowsTo(schema.uncheckedRef(o)) => o.name
})
schema.implementations(i).collectFirstSomeM { o =>
cursor.narrowsTo(schema.uncheckedRef(o)).ifF(Some(o.name), None)
}.flatMap(mkTypeNameFieldsOrFail)
case u: UnionType =>
(u.members.map(_.dealias).collectFirst {
case nt: NamedType if cursor.narrowsTo(schema.uncheckedRef(nt)) => nt.name
})
case _ => None
}) match {
case Some(name) =>
List((s.resultName, ProtoJson.fromJson(Json.fromString(name)))).success
case None =>
Result.failure(s"'__typename' cannot be applied to non-selectable type '$tpe'")
u.members.map(_.dealias).collectFirstSomeM { nt =>
cursor.narrowsTo(schema.uncheckedRef(nt)).ifF(Some(nt.name), None)
}.flatMap(mkTypeNameFieldsOrFail)
case _ => fail
}

case sel: Select if tpe.isNullable =>
Expand Down Expand Up @@ -250,13 +251,15 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) {
value <- runValue(child, fieldTpe, c)
} yield List((sel.resultName, value))

case Narrow(tp1, child) if cursor.narrowsTo(tp1) =>
for {
c <- cursor.narrow(tp1)
fields <- runFields(child, tp1, c)
} yield fields

case _: Narrow => Nil.success
case Narrow(tp1, child) =>
cursor.narrowsTo(tp1).flatMap { n =>
if (!n) Nil.success
else
for {
c <- cursor.narrow(tp1)
fields <- runFields(child, tp1, c)
} yield fields
}

case c@Component(_, _, cont) =>
for {
Expand Down
16 changes: 9 additions & 7 deletions modules/core/src/main/scala/valuemapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,22 @@ trait ValueMappingLike[F[_]] extends Mapping[F] {
case _ => Result.internalError(s"Expected Nullable type, found $focus for $tpe")
}

def narrowsTo(subtpe: TypeRef): Boolean =
subtpe <:< tpe &&
def narrowsTo(subtpe: TypeRef): Result[Boolean] =
(subtpe <:< tpe &&
objectMapping(context.asType(subtpe)).exists {
case ValueObjectMapping(_, _, classTag) =>
classTag.runtimeClass.isInstance(focus)
case _ => false
}
}).success


def narrow(subtpe: TypeRef): Result[Cursor] =
if (narrowsTo(subtpe))
mkChild(context.asType(subtpe)).success
else
Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
narrowsTo(subtpe).flatMap { n =>
if(n)
mkChild(context.asType(subtpe)).success
else
Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
}

def field(fieldName: String, resultName: Option[String]): Result[Cursor] =
mkCursorForField(this, fieldName, resultName)
Expand Down
10 changes: 6 additions & 4 deletions modules/generic/src/main/scala-2/genericmapping2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,14 @@ trait ScalaVersionSpecificGenericMappingLike[F[_]] extends Mapping[F] { self: Ge
override def field(fieldName: String, resultName: Option[String]): Result[Cursor] =
cursor.field(fieldName, resultName) orElse mkCursorForField(this, fieldName, resultName)

override def narrowsTo(subtpe: TypeRef): Boolean =
subtpe <:< tpe && rtpe <:< subtpe
override def narrowsTo(subtpe: TypeRef): Result[Boolean] =
(subtpe <:< tpe && rtpe <:< subtpe).success

override def narrow(subtpe: TypeRef): Result[Cursor] =
if (narrowsTo(subtpe)) copy(tpe0 = subtpe).success
else Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
narrowsTo(subtpe).flatMap { n =>
if (n) copy(tpe0 = subtpe).success
else Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
}
}
}
}
Expand Down
10 changes: 6 additions & 4 deletions modules/generic/src/main/scala-3/genericmapping3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ trait ScalaVersionSpecificGenericMappingLike[F[_]] extends Mapping[F] { self: Ge
override def field(fieldName: String, resultName: Option[String]): Result[Cursor] =
cursor.field(fieldName, resultName) orElse mkCursorForField(this, fieldName, resultName)

override def narrowsTo(subtpe: TypeRef): Boolean =
subtpe <:< tpe && rtpe <:< subtpe
override def narrowsTo(subtpe: TypeRef): Result[Boolean] =
(subtpe <:< tpe && rtpe <:< subtpe).success

override def narrow(subtpe: TypeRef): Result[Cursor] =
if (narrowsTo(subtpe)) copy(tpe0 = subtpe).success
else Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
narrowsTo(subtpe).flatMap { n =>
if (n) copy(tpe0 = subtpe).success
else Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
}
}
}
}
66 changes: 32 additions & 34 deletions modules/sql/shared/src/main/scala/SqlMapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ trait SqlMappingLike[F[_]] extends CirceMappingLike[F] with SqlModule[F] { self
/** Discriminator for the branches of an interface/union */
trait SqlDiscriminator {
/** yield a predicate suitable for filtering row corresponding to the supplied type */
def narrowPredicate(tpe: Type): Option[Predicate]
def narrowPredicate(tpe: Type): Result[Predicate]

/** compute the concrete type of the value at the cursor */
def discriminate(cursor: Cursor): Result[Type]
Expand Down Expand Up @@ -3066,11 +3066,6 @@ trait SqlMappingLike[F[_]] extends CirceMappingLike[F] with SqlModule[F] { self
assert(supertpe.underlying.isInterface || supertpe.underlying.isUnion || (subtpes.sizeCompare(1) == 0 && subtpes.head =:= supertpe))
subtpes.foreach(subtpe => assert(subtpe <:< supertpe))

val discriminator = discriminatorForType(context)
val narrowPredicates = subtpes.map { subtpe =>
(subtpe, discriminator.flatMap(_.discriminator.narrowPredicate(subtpe)))
}

val exhaustive = schema.exhaustive(supertpe, subtpes)
val exclusive = default == Empty
val allSimple = narrows.forall(narrow => isSimple(narrow.child))
Expand All @@ -3097,24 +3092,24 @@ trait SqlMappingLike[F[_]] extends CirceMappingLike[F] with SqlModule[F] { self

val dquery =
if(exhaustive) EmptySqlQuery(context).success
else {
val allPreds = narrowPredicates.collect {
case (_, Some(pred)) => pred
}
if (exclusive) {
for {
parentTable <- parentTableForType(context)
allPreds0 <- allPreds.traverse(pred => SqlQuery.contextualiseWhereTerms(context, parentTable, pred).map(Not(_)))
} yield {
else
discriminatorForType(context).map { disc =>
subtpes.traverse(disc.discriminator.narrowPredicate)
}.getOrElse(Nil.success).flatMap { allPreds =>
if (exclusive) {
for {
parentTable <- parentTableForType(context)
allPreds0 <- allPreds.traverse(pred => SqlQuery.contextualiseWhereTerms(context, parentTable, pred).map(Not(_)))
} yield {
val defaultPredicate = And.combineAll(allPreds0)
SqlSelect(context, Nil, parentTable, extraCols, Nil, defaultPredicate :: Nil, Nil, None, None, Nil, true, false)
}
} else {
val allPreds0 = allPreds.map(Not(_))
val defaultPredicate = And.combineAll(allPreds0)
SqlSelect(context, Nil, parentTable, extraCols, Nil, defaultPredicate :: Nil, Nil, None, None, Nil, true, false)
loop(Filter(defaultPredicate, default), context, parentConstraints, exposeJoins).flatMap(_.withContext(context, extraCols, Nil))
}
} else {
val allPreds0 = allPreds.map(Not(_))
val defaultPredicate = And.combineAll(allPreds0)
loop(Filter(defaultPredicate, default), context, parentConstraints, exposeJoins).flatMap(_.withContext(context, extraCols, Nil))
}
}

for {
dquery0 <- dquery
Expand Down Expand Up @@ -3543,30 +3538,33 @@ trait SqlMappingLike[F[_]] extends CirceMappingLike[F] with SqlModule[F] { self
case _ => Result.internalError(s"Not nullable at ${context.path}")
}

def narrowsTo(subtpe: TypeRef): Boolean = {
def narrowsTo(subtpe: TypeRef): Result[Boolean] = {
def check(ctpe: Type): Boolean =
if (ctpe =:= tpe) asTable.map(table => mapped.narrowsTo(context.asType(subtpe), table)).toOption.getOrElse(false)
else ctpe <:< subtpe

(subtpe <:< tpe) &&
(discriminatorForType(context) match {
case Some(disc) => disc.discriminator.discriminate(this).map(check).getOrElse(false)
case _ => check(tpe)
})
if (!(subtpe <:< tpe)) false.success
else
discriminatorForType(context) match {
case Some(disc) => disc.discriminator.discriminate(this).map(check)
case _ => check(tpe).success
}
}

def narrow(subtpe: TypeRef): Result[Cursor] = {
if (narrowsTo(subtpe)) {
val narrowedContext = context.asType(subtpe)
asTable.map { table =>
mkChild(context = narrowedContext, focus = mapped.narrow(narrowedContext, table))
}
} else Result.internalError(s"Cannot narrow $tpe to $subtpe")
narrowsTo(subtpe).flatMap { n =>
if (n) {
val narrowedContext = context.asType(subtpe)
asTable.map { table =>
mkChild(context = narrowedContext, focus = mapped.narrow(narrowedContext, table))
}
} else Result.internalError(s"Cannot narrow $tpe to $subtpe")
}
}

def field(fieldName: String, resultName: Option[String]): Result[Cursor] = {
val localField =
typeMappings.fieldMapping(this, fieldName) match {
typeMappings.fieldMapping(this, fieldName).flatMap {
case Some((np, _: SqlJson)) =>
val fieldContext = np.context.forFieldOrAttribute(fieldName, resultName)
val fieldTpe = fieldContext.tpe
Expand Down
8 changes: 4 additions & 4 deletions modules/sql/shared/src/test/scala/SqlInterfacesMapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,14 @@ trait SqlInterfacesMapping[F[_]] extends SqlTestMapping[F] { self =>
}
}

def narrowPredicate(subtpe: Type): Option[Predicate] = {
def mkPredicate(tpe: EntityType): Option[Predicate] =
Some(Eql(EType / "entityType", Const(tpe)))
def narrowPredicate(subtpe: Type): Result[Predicate] = {
def mkPredicate(tpe: EntityType): Result[Predicate] =
Eql(EType / "entityType", Const(tpe)).success

subtpe match {
case FilmType => mkPredicate(EntityType.Film)
case SeriesType => mkPredicate(EntityType.Series)
case _ => None
case _ => Result.internalError(s"Invalid discriminator: $subtpe")
}
}
}
Expand Down
Loading

0 comments on commit 81d8c77

Please sign in to comment.