Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WX-1531] Struct Literal Type Checking #7402

Merged
merged 13 commits into from
Apr 10, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import wom.types.WomType

@typeclass
trait TypeEvaluator[A] {
def evaluateType(a: A, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(implicit
def evaluateType(a: A,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType]
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package wdl.transforms.biscayne.linking.expression.types

import cats.data.Validated.{Invalid, Valid}
import cats.implicits.{catsSyntaxTuple2Semigroupal, catsSyntaxTuple3Semigroupal}
import cats.syntax.validated._
import common.validation.ErrorOr._
Expand All @@ -12,7 +13,10 @@ import wom.types._

object BiscayneTypeEvaluators {
implicit val keysFunctionEvaluator: TypeEvaluator[Keys] = new TypeEvaluator[Keys] {
override def evaluateType(a: Keys, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(implicit
override def evaluateType(a: Keys,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] =
validateParamType(a.param, linkedValues, WomMapType(WomAnyType, WomAnyType)) flatMap {
Expand All @@ -22,7 +26,10 @@ object BiscayneTypeEvaluators {
}

implicit val asMapFunctionEvaluator: TypeEvaluator[AsMap] = new TypeEvaluator[AsMap] {
override def evaluateType(a: AsMap, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(implicit
override def evaluateType(a: AsMap,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] =
validateParamType(a.param, linkedValues, WomArrayType(WomPairType(WomAnyType, WomAnyType))) flatMap {
Expand All @@ -34,7 +41,10 @@ object BiscayneTypeEvaluators {
}

implicit val asPairsFunctionEvaluator: TypeEvaluator[AsPairs] = new TypeEvaluator[AsPairs] {
override def evaluateType(a: AsPairs, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(implicit
override def evaluateType(a: AsPairs,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] =
validateParamType(a.param, linkedValues, WomMapType(WomAnyType, WomAnyType)) flatMap {
Expand All @@ -44,8 +54,11 @@ object BiscayneTypeEvaluators {
}

implicit val collectByKeyFunctionEvaluator: TypeEvaluator[CollectByKey] = new TypeEvaluator[CollectByKey] {
override def evaluateType(a: CollectByKey, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(
implicit expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
override def evaluateType(a: CollectByKey,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] =
validateParamType(a.param, linkedValues, WomArrayType(WomPairType(WomAnyType, WomAnyType))) flatMap {
case WomArrayType(WomPairType(x: WomPrimitiveType, y)) => WomMapType(x, WomArrayType(y)).validNel
Expand All @@ -67,29 +80,38 @@ object BiscayneTypeEvaluators {
}

implicit val minFunctionEvaluator: TypeEvaluator[Min] = new TypeEvaluator[Min] {
override def evaluateType(a: Min, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(implicit
override def evaluateType(a: Min,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] = {
val type1 = expressionTypeEvaluator.evaluateType(a.arg1, linkedValues)
val type2 = expressionTypeEvaluator.evaluateType(a.arg1, linkedValues)
val type1 = expressionTypeEvaluator.evaluateType(a.arg1, linkedValues, typeAliases)
val type2 = expressionTypeEvaluator.evaluateType(a.arg1, linkedValues, typeAliases)

(type1, type2) flatMapN resultTypeOfIntVsFloat("min")
}
}

implicit val maxFunctionEvaluator: TypeEvaluator[Max] = new TypeEvaluator[Max] {
override def evaluateType(a: Max, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(implicit
override def evaluateType(a: Max,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] = {
val type1 = expressionTypeEvaluator.evaluateType(a.arg1, linkedValues)
val type2 = expressionTypeEvaluator.evaluateType(a.arg1, linkedValues)
val type1 = expressionTypeEvaluator.evaluateType(a.arg1, linkedValues, typeAliases)
val type2 = expressionTypeEvaluator.evaluateType(a.arg1, linkedValues, typeAliases)

(type1, type2) flatMapN resultTypeOfIntVsFloat("max")
}
}

implicit val sepFunctionEvaluator: TypeEvaluator[Sep] = new TypeEvaluator[Sep] {
override def evaluateType(a: Sep, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(implicit
override def evaluateType(a: Sep,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] =
validateParamType(a.arg2, linkedValues, WomArrayType(WomAnyType)) flatMap {
Expand All @@ -103,7 +125,10 @@ object BiscayneTypeEvaluators {
}

implicit val subPosixFunctionEvaluator: TypeEvaluator[SubPosix] = new TypeEvaluator[SubPosix] {
override def evaluateType(a: SubPosix, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(implicit
override def evaluateType(a: SubPosix,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] =
(validateParamType(a.input, linkedValues, WomSingleFileType),
Expand All @@ -113,7 +138,10 @@ object BiscayneTypeEvaluators {
}

implicit val suffixFunctionEvaluator: TypeEvaluator[Suffix] = new TypeEvaluator[Suffix] {
override def evaluateType(a: Suffix, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(implicit
override def evaluateType(a: Suffix,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] =
(validateParamType(a.suffix, linkedValues, WomStringType),
Expand All @@ -122,7 +150,10 @@ object BiscayneTypeEvaluators {
}

implicit val quoteFunctionEvaluator: TypeEvaluator[Quote] = new TypeEvaluator[Quote] {
override def evaluateType(a: Quote, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(implicit
override def evaluateType(a: Quote,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] =
validateParamType(a.param, linkedValues, WomArrayType(WomAnyType)) flatMap {
Expand All @@ -136,7 +167,10 @@ object BiscayneTypeEvaluators {
}

implicit val sQuoteFunctionEvaluator: TypeEvaluator[SQuote] = new TypeEvaluator[SQuote] {
override def evaluateType(a: SQuote, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(implicit
override def evaluateType(a: SQuote,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] =
validateParamType(a.param, linkedValues, WomArrayType(WomAnyType)) flatMap {
Expand All @@ -150,7 +184,10 @@ object BiscayneTypeEvaluators {
}

implicit val unzipFunctionEvaluator: TypeEvaluator[Unzip] = new TypeEvaluator[Unzip] {
override def evaluateType(a: Unzip, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(implicit
override def evaluateType(a: Unzip,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] =
validateParamType(a.param, linkedValues, WomArrayType(WomPairType(WomAnyType, WomAnyType))) flatMap {
Expand All @@ -161,14 +198,101 @@ object BiscayneTypeEvaluators {
}

implicit val structLiteralTypeEvaluator: TypeEvaluator[StructLiteral] = new TypeEvaluator[StructLiteral] {
override def evaluateType(a: StructLiteral, linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle])(
implicit expressionTypeEvaluator: TypeEvaluator[ExpressionElement]

// Is the evaluated type allowed to be assigned to the expectedType?
def areTypesAssignable(evaluatedType: WomType, expectedType: WomType): Boolean =
// NB: This check is a little looser than we'd like it to be.
// For example, String is coercible to Int (Int i = "1" is OK)
// It's not until we actually evaluate the value of the string that we can know if that coercion succeeded or not. (Int i = "orange" will fail)
// We don't know whether the user has provided "1" or "orange" at this stage.
// This is OK as-is because the value evaluators do the coercing and throw meaningful errors if the coercion fails.
THWiseman marked this conversation as resolved.
Show resolved Hide resolved
expectedType.isCoerceableFrom(evaluatedType)

// Helper method to check something (maybe) found in the struct literal to something (maybe) found in the struct definition.
def checkIfMemberIsValid(typeName: String,
memberName: String,
evaluatedType: Option[WomType],
expectedType: Option[WomType]
): ErrorOr[WomType] =
// This works fine, but is not yet a strong enough type check for the WDL 1.1 spec
// (i.e. users are able to instantiate struct literals with k/v pairs that aren't actually in the struct definition, without an error being thrown.)
// We want to add extra validation here, and return a WomCompositeType that matches the struct definition of everything is OK.
// Note that users are allowed to omit optional k/v pairs in their literal.
// This requires some extra work to be done in a subsequent PR.
WomObjectType.validNel
evaluatedType match {
case Some(evaluated) =>
expectedType match {
case Some(expected) =>
if (areTypesAssignable(evaluated, expected)) evaluated.validNel
else
s"$typeName.$memberName expected to be ${expected.friendlyName}. Found ${evaluated.friendlyName}.".invalidNel
case None => s"Type $typeName does not have a member called $memberName.".invalidNel
}
case None => s"Error evaluating the type of ${typeName}.${memberName}.".invalidNel
}

// For each member in the literal, check that it exists in the struct definition and is the expected type.
def checkMembersAgainstDefinition(a: StructLiteral,
jgainerdewar marked this conversation as resolved.
Show resolved Hide resolved
structDefinition: WomCompositeType,
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
): ErrorOr[WomCompositeType] = {
val checkedMembers: Map[String, ErrorOr[WomType]] = a.elements.map { case (memberKey, memberExpressionElement) =>
val evaluatedType =
expressionTypeEvaluator.evaluateType(memberExpressionElement, linkedValues, typeAliases).toOption
val expectedType = structDefinition.typeMap.get(memberKey)
(memberKey, checkIfMemberIsValid(a.structTypeName, memberKey, evaluatedType, expectedType))
}

val (errors, validatedTypes) = checkedMembers.partition { case (_, errorOr) =>
errorOr match {
case Invalid(_) => true
case Valid(_) => false
}
}

errors.headOption match {
case Some((_, Invalid(_))) =>
THWiseman marked this conversation as resolved.
Show resolved Hide resolved
errors.collect { case (_, Invalid(e)) => e.toList.mkString }.toList.mkString(",").invalidNel
THWiseman marked this conversation as resolved.
Show resolved Hide resolved
case _ =>
val types = validatedTypes.collect { case (key, Valid(v)) => (key, v) }
WomCompositeType(types, Some(a.structTypeName)).validNel
}
}

// For every member in the definition, if that member isn't optional, confirm that it is also in the struct literal.
THWiseman marked this conversation as resolved.
Show resolved Hide resolved
def checkForMissingMembers(foundMembers: Map[String, WomType],
structDefinition: WomCompositeType
): Option[String] = {
THWiseman marked this conversation as resolved.
Show resolved Hide resolved
val errors: Iterable[String] = structDefinition.typeMap flatMap { case (memberName, memberType) =>
memberType match {
case WomOptionalType(_) => None
case _ =>
if (!foundMembers.contains(memberName)) Some(s"Expected member ${memberName} not found. ")
else None
}
}
errors match {
case Nil => None
case _ => Some(errors.mkString)
}
}
override def evaluateType(a: StructLiteral,
jgainerdewar marked this conversation as resolved.
Show resolved Hide resolved
linkedValues: Map[UnlinkedConsumedValueHook, GeneratedValueHandle],
typeAliases: Map[String, WomType]
)(implicit
expressionTypeEvaluator: TypeEvaluator[ExpressionElement]
): ErrorOr[WomType] = {
val structDefinition = typeAliases.get(a.structTypeName)
THWiseman marked this conversation as resolved.
Show resolved Hide resolved
structDefinition match {
case Some(definition) =>
definition match {
case compositeType: WomCompositeType =>
checkMembersAgainstDefinition(a, compositeType, linkedValues, typeAliases).flatMap { foundMembers =>
checkForMissingMembers(foundMembers.typeMap, compositeType) match {
case Some(error) => error.invalidNel
case _ => compositeType.validNel
}
}
case _ => s"Unexpected error while parsing ${a.structTypeName}".invalidNel
THWiseman marked this conversation as resolved.
Show resolved Hide resolved
}
case None => s"Could not find Struct Definition for type ${a.structTypeName}".invalidNel
}
}
}
}
Loading
Loading