Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewparmet committed May 13, 2024
1 parent 2b78bba commit 048c965
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@ import build.buf.protovalidate.internal.celext.ValidateLibrary
import build.buf.protovalidate.internal.evaluator.Evaluator
import build.buf.protovalidate.internal.evaluator.EvaluatorBuilder
import build.buf.protovalidate.internal.evaluator.MessageValue
import com.google.protobuf.DescriptorProtos
import com.google.protobuf.Descriptors
import com.google.protobuf.Descriptors.Descriptor
import org.projectnessie.cel.Env
import org.projectnessie.cel.Library
import protokt.v1.Beta
import protokt.v1.GeneratedMessage
import protokt.v1.Message
import protokt.v1.google.protobuf.FileDescriptor
import protokt.v1.google.protobuf.RuntimeContext
import protokt.v1.google.protobuf.toDynamicMessage
import java.util.concurrent.ConcurrentHashMap
Expand All @@ -50,20 +47,6 @@ class ProtoktValidator @JvmOverloads constructor(
private val evaluatorsByFullTypeName = ConcurrentHashMap<String, Evaluator>()
private val runtimeContext = RuntimeContext(emptyList())

fun load(descriptor: FileDescriptor) {
descriptor
.toProtobufJavaDescriptor()
.messageTypes
.forEach(::load)
}

private fun FileDescriptor.toProtobufJavaDescriptor(): Descriptors.FileDescriptor =
Descriptors.FileDescriptor.buildFrom(
DescriptorProtos.FileDescriptorProto.parseFrom(proto.serialize()),
dependencies.map { it.toProtobufJavaDescriptor() }.toTypedArray(),
true
)

fun load(descriptor: Descriptor) {
runtimeContext.add(descriptor)
evaluatorsByFullTypeName[descriptor.fullName] = evaluatorBuilder.load(descriptor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import buf.validate.conformance.cases.Strings
import build.buf.protovalidate.ValidationResult
import com.google.common.truth.Truth.assertThat
import com.google.protobuf.ByteString
import com.google.protobuf.DescriptorProtos
import com.google.protobuf.Descriptors
import org.junit.jupiter.api.Test
import protokt.v1.AbstractDeserializer
import protokt.v1.AbstractMessage
Expand All @@ -41,15 +43,32 @@ import protokt.v1.buf.validate.conformance.cases.numbers_file_descriptor
import protokt.v1.buf.validate.conformance.cases.oneofs_file_descriptor
import protokt.v1.buf.validate.conformance.cases.repeated_file_descriptor
import protokt.v1.buf.validate.conformance.cases.strings_file_descriptor
import protokt.v1.google.protobuf.FileDescriptor

abstract class AbstractProtoktValidatorTest {
protected val validator = ProtoktValidator()

abstract fun validate(message: Message): ValidationResult

private fun load(descriptor: FileDescriptor) {
descriptor
.toProtobufJavaDescriptor()
.messageTypes
.forEach {
runCatching { validator.load(it) }
}
}

private fun FileDescriptor.toProtobufJavaDescriptor(): Descriptors.FileDescriptor =
Descriptors.FileDescriptor.buildFrom(
DescriptorProtos.FileDescriptorProto.parseFrom(proto.serialize()),
dependencies.map { it.toProtobufJavaDescriptor() }.toTypedArray(),
true
)

@Test
fun `test required oneof constraint`() {
validator.load(messages_file_descriptor.descriptor)
load(messages_file_descriptor.descriptor)

val result =
validate(
Expand All @@ -58,9 +77,9 @@ abstract class AbstractProtoktValidatorTest {
MessageRequiredOneof.One.Val(
TestMsg {
const = "foo"
},
}
)
},
}
)

assertThat(result.violations).isEmpty()
Expand All @@ -69,13 +88,13 @@ abstract class AbstractProtoktValidatorTest {

@Test
fun `test oneof constraint`() {
validator.load(oneofs_file_descriptor.descriptor)
load(oneofs_file_descriptor.descriptor)

val result =
validate(
Oneof {
o = Oneof.O.X("foobar")
},
}
)

assertThat(result.violations).isEmpty()
Expand All @@ -84,21 +103,21 @@ abstract class AbstractProtoktValidatorTest {

@Test
fun `test uint64 in constraint`() {
validator.load(numbers_file_descriptor.descriptor)
load(numbers_file_descriptor.descriptor)

val result =
validate(
UInt64In {
`val` = 4u
},
}
)

assertThat(result.isSuccess).isFalse()
}

@Test
fun `test message with varint non-uint64 encoded purely as unknown fields (dynamic message without a dedicated type)`() {
validator.load(numbers_file_descriptor.descriptor)
load(numbers_file_descriptor.descriptor)

val result =
validate(
Expand All @@ -107,8 +126,8 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.setVal(4)
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result.isSuccess).isFalse()
Expand All @@ -120,16 +139,16 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.setVal(3)
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result2.isSuccess).isTrue()
}

@Test
fun `test message with varint uint64 encoded purely as unknown fields (dynamic message without a dedicated type)`() {
validator.load(numbers_file_descriptor.descriptor)
load(numbers_file_descriptor.descriptor)

val result =
validate(
Expand All @@ -138,8 +157,8 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.setVal(4)
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result.isSuccess).isFalse()
Expand All @@ -151,16 +170,16 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.setVal(3)
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result2.isSuccess).isTrue()
}

@Test
fun `test message with fixed32 encoded purely as unknown fields (dynamic message without a dedicated type)`() {
validator.load(numbers_file_descriptor.descriptor)
load(numbers_file_descriptor.descriptor)

val result =
validate(
Expand All @@ -169,8 +188,8 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.setVal(4)
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result.isSuccess).isFalse()
Expand All @@ -182,16 +201,16 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.setVal(3)
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result2.isSuccess).isTrue()
}

@Test
fun `test message with fixed64 encoded purely as unknown fields (dynamic message without a dedicated type)`() {
validator.load(numbers_file_descriptor.descriptor)
load(numbers_file_descriptor.descriptor)

val result =
validate(
Expand All @@ -200,8 +219,8 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.setVal(4)
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result.isSuccess).isFalse()
Expand All @@ -213,16 +232,16 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.setVal(3)
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result2.isSuccess).isTrue()
}

@Test
fun `test message with length delimited string encoded purely as unknown fields (dynamic message without a dedicated type)`() {
validator.load(strings_file_descriptor.descriptor)
load(strings_file_descriptor.descriptor)

val result =
validate(
Expand All @@ -231,8 +250,8 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.setVal("foo")
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result.isSuccess).isFalse()
Expand All @@ -244,16 +263,16 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.setVal("bar")
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result2.isSuccess).isTrue()
}

@Test
fun `test message with length delimited bytes encoded purely as unknown fields (dynamic message without a dedicated type)`() {
validator.load(bytes_file_descriptor.descriptor)
load(bytes_file_descriptor.descriptor)

val result =
validate(
Expand All @@ -262,8 +281,8 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.setVal(ByteString.copyFromUtf8("foo"))
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result.isSuccess).isFalse()
Expand All @@ -275,16 +294,16 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.setVal(ByteString.copyFromUtf8("bar"))
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result2.isSuccess).isTrue()
}

@Test
fun `test message with repeated values encoded purely as unknown fields (dynamic message without a dedicated type)`() {
validator.load(repeated_file_descriptor.descriptor)
load(repeated_file_descriptor.descriptor)

val result =
validate(
Expand All @@ -293,8 +312,8 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.addAllVal(listOf("foo", "foo"))
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result.isSuccess).isFalse()
Expand All @@ -306,8 +325,8 @@ abstract class AbstractProtoktValidatorTest {
.newBuilder()
.addAllVal(listOf("foo", "bar"))
.build()
.toByteArray(),
),
.toByteArray()
)
)

assertThat(result2.isSuccess).isTrue()
Expand Down

0 comments on commit 048c965

Please sign in to comment.