Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W-14608042: parse root types as types #1915

Merged
merged 7 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1285,8 +1285,8 @@ object APIRawValidations extends CommonValidationDefinitions {
),
AMFValidation(
uri = amfParser("reserved-endpoints"),
owlClass = apiContract("EndPoint"),
owlProperty = apiContract("path"),
owlClass = apiContract("WebAPI"),
owlProperty = apiContract("EndPoints"),
constraint = shape("reservedEndpoints"),
message = "Endpoint is reserved by Federation"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package amf.apicontract.internal.validation.shacl

import amf.apicontract.client.scala.model.domain.{EndPoint, Request}
import amf.apicontract.client.scala.model.domain.api.Api
import amf.apicontract.client.scala.model.domain.api.{Api, WebApi}
import amf.apicontract.client.scala.model.domain.security.{OAuth2Settings, OpenIdConnectSettings, SecurityScheme}
import amf.apicontract.internal.metamodel.domain._
import amf.apicontract.internal.metamodel.domain.api.BaseApiModel
Expand Down Expand Up @@ -88,28 +88,31 @@ object APICustomShaclFunctions extends BaseCustomShaclFunctions {
new CustomShaclFunction {
override val name: String = "reservedEndpoints"
override def run(element: AmfObject, validate: Option[ValidationInfo] => Unit): Unit = {
val reserved = Set("_service", "_entities")
val endpoint = element.asInstanceOf[EndPoint]
endpoint.path
.option()
.map(_.stripPrefix("/query/").stripPrefix("/mutation/").stripPrefix("/subscription/"))
.flatMap {
case path if reserved.contains(path) =>
val rootKind = {
val name = endpoint.name.value()
val end = name.indexOf(".")
name.substring(0, end)
}
Some(
ValidationInfo(
EndPointModel.Path,
Some(s"Cannot declare field '$path' in type $rootKind since it is reserved by Federation"),
Some(element.annotations)
val reserved = Set("_service", "_entities")
val api = element.asInstanceOf[WebApi]
val endpoints = api.endPoints
endpoints.foreach { endpoint =>
endpoint.path
.option()
.map(_.stripPrefix("/query/").stripPrefix("/mutation/").stripPrefix("/subscription/"))
.flatMap {
case path if reserved.contains(path) =>
val rootKind = {
val name = endpoint.name.value()
val end = name.indexOf(".")
name.substring(0, end)
}
Some(
ValidationInfo(
EndPointModel.Path,
Some(s"Cannot declare field '$path' in type $rootKind since it is reserved by Federation"),
Some(element.annotations)
)
)
)
case _ => None
}
.foreach(res => validate(Some(res)))
case _ => None
}
.foreach(res => validate(Some(res)))
}
}
},
new CustomShaclFunction {
Expand Down Expand Up @@ -468,9 +471,12 @@ object APICustomShaclFunctions extends BaseCustomShaclFunctions {
element match {
case d: CustomDomainProperty =>
if (hasIntrospectionName(d)) validate(Some(ValidationInfo(CustomDomainPropertyModel.Name)))
case t: Shape => if (hasIntrospectionName(t)) validate(Some(ValidationInfo(AnyShapeModel.Name)))
case t: Shape =>
if (hasIntrospectionName(t)) validate(Some(ValidationInfo(AnyShapeModel.Name)))
case n: NamedDomainElement =>
if (hasIntrospectionName(n)) validate(Some(ValidationInfo(NameFieldSchema.Name)))
if (hasIntrospectionName(n)) {
validate(Some(ValidationInfo(NameFieldSchema.Name)))
}
case _ => // ignore
}
}
Expand Down Expand Up @@ -764,7 +770,7 @@ object APICustomShaclFunctions extends BaseCustomShaclFunctions {

// Obtained from the BNF in: https://tools.ietf.org/html/rfc7230#section-3.2
private def isInvalidHttpHeaderName(name: String): Boolean =
!name.matches("^[!#$%&'*\\+\\-\\.^\\_\\`\\|\\~0-9a-zA-Z]+$")
!name.matches("^[!#$%&'*+\\-.^_`|~0-9a-zA-Z]+$")

private def hasIntrospectionName(element: NamedDomainElement): Boolean =
element.name.nonNull && element.name.value().startsWith("__")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ object GraphQLDirectiveLocationValidator {
result match {
case Some((actual, false)) =>
val message = buildErrorMessage(directiveApplication, element, actual.name)
Some(ValidationInfo(DomainElementModel.CustomDomainProperties, Some(message), Some(directiveApplication.annotations)))
Some(
ValidationInfo(
DomainElementModel.CustomDomainProperties,
Some(message),
Some(directiveApplication.annotations)
)
)
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package amf.apicontract.internal.validation.shacl.graphql

import amf.apicontract.internal.metamodel.domain.EndPointModel
import amf.apicontract.internal.metamodel.domain.api.WebApiModel
import amf.core.client.scala.model.document.FieldsFilter
import amf.core.client.scala.model.domain.AmfElement
import amf.core.internal.metamodel.document.DocumentModel
import amf.core.internal.parser.domain.{FieldEntry, Fields}

/** Scope does not include external references (like FieldsFilter.Local) and also removes endpoints to avoid validating
* them twice in graphql (because they are also parsed as types)
*/
object GraphQLFieldsFilter extends FieldsFilter {

override def filter(fields: Fields): List[AmfElement] =
fields
.fields()
.filter(_.field != DocumentModel.References) // remove external refs
.filter(_.field != WebApiModel.EndPoints) // remove endpoints
.map(_.element)
.toList
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package amf.apicontract.internal.validation.shacl.graphql

import amf.core.client.scala.model.domain.AmfElement
import amf.core.client.scala.traversal.iterator.{
AmfIterator,
DomainElementIterator,
IdCollector,
IteratorStrategy,
VisitedCollector
}

object GraphQLIteratorStrategy extends IteratorStrategy {
override def iterator(elements: List[AmfElement], visited: VisitedCollector = IdCollector()): AmfIterator =
DomainElementIterator.withFilter(elements, visited, GraphQLFieldsFilter)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package amf.apicontract.internal.validation.shacl.graphql

import amf.apicontract.internal.validation.plugin.BaseApiValidationPlugin
import amf.apicontract.internal.validation.shacl.APICustomShaclFunctions
import amf.core.client.common.validation.{ProfileName, ProfileNames}
import amf.core.client.common.{HighPriority, PluginPriority}
import amf.core.client.scala.model.document.BaseUnit
import amf.core.client.scala.validation.AMFValidationReport
import amf.core.internal.plugins.validation.ValidationOptions
import amf.shapes.internal.validation.shacl.BaseShaclModelValidationPlugin
import amf.validation.internal.shacl.custom.CustomShaclValidator
import amf.validation.internal.shacl.custom.CustomShaclValidator.CustomShaclFunctions

import scala.concurrent.{ExecutionContext, Future}

case class GraphQLShaclModelValidationPlugin(override val profile: ProfileName = ProfileNames.GRAPHQL)
extends BaseShaclModelValidationPlugin
with BaseApiValidationPlugin {

override val id: String = this.getClass.getSimpleName

override def priority: PluginPriority = HighPriority

override protected def getValidator: CustomShaclValidator = {
new CustomShaclValidator(
functions,
profile.messageStyle,
strategy = GraphQLIteratorStrategy
)
}

override protected def specificValidate(unit: BaseUnit, options: ValidationOptions)(implicit
executionContext: ExecutionContext
): Future[AMFValidationReport] = Future(validateWithShacl(unit, options: ValidationOptions))

override protected val functions: CustomShaclFunctions = APICustomShaclFunctions.functions
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package amf.apicontract.internal.validation.shacl.graphql

import amf.core.client.scala.model.domain.Shape
import amf.core.client.scala.model.domain.{AmfObject, Shape}
import amf.shapes.client.scala.model.domain._

import scala.annotation.tailrec
Expand All @@ -16,6 +16,7 @@ object GraphQLUtils {
}
}

@tailrec
def isValidInputType(schema: Shape): Boolean = {
schema match {
case a: ArrayShape => isValidInputType(a.items)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import scala.annotation.tailrec

object GraphQLValidator {

def checkValidPath(path: Seq[PropertyShape], key: Key): Seq[ValidationInfo] = {
private def checkValidPath(path: Seq[PropertyShape], key: Key): Seq[ValidationInfo] = {
path.flatMap { propertyShape =>
propertyShape.range match {
case n: NodeShape =>
Expand Down Expand Up @@ -69,7 +69,7 @@ object GraphQLValidator {
}
}

case class RequiredField(interface: String, field: GraphQLField)
private case class RequiredField(interface: String, field: GraphQLField)

def validateRequiredFields(obj: GraphQLObject): Seq[ValidationInfo] = {

Expand Down Expand Up @@ -216,15 +216,8 @@ object GraphQLValidator {
} else None
}

val operationValidations = obj.operations.flatMap { op =>
if (!op.isValidOutputType) {
validationInfo(
NodeShapeModel.Properties,
s"Type '${getShapeName(op.payload.get.schema)}' from field '${op.name}' must be an output type",
op.annotations
)
} else None
}
val operations = obj.operations
val operationValidations = validateOperations(operations)

operationValidations ++ propertiesValidations
} else {
Expand All @@ -234,7 +227,12 @@ object GraphQLValidator {
}

def validateOutputTypes(endpoint: GraphQLEndpoint): Seq[ValidationInfo] = {
endpoint.operations.flatMap { op =>
val operations = endpoint.operations
validateOperations(operations)
}

private def validateOperations(operations: Seq[GraphQLOperation]): Seq[ValidationInfo] = {
operations.flatMap { op =>
if (!op.isValidOutputType) {
validationInfo(
NodeShapeModel.Properties,
Expand All @@ -247,20 +245,18 @@ object GraphQLValidator {

def validateInputTypes(obj: GraphQLObject): Seq[ValidationInfo] = {
// fields arguments cannot be output types
val operationValidations = obj.operations.flatMap { op =>
op.parameters.flatMap { param =>
if (!param.isValidInputType) {
validationInfo(
NodeShapeModel.Properties,
s"Type '${getShapeName(param.schema)}' from argument '${param.name}' must be an input type",
param.annotations
)
} else None
}
}
val parameters: Seq[GraphQLParameter] = obj.operations.flatMap(_.parameters)
val operationValidations = validateParameters(parameters)

// input type fields or directive arguments cannot be output types
val propertiesValidations = obj.properties.flatMap { prop =>
val properties: Seq[GraphQLProperty] = obj.properties
val propertiesValidations = validateProperties(properties, obj)

operationValidations ++ propertiesValidations
}

private def validateProperties(properties: Seq[GraphQLProperty], obj: GraphQLObject): Seq[ValidationInfo] = {
properties.flatMap { prop =>
if (!prop.isValidInputType && obj.isInput) {
val message = s"Type '${getShapeName(prop.range)}' from field '${prop.name}' must be an input type"
validationInfo(
Expand All @@ -270,12 +266,14 @@ object GraphQLValidator {
)
} else None
}

operationValidations ++ propertiesValidations
}

def validateInputTypes(endpoint: GraphQLEndpoint): Seq[ValidationInfo] = {
endpoint.parameters.flatMap { param =>
validateParameters(endpoint.parameters)
}

private def validateParameters(parameters: Seq[GraphQLParameter]): Seq[ValidationInfo] = {
parameters.flatMap { param =>
if (!param.isValidInputType) {
validationInfo(
NodeShapeModel.Properties,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ object EnumValueValidator extends ValueValidator[ScalarShape] {
case o: ObjectNode => Seq(typeError("scalar", "object", o.annotations))
}
}

private def validateDataType(value: ScalarNode)(implicit targetField: Field): Seq[ValidationInfo] = {
value.dataType.value() match {
case DataTypes.Any => Nil // enum values are 'Any' explicitly
case otherDT => Seq(typeError("enum", friendlyName(otherDT), value.annotations))
}
}

private def validateValueIsMember(shape: ScalarShape, value: ScalarNode)(implicit targetField: Field): Seq[ValidationInfo] = {
private def validateValueIsMember(shape: ScalarShape, value: ScalarNode)(implicit
targetField: Field
): Seq[ValidationInfo] = {
val acceptedValues = shape.values
val actualValue = value.value.value()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ object NullableValueValidator extends ValueValidator[UnionShape] {
}
}

private def validateNonNullValue(shape: UnionShape, other: DataNode)(implicit targetField: Field): Seq[ValidationInfo] = {
private def validateNonNullValue(shape: UnionShape, other: DataNode)(implicit
targetField: Field
): Seq[ValidationInfo] = {
val concreteShape = shape.anyOf.filter(!_.isInstanceOf[NilShape]).head
ValueValidator.validate(concreteShape, other)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ object ObjectValueValidator extends ValueValidator[NodeShape] {

sealed case class ReportingInfo(shapeName: String, annotations: Annotations)

private def validateProperties(shape: NodeShape, value: ObjectNode)(implicit targetField: Field): Seq[ValidationInfo] = {
private def validateProperties(shape: NodeShape, value: ObjectNode)(implicit
targetField: Field
): Seq[ValidationInfo] = {
val actual: Map[String, DataNode] = value.allPropertiesWithName()
val expected = shape.properties
implicit val info: ReportingInfo = ReportingInfo(shape.name.value(), value.annotations)
Expand All @@ -29,7 +31,8 @@ object ObjectValueValidator extends ValueValidator[NodeShape] {
}

private def validateExpectedProperties(expected: Seq[PropertyShape], actual: Map[String, DataNode])(implicit
info: ReportingInfo, targetField: Field
info: ReportingInfo,
targetField: Field
): Seq[ValidationInfo] = {
expected.flatMap { expectedProperty => validateExpectedProperty(expectedProperty, actual) }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ object ScalarValueValidator extends ValueValidator[ScalarShape] {
}
}

private def validateDataType(shape: ScalarShape, value: ScalarNode)(implicit targetField: Field): Seq[ValidationInfo] = {
private def validateDataType(shape: ScalarShape, value: ScalarNode)(implicit
targetField: Field
): Seq[ValidationInfo] = {
val shapeDT = shape.dataType.value()
val valueDT = value.dataType.value()
shapeDT match {
Expand Down
Loading