From 7b91cbe90c157d25eb48abc016f7f8df197fbc7d Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 19 Oct 2023 10:04:27 -0700 Subject: [PATCH] Internal Changes PiperOrigin-RevId: 574911225 --- .../main/java/dev/cel/bundle/CelBuilder.java | 8 + .../src/main/java/dev/cel/bundle/CelImpl.java | 8 + .../java/dev/cel/common/internal/BUILD.bazel | 7 +- .../internal/CombinedDescriptorPool.java | 3 +- .../internal/DefaultDescriptorPool.java | 24 ++- .../internal/DefaultMessageFactory.java | 4 + .../dev/cel/common/internal/DynamicProto.java | 142 +++--------------- .../dev/cel/common/internal/ProtoAdapter.java | 47 ------ .../cel/common/internal/ProtoEquality.java | 17 +-- .../internal/ProtoRegistryProvider.java | 38 ----- .../internal/DefaultMessageFactoryTest.java | 8 +- .../cel/common/internal/DynamicProtoTest.java | 88 ++++------- .../cel/common/internal/ProtoAdapterTest.java | 44 +++--- .../common/internal/ProtoEqualityTest.java | 6 +- .../extensions/CelProtoExtensionsTest.java | 27 ++++ .../main/java/dev/cel/runtime/Activation.java | 5 +- .../src/main/java/dev/cel/runtime/BUILD.bazel | 8 + .../dev/cel/runtime/CelRuntimeBuilder.java | 8 + .../dev/cel/runtime/CelRuntimeLegacyImpl.java | 89 ++++++----- .../dev/cel/runtime/DefaultDispatcher.java | 4 +- .../runtime/DescriptorMessageProvider.java | 48 +++--- .../cel/runtime/DynamicMessageFactory.java | 43 +++--- .../java/dev/cel/runtime/MessageFactory.java | 24 ++- .../java/dev/cel/runtime/RuntimeHelpers.java | 6 - .../dev/cel/runtime/StandardFunctions.java | 6 +- .../src/test/java/dev/cel/runtime/BUILD.bazel | 5 + .../java/dev/cel/runtime/CelRuntimeTest.java | 99 ++++++++++++ .../DescriptorMessageProviderTest.java | 92 +++++++++--- .../dev/cel/runtime/RuntimeEqualityTest.java | 12 +- .../dev/cel/runtime/RuntimeHelpersTest.java | 51 +++++-- .../src/main/java/dev/cel/testing/BUILD.bazel | 5 +- .../dev/cel/testing/BaseInterpreterTest.java | 4 +- .../main/java/dev/cel/testing/EvalSync.java | 8 +- 33 files changed, 524 insertions(+), 464 deletions(-) delete mode 100644 common/src/main/java/dev/cel/common/internal/ProtoRegistryProvider.java diff --git a/bundle/src/main/java/dev/cel/bundle/CelBuilder.java b/bundle/src/main/java/dev/cel/bundle/CelBuilder.java index 2e4b535ae..7857d4140 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelBuilder.java +++ b/bundle/src/main/java/dev/cel/bundle/CelBuilder.java @@ -20,6 +20,7 @@ import com.google.protobuf.DescriptorProtos.FileDescriptorSet; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import dev.cel.checker.ProtoTypeMask; import dev.cel.checker.TypeProvider; @@ -282,6 +283,13 @@ public interface CelBuilder { @CanIgnoreReturnValue CelBuilder addRuntimeLibraries(Iterable libraries); + /** + * Sets a proto ExtensionRegistry to assist with unpacking Any messages containing a proto2 + extension field. + */ + @CanIgnoreReturnValue + CelBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry); + /** Construct a new {@code Cel} instance from the provided configuration. */ Cel build(); } diff --git a/bundle/src/main/java/dev/cel/bundle/CelImpl.java b/bundle/src/main/java/dev/cel/bundle/CelImpl.java index 81680c86a..3d65ada89 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelImpl.java +++ b/bundle/src/main/java/dev/cel/bundle/CelImpl.java @@ -25,6 +25,7 @@ import com.google.protobuf.DescriptorProtos.FileDescriptorSet; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import dev.cel.checker.CelCheckerBuilder; import dev.cel.checker.ProtoTypeMask; @@ -339,6 +340,13 @@ public Builder addRuntimeLibraries(Iterable libraries) { return this; } + @Override + public CelBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry) { + checkNotNull(extensionRegistry); + runtimeBuilder.setExtensionRegistry(extensionRegistry); + return this; + } + @Override public Cel build() { return new CelImpl( diff --git a/common/src/main/java/dev/cel/common/internal/BUILD.bazel b/common/src/main/java/dev/cel/common/internal/BUILD.bazel index 104a82279..8219aa73e 100644 --- a/common/src/main/java/dev/cel/common/internal/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/internal/BUILD.bazel @@ -91,11 +91,9 @@ java_library( ], ) -# keep sorted DYNAMIC_PROTO_SOURCES = [ "DynamicProto.java", "ProtoAdapter.java", - "ProtoRegistryProvider.java", ] java_library( @@ -119,13 +117,12 @@ java_library( ], deps = [ ":converter", - ":default_instance_message_factory", + ":proto_message_factory", + ":well_known_proto", "//:auto_value", - "//common", "//common:error_codes", "//common:runtime_exception", "//common/annotations", - "//common/types:cel_types", "@cel_spec//proto/cel/expr:expr_java_proto", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", diff --git a/common/src/main/java/dev/cel/common/internal/CombinedDescriptorPool.java b/common/src/main/java/dev/cel/common/internal/CombinedDescriptorPool.java index 975e5bdbf..6227b9e7b 100644 --- a/common/src/main/java/dev/cel/common/internal/CombinedDescriptorPool.java +++ b/common/src/main/java/dev/cel/common/internal/CombinedDescriptorPool.java @@ -74,7 +74,8 @@ public ExtensionRegistry getExtensionRegistry() { private CombinedDescriptorPool(ImmutableList descriptorPools) { this.descriptorPools = descriptorPools; // TODO: Combine the extension registry. This will become necessary once we accept - // ExtensionRegistry through runtime builder. + // ExtensionRegistry through runtime builder. Ideally, proto team should open source this + // implementation but we may have to create our own. this.extensionRegistry = descriptorPools.stream() .map(CelDescriptorPool::getExtensionRegistry) diff --git a/common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java b/common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java index 66dbf2e1e..fc703c905 100644 --- a/common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java +++ b/common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java @@ -46,7 +46,10 @@ public final class DefaultDescriptorPool implements CelDescriptorPool { /** A DefaultDescriptorPool instance with just well known types loaded. */ public static final DefaultDescriptorPool INSTANCE = - new DefaultDescriptorPool(WELL_KNOWN_TYPE_DESCRIPTORS, ImmutableMultimap.of()); + new DefaultDescriptorPool( + WELL_KNOWN_TYPE_DESCRIPTORS, + ImmutableMultimap.of(), + ExtensionRegistry.getEmptyRegistry()); // K: Fully qualified message type name, V: Message descriptor private final ImmutableMap descriptorMap; @@ -55,7 +58,15 @@ public final class DefaultDescriptorPool implements CelDescriptorPool { // V: Field descriptor for the extension message private final ImmutableMultimap extensionDescriptorMap; + @SuppressWarnings("Immutable") // ExtensionRegistry is immutable, just not marked as such. + private final ExtensionRegistry extensionRegistry; + public static DefaultDescriptorPool create(CelDescriptors celDescriptors) { + return create(celDescriptors, ExtensionRegistry.getEmptyRegistry()); + } + + public static DefaultDescriptorPool create( + CelDescriptors celDescriptors, ExtensionRegistry extensionRegistry) { Map descriptorMap = new HashMap<>(); // Using a hashmap to allow deduping stream(WellKnownProto.values()).forEach(d -> descriptorMap.put(d.typeName(), d.descriptor())); @@ -64,7 +75,9 @@ public static DefaultDescriptorPool create(CelDescriptors celDescriptors) { } return new DefaultDescriptorPool( - ImmutableMap.copyOf(descriptorMap), celDescriptors.extensionDescriptors()); + ImmutableMap.copyOf(descriptorMap), + celDescriptors.extensionDescriptors(), + extensionRegistry); } @Override @@ -83,14 +96,15 @@ public Optional findExtensionDescriptor( @Override public ExtensionRegistry getExtensionRegistry() { - // TODO: Populate one from runtime builder. - return ExtensionRegistry.getEmptyRegistry(); + return extensionRegistry; } private DefaultDescriptorPool( ImmutableMap descriptorMap, - ImmutableMultimap extensionDescriptorMap) { + ImmutableMultimap extensionDescriptorMap, + ExtensionRegistry extensionRegistry) { this.descriptorMap = checkNotNull(descriptorMap); this.extensionDescriptorMap = checkNotNull(extensionDescriptorMap); + this.extensionRegistry = checkNotNull(extensionRegistry); } } diff --git a/common/src/main/java/dev/cel/common/internal/DefaultMessageFactory.java b/common/src/main/java/dev/cel/common/internal/DefaultMessageFactory.java index 381a5de9e..4a021cd90 100644 --- a/common/src/main/java/dev/cel/common/internal/DefaultMessageFactory.java +++ b/common/src/main/java/dev/cel/common/internal/DefaultMessageFactory.java @@ -23,6 +23,10 @@ /** DefaultMessageFactory produces {@link Message.Builder} instances by protobuf name. */ @Internal public final class DefaultMessageFactory implements ProtoMessageFactory { + + /** A default message factory instance that can construct well known typed messages. */ + public static final DefaultMessageFactory INSTANCE = create(DefaultDescriptorPool.INSTANCE); + private final CelDescriptorPool celDescriptorPool; public static DefaultMessageFactory create(CelDescriptorPool celDescriptorPool) { diff --git a/common/src/main/java/dev/cel/common/internal/DynamicProto.java b/common/src/main/java/dev/cel/common/internal/DynamicProto.java index 755a990e6..0cb9eb328 100644 --- a/common/src/main/java/dev/cel/common/internal/DynamicProto.java +++ b/common/src/main/java/dev/cel/common/internal/DynamicProto.java @@ -15,29 +15,16 @@ package dev.cel.common.internal; import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static java.util.Arrays.stream; -import com.google.auto.value.AutoBuilder; -import com.google.common.collect.ImmutableCollection; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; import com.google.errorprone.annotations.CheckReturnValue; import com.google.errorprone.annotations.Immutable; import com.google.protobuf.Any; import com.google.protobuf.ByteString; -import com.google.protobuf.Descriptors.Descriptor; -import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.DynamicMessage; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; -import dev.cel.common.CelDescriptorUtil; -import dev.cel.common.CelDescriptors; import dev.cel.common.annotations.Internal; -import dev.cel.common.types.CelTypes; -import java.util.Map.Entry; import java.util.Optional; -import org.jspecify.nullness.Nullable; /** * The {@code DynamicProto} class supports the conversion of {@link Any} values to concrete {@code @@ -49,59 +36,30 @@ @CheckReturnValue @Internal public final class DynamicProto { - - private static final ImmutableMap WELL_KNOWN_DESCRIPTORS = - stream(ProtoAdapter.WellKnownProto.values()) - .collect(toImmutableMap(d -> d.typeName(), d -> d.descriptor())); - - private final ImmutableMap dynamicDescriptors; - private final ImmutableMultimap dynamicExtensionDescriptors; private final ProtoMessageFactory protoMessageFactory; - /** {@code ProtoMessageFactory} provides a method to create a protobuf builder objects by name. */ - @Immutable - @FunctionalInterface - public interface ProtoMessageFactory { - Message.@Nullable Builder newBuilder(String messageName); + public static DynamicProto create(ProtoMessageFactory protoMessageFactory) { + return new DynamicProto(protoMessageFactory); } - /** Builder for configuring the {@link DynamicProto}. */ - @AutoBuilder(ofClass = DynamicProto.class) - public abstract static class Builder { - - /** Sets {@link CelDescriptors} to unpack any message types. */ - public abstract Builder setDynamicDescriptors(CelDescriptors celDescriptors); - - /** Sets a custom type factory to unpack any message types. */ - public abstract Builder setProtoMessageFactory(ProtoMessageFactory factory); - - /** Builds a new instance of {@link DynamicProto} */ - @CheckReturnValue - public abstract DynamicProto build(); - } - - public static Builder newBuilder() { - return new AutoBuilder_DynamicProto_Builder() - .setDynamicDescriptors(CelDescriptors.builder().build()) - .setProtoMessageFactory((typeName) -> null); + DynamicProto(ProtoMessageFactory protoMessageFactory) { + this.protoMessageFactory = checkNotNull(protoMessageFactory); } - DynamicProto( - CelDescriptors dynamicDescriptors, - ProtoMessageFactory protoMessageFactory) { - ImmutableMap messageTypeDescriptorMap = - CelDescriptorUtil.descriptorCollectionToMap(dynamicDescriptors.messageTypeDescriptors()); - ImmutableMap filteredDescriptors = - messageTypeDescriptorMap.entrySet().stream() - .filter(e -> !WELL_KNOWN_DESCRIPTORS.containsKey(e.getKey())) - .collect(toImmutableMap(Entry::getKey, Entry::getValue)); - this.dynamicDescriptors = - ImmutableMap.builder() - .putAll(WELL_KNOWN_DESCRIPTORS) - .putAll(filteredDescriptors) - .buildOrThrow(); - this.dynamicExtensionDescriptors = checkNotNull(dynamicDescriptors.extensionDescriptors()); - this.protoMessageFactory = checkNotNull(protoMessageFactory); + /** Attempts to unpack an Any message. */ + public Optional maybeUnpackAny(Message msg) { + try { + Any any = + msg instanceof Any + ? (Any) msg + : Any.parseFrom( + msg.toByteString(), + protoMessageFactory.getDescriptorPool().getExtensionRegistry()); + + return Optional.of(unpack(any)); + } catch (InvalidProtocolBufferException e) { + return Optional.empty(); + } } /** @@ -120,7 +78,8 @@ public Message unpack(Any any) throws InvalidProtocolBufferException { String.format("malformed type URL: %s", any.getTypeUrl()))); Message.Builder builder = - newMessageBuilder(messageTypeName) + protoMessageFactory + .newBuilder(messageTypeName) .orElseThrow( () -> new InvalidProtocolBufferException( @@ -136,7 +95,7 @@ public Message unpack(Any any) throws InvalidProtocolBufferException { */ public Message maybeAdaptDynamicMessage(DynamicMessage input) { Optional maybeBuilder = - newMessageBuilder(input.getDescriptorForType().getFullName()); + protoMessageFactory.newBuilder(input.getDescriptorForType().getFullName()); if (!maybeBuilder.isPresent() || maybeBuilder.get() instanceof DynamicMessage.Builder) { // Just return the same input if: // 1. We didn't get a builder back because there's no descriptor (nothing we can do) @@ -148,61 +107,6 @@ public Message maybeAdaptDynamicMessage(DynamicMessage input) { return merge(maybeBuilder.get(), input.toByteString()); } - /** - * This method instantiates a builder for the given {@code typeName} assuming one is configured - * within the descriptor set provided to the {@code DynamicProto} constructor. - * - *

When the {@code useLinkedTypes} flag is set, the {@code Message.Builder} returned will be - * the concrete builder instance linked into the binary if it is present; otherwise, the result - * will be a {@code DynamicMessageBuilder}. - */ - public Optional newMessageBuilder(String typeName) { - if (!CelTypes.isWellKnownType(typeName)) { - // Check if the message factory can produce a concrete message via custom type factory - // first. - Message.Builder builder = protoMessageFactory.newBuilder(typeName); - if (builder != null) { - return Optional.of(builder); - } - } - - Optional descriptor = maybeGetDescriptor(typeName); - if (!descriptor.isPresent()) { - return Optional.empty(); - } - // If the descriptor that's resolved does not match the descriptor instance in the message - // factory, the call to fetch the prototype will return null, and a dynamic proto message - // should be used as a fallback. - Optional message = - DefaultInstanceMessageFactory.getInstance().getPrototype(descriptor.get()); - if (message.isPresent()) { - return Optional.of(message.get().toBuilder()); - } - - // Fallback to a dynamic proto instance. - return Optional.of(DynamicMessage.newBuilder(descriptor.get())); - } - - private Optional maybeGetDescriptor(String typeName) { - - Descriptor descriptor = ProtoRegistryProvider.getTypeRegistry().find(typeName); - return Optional.ofNullable(descriptor != null ? descriptor : dynamicDescriptors.get(typeName)); - } - - /** Gets the corresponding field descriptor for an extension field on a message. */ - public Optional maybeGetExtensionDescriptor( - Descriptor containingDescriptor, String fieldName) { - - String typeName = containingDescriptor.getFullName(); - ImmutableCollection fieldDescriptors = - dynamicExtensionDescriptors.get(typeName); - if (fieldDescriptors.isEmpty()) { - return Optional.empty(); - } - - return fieldDescriptors.stream().filter(d -> d.getFullName().equals(fieldName)).findFirst(); - } - /** * Merge takes in a Message builder and merges another message bytes into the builder. Some * example usages are: @@ -214,7 +118,9 @@ public Optional maybeGetExtensionDescriptor( */ private Message merge(Message.Builder builder, ByteString inputBytes) { try { - return builder.mergeFrom(inputBytes, ProtoRegistryProvider.getExtensionRegistry()).build(); + return builder + .mergeFrom(inputBytes, protoMessageFactory.getDescriptorPool().getExtensionRegistry()) + .build(); } catch (InvalidProtocolBufferException e) { throw new AssertionError("Failed to merge input message into the message builder", e); } diff --git a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java index 822bce2b7..f66fe26a7 100644 --- a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java +++ b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java @@ -143,53 +143,6 @@ public final class ProtoAdapter { stream(WellKnownProto.values()) .collect(toImmutableMap(WellKnownProto::typeName, Function.identity())); - /** - * WellKnownProto types used throughout CEL. These types are specially handled to ensure that - * bidirectional conversion between CEL native values and these well-known types is performed - * consistently across runtimes. - */ - enum WellKnownProto { - JSON_VALUE(Value.getDescriptor()), - JSON_STRUCT_VALUE(Struct.getDescriptor()), - JSON_LIST_VALUE(ListValue.getDescriptor()), - ANY_VALUE(Any.getDescriptor()), - BOOL_VALUE(BoolValue.getDescriptor(), true), - BYTES_VALUE(BytesValue.getDescriptor(), true), - DOUBLE_VALUE(DoubleValue.getDescriptor(), true), - FLOAT_VALUE(FloatValue.getDescriptor(), true), - INT32_VALUE(Int32Value.getDescriptor(), true), - INT64_VALUE(Int64Value.getDescriptor(), true), - STRING_VALUE(StringValue.getDescriptor(), true), - UINT32_VALUE(UInt32Value.getDescriptor(), true), - UINT64_VALUE(UInt64Value.getDescriptor(), true), - DURATION_VALUE(Duration.getDescriptor()), - TIMESTAMP_VALUE(Timestamp.getDescriptor()); - - private final Descriptor descriptor; - private final boolean isWrapperType; - - WellKnownProto(Descriptor descriptor) { - this(descriptor, /* isWrapperType= */ false); - } - - WellKnownProto(Descriptor descriptor, boolean isWrapperType) { - this.descriptor = descriptor; - this.isWrapperType = isWrapperType; - } - - Descriptor descriptor() { - return descriptor; - } - - String typeName() { - return descriptor.getFullName(); - } - - boolean isWrapperType() { - return isWrapperType; - } - } - private final DynamicProto dynamicProto; private final boolean enableUnsignedLongs; diff --git a/common/src/main/java/dev/cel/common/internal/ProtoEquality.java b/common/src/main/java/dev/cel/common/internal/ProtoEquality.java index 0337894d5..57130b7b7 100644 --- a/common/src/main/java/dev/cel/common/internal/ProtoEquality.java +++ b/common/src/main/java/dev/cel/common/internal/ProtoEquality.java @@ -22,7 +22,6 @@ import com.google.protobuf.ByteString; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; -import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import dev.cel.common.annotations.Internal; import java.util.List; @@ -69,8 +68,8 @@ public boolean equals(Message message1, Message message2) { // Test whether the typeUrl values are the same, if not return false. // Use the dynamicProto.unpack(message1), dynamicProto.unpack(message2) // and assign the results to message1 and message2. - Optional unpackedAny1 = anyUnpack(message1); - Optional unpackedAny2 = anyUnpack(message2); + Optional unpackedAny1 = dynamicProto.maybeUnpackAny(message1); + Optional unpackedAny2 = dynamicProto.maybeUnpackAny(message2); if (unpackedAny1.isPresent() && unpackedAny2.isPresent()) { return equals(unpackedAny1.get(), unpackedAny2.get()); } @@ -165,18 +164,6 @@ private ByteString anyValue(Message msg) { return (ByteString) msg.getField(value); } - private Optional anyUnpack(Message msg) { - try { - Any any = - msg instanceof Any - ? (Any) msg - : Any.parseFrom(msg.toByteString(), ProtoRegistryProvider.getExtensionRegistry()); - return Optional.of(dynamicProto.unpack(any)); - } catch (InvalidProtocolBufferException e) { - return Optional.empty(); - } - } - private static ProtoMap protoMap(List entries) { ImmutableMap.Builder protoMap = ImmutableMap.builder(); FieldDescriptor keyField = null; diff --git a/common/src/main/java/dev/cel/common/internal/ProtoRegistryProvider.java b/common/src/main/java/dev/cel/common/internal/ProtoRegistryProvider.java deleted file mode 100644 index 72756f592..000000000 --- a/common/src/main/java/dev/cel/common/internal/ProtoRegistryProvider.java +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dev.cel.common.internal; - -import com.google.protobuf.ExtensionRegistry; -import com.google.protobuf.TypeRegistry; -import dev.cel.common.annotations.Internal; - -/** - * ProtoRegistryProvider provides Extension and Type registries for handling Protobuf.Any messages. - * - *

CEL Library Internals. Do Not Use. - */ -@Internal -public final class ProtoRegistryProvider { - - public static ExtensionRegistry getExtensionRegistry() { - return ExtensionRegistry.getEmptyRegistry(); - } - - static TypeRegistry getTypeRegistry() { - return TypeRegistry.getEmptyTypeRegistry(); - } - - private ProtoRegistryProvider() {} -} diff --git a/common/src/test/java/dev/cel/common/internal/DefaultMessageFactoryTest.java b/common/src/test/java/dev/cel/common/internal/DefaultMessageFactoryTest.java index 68f104685..593817d10 100644 --- a/common/src/test/java/dev/cel/common/internal/DefaultMessageFactoryTest.java +++ b/common/src/test/java/dev/cel/common/internal/DefaultMessageFactoryTest.java @@ -41,8 +41,7 @@ public final class DefaultMessageFactoryTest { @Test public void newBuilder_wellKnownType_producesNewMessage() { - DefaultMessageFactory messageFactory = - DefaultMessageFactory.create(DefaultDescriptorPool.INSTANCE); + DefaultMessageFactory messageFactory = DefaultMessageFactory.INSTANCE; Value.Builder valueBuilder = (Value.Builder) messageFactory.newBuilder("google.protobuf.Value").get(); @@ -69,8 +68,7 @@ public void newBuilder_withDescriptor_producesNewMessageBuilder() { @Test public void newBuilder_unknownMessage_returnsEmpty() { - DefaultMessageFactory messageFactory = - DefaultMessageFactory.create(DefaultDescriptorPool.INSTANCE); + DefaultMessageFactory messageFactory = DefaultMessageFactory.INSTANCE; assertThat(messageFactory.newBuilder("unknown_message")).isEmpty(); } @@ -108,7 +106,7 @@ public void combinedMessageFactoryTest() { CombinedMessageFactory messageFactory = new ProtoMessageFactory.CombinedMessageFactory( ImmutableList.of( - DefaultMessageFactory.create(DefaultDescriptorPool.INSTANCE), + DefaultMessageFactory.INSTANCE, (messageName) -> messageName.equals("test") ? Optional.of(TestAllTypes.newBuilder()) diff --git a/common/src/test/java/dev/cel/common/internal/DynamicProtoTest.java b/common/src/test/java/dev/cel/common/internal/DynamicProtoTest.java index d563d01a3..cc5ba5632 100644 --- a/common/src/test/java/dev/cel/common/internal/DynamicProtoTest.java +++ b/common/src/test/java/dev/cel/common/internal/DynamicProtoTest.java @@ -15,7 +15,6 @@ package dev.cel.common.internal; import static com.google.common.truth.Truth.assertThat; -import static com.google.common.truth.Truth8.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertThrows; @@ -55,7 +54,8 @@ public void unpackLinkedMessageType_withTypeRegistry() throws Exception { CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(Expr.getDescriptor().getFile()); DynamicProto dynamicProto = - DynamicProto.newBuilder().setDynamicDescriptors(celDescriptors).build(); + DynamicProto.create( + DefaultMessageFactory.create(DefaultDescriptorPool.create(celDescriptors))); Message unpacked = dynamicProto.unpack(packedExpr); @@ -74,7 +74,8 @@ public void unpackLinkedMessageType_withTypeRegistry_multiFileNested() throws Ex CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(MultiFile.getDescriptor().getFile()); DynamicProto dynamicProto = - DynamicProto.newBuilder().setDynamicDescriptors(celDescriptors).build(); + DynamicProto.create( + DefaultMessageFactory.create(DefaultDescriptorPool.create(celDescriptors))); Message unpacked = dynamicProto.unpack(packed); @@ -89,7 +90,8 @@ public void unpackLinkedMessageType_withTypeRegistry_singleFileNested() throws E CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(SingleFile.getDescriptor().getFile()); DynamicProto dynamicProto = - DynamicProto.newBuilder().setDynamicDescriptors(celDescriptors).build(); + DynamicProto.create( + DefaultMessageFactory.create(DefaultDescriptorPool.create(celDescriptors))); Message unpacked = dynamicProto.unpack(packed); @@ -105,7 +107,8 @@ public void unpackLinkedMessageType_withTypeRegistryCached() throws Exception { CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(Expr.getDescriptor().getFile()); DynamicProto dynamicProto = - DynamicProto.newBuilder().setDynamicDescriptors(celDescriptors).build(); + DynamicProto.create( + DefaultMessageFactory.create(DefaultDescriptorPool.create(celDescriptors))); // Order is important here. Message unpacked = dynamicProto.unpack(packedExpr); @@ -131,7 +134,8 @@ public void unpackLinkedMessageType_removeDescriptorLocalLinkedType() throws Exc CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptorBuilder.build()); DynamicProto dynamicProto = - DynamicProto.newBuilder().setDynamicDescriptors(celDescriptors).build(); + DynamicProto.create( + DefaultMessageFactory.create(DefaultDescriptorPool.create(celDescriptors))); Message unpacked = dynamicProto.unpack(packedStruct); @@ -141,9 +145,7 @@ public void unpackLinkedMessageType_removeDescriptorLocalLinkedType() throws Exc @Test public void unpackDynamicMessageType_noDescriptor() throws Exception { - DynamicProto dynamicProto = - DynamicProto.newBuilder() - .build(); + DynamicProto dynamicProto = DynamicProto.create(DefaultMessageFactory.INSTANCE); Any.Builder anyValue = Any.newBuilder(); TextFormat.merge(readFile("value.textproto"), anyValue); assertThat(anyValue.getTypeUrl()).isEqualTo("type.googleapis.com/google.api.expr.Value"); @@ -152,7 +154,7 @@ public void unpackDynamicMessageType_noDescriptor() throws Exception { @Test public void unpackDynamicMessageType_badDescriptor() throws Exception { - DynamicProto dynamicProto = DynamicProto.newBuilder().build(); + DynamicProto dynamicProto = DynamicProto.create(DefaultMessageFactory.INSTANCE); Any.Builder anyValue = Any.newBuilder(); TextFormat.merge(readFile("value.textproto"), anyValue); anyValue.setTypeUrl("google.api.expr.Value"); @@ -166,7 +168,8 @@ public void unpackDynamicMessageType() throws Exception { CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet(fds); CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(files); DynamicProto dynamicProto = - DynamicProto.newBuilder().setDynamicDescriptors(celDescriptors).build(); + DynamicProto.create( + DefaultMessageFactory.create(DefaultDescriptorPool.create(celDescriptors))); Any.Builder anyValue = Any.newBuilder(); TextFormat.merge(readFile("value.textproto"), anyValue); assertThat(anyValue.getTypeUrl()).isEqualTo("type.googleapis.com/google.api.expr.Value"); @@ -201,7 +204,8 @@ public void unpackDynamicMessageType_cached() throws Exception { CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet(fds); CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(files); DynamicProto dynamicProto = - DynamicProto.newBuilder().setDynamicDescriptors(celDescriptors).build(); + DynamicProto.create( + DefaultMessageFactory.create(DefaultDescriptorPool.create(celDescriptors))); Any.Builder anyValue = Any.newBuilder(); TextFormat.merge(readFile("value.textproto"), anyValue); assertThat(anyValue.getTypeUrl()).isEqualTo("type.googleapis.com/google.api.expr.Value"); @@ -212,7 +216,7 @@ public void unpackDynamicMessageType_cached() throws Exception { @Test public void maybeAdaptDynamicMessage() throws Exception { - DynamicProto dynamicProto = DynamicProto.newBuilder().build(); + DynamicProto dynamicProto = DynamicProto.create(DefaultMessageFactory.INSTANCE); Struct struct = Struct.newBuilder() .putFields("hello", Value.newBuilder().setStringValue("world").build()) @@ -220,7 +224,9 @@ public void maybeAdaptDynamicMessage() throws Exception { Any any = Any.pack(struct); DynamicMessage anyDyn = DynamicMessage.parseFrom( - Struct.getDescriptor(), any.getValue(), ProtoRegistryProvider.getExtensionRegistry()); + Struct.getDescriptor(), + any.getValue(), + DefaultDescriptorPool.INSTANCE.getExtensionRegistry()); Message adapted = dynamicProto.maybeAdaptDynamicMessage(anyDyn); assertThat(adapted).isEqualTo(struct); assertThat(adapted).isInstanceOf(Struct.class); @@ -228,7 +234,7 @@ public void maybeAdaptDynamicMessage() throws Exception { @Test public void maybeAdaptDynamicMessage_cached() throws Exception { - DynamicProto dynamicProto = DynamicProto.newBuilder().build(); + DynamicProto dynamicProto = DynamicProto.create(DefaultMessageFactory.INSTANCE); Struct struct = Struct.newBuilder() .putFields("hello", Value.newBuilder().setStringValue("world").build()) @@ -236,7 +242,9 @@ public void maybeAdaptDynamicMessage_cached() throws Exception { Any any = Any.pack(struct); DynamicMessage anyDyn = DynamicMessage.parseFrom( - Struct.getDescriptor(), any.getValue(), ProtoRegistryProvider.getExtensionRegistry()); + Struct.getDescriptor(), + any.getValue(), + DefaultDescriptorPool.INSTANCE.getExtensionRegistry()); Message adapted = dynamicProto.maybeAdaptDynamicMessage(anyDyn); Message adapted2 = dynamicProto.maybeAdaptDynamicMessage(anyDyn); assertThat(adapted).isEqualTo(adapted2); @@ -251,9 +259,8 @@ public void maybeAdaptDynamicMessage_notLinked() throws Exception { CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet(fds); CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(files); DynamicProto dynamicProto = - DynamicProto.newBuilder() - .setDynamicDescriptors(celDescriptors) - .build(); + DynamicProto.create( + DefaultMessageFactory.create(DefaultDescriptorPool.create(celDescriptors))); Any any = TextFormat.parse(readFile("value.textproto"), Any.class); Message unpacked = dynamicProto.unpack(any); assertThat(unpacked).isInstanceOf(DynamicMessage.class); @@ -261,49 +268,6 @@ public void maybeAdaptDynamicMessage_notLinked() throws Exception { .isSameInstanceAs(unpacked); } - @Test - public void newBuilder() throws Exception { - DynamicProto dynamicProto = - DynamicProto.newBuilder() - .setDynamicDescriptors( - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - Value.getDescriptor().getFile())) - .build(); - Value.Builder valueBuilder = - (Value.Builder) dynamicProto.newMessageBuilder("google.protobuf.Value").get(); - assertThat(valueBuilder.setStringValue("hello").build()) - .isEqualTo(Value.newBuilder().setStringValue("hello").build()); - } - - @Test - public void newBuilder_notLinked() throws Exception { - DynamicProto dynamicProto = - DynamicProto.newBuilder() - .setDynamicDescriptors( - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - Value.getDescriptor().getFile())) - .build(); - FieldDescriptor stringValueField = Value.getDescriptor().findFieldByName("string_value"); - Message.Builder valueBuilder = dynamicProto.newMessageBuilder("google.protobuf.Value").get(); - assertThat(valueBuilder.setField(stringValueField, "hello").build()) - .isEqualTo(Value.newBuilder().setStringValue("hello").build()); - } - - @Test - public void newBuilder_dynamic() throws Exception { - FileDescriptorSet fds = TextFormat.parse(readFile("value.fds"), FileDescriptorSet.class); - ImmutableSet files = - CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet(fds); - CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(files); - DynamicProto dynamicProto = - DynamicProto.newBuilder() - .setDynamicDescriptors(celDescriptors) - .build(); - assertThat(dynamicProto.newMessageBuilder("google.api.expr.Value")).isPresent(); - assertThat(dynamicProto.newMessageBuilder("google.api.expr.Value").get()) - .isInstanceOf(DynamicMessage.Builder.class); - } - private static String readFile(String path) throws IOException { return Resources.toString(Resources.getResource(Ascii.toLowerCase(path)), UTF_8); } diff --git a/common/src/test/java/dev/cel/common/internal/ProtoAdapterTest.java b/common/src/test/java/dev/cel/common/internal/ProtoAdapterTest.java index 65d114c97..83795b554 100644 --- a/common/src/test/java/dev/cel/common/internal/ProtoAdapterTest.java +++ b/common/src/test/java/dev/cel/common/internal/ProtoAdapterTest.java @@ -40,6 +40,7 @@ import dev.cel.common.CelOptions; import java.util.Arrays; import java.util.List; +import java.util.Optional; import org.junit.Test; import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; @@ -55,6 +56,9 @@ public final class ProtoAdapterTest { private static final CelOptions CURRENT = CelOptions.newBuilder().enableUnsignedLongs(true).build(); + private static final DynamicProto DYNAMIC_PROTO = + DynamicProto.create(DefaultMessageFactory.INSTANCE); + @RunWith(Parameterized.class) public static class BidirectionalConversionTest { @Parameter(0) @@ -156,8 +160,8 @@ public static List data() { @Test public void adaptValueToProto_bidirectionalConversion() { - ProtoAdapter protoAdapter = - new ProtoAdapter(DynamicProto.newBuilder().build(), options.enableUnsignedLongs()); + DynamicProto dynamicProto = DynamicProto.create(DefaultMessageFactory.INSTANCE); + ProtoAdapter protoAdapter = new ProtoAdapter(dynamicProto, options.enableUnsignedLongs()); assertThat(protoAdapter.adaptValueToProto(value, proto.getDescriptorForType().getFullName())) .hasValue(proto); assertThat(protoAdapter.adaptProtoToValue(proto)).isEqualTo(value); @@ -171,13 +175,11 @@ public void adaptAnyValue_hermeticTypes_bidirectionalConversion() { Expr expr = Expr.newBuilder().setExpression("test").build(); ProtoAdapter protoAdapter = new ProtoAdapter( - DynamicProto.newBuilder() - .setProtoMessageFactory( - (typeName) -> - typeName.equals(Expr.getDescriptor().getFullName()) - ? Expr.newBuilder() - : null) - .build(), + DynamicProto.create( + (typeName) -> + typeName.equals(Expr.getDescriptor().getFullName()) + ? Optional.of(Expr.newBuilder()) + : Optional.empty()), LEGACY.enableUnsignedLongs()); assertThat(protoAdapter.adaptValueToProto(expr, Any.getDescriptor().getFullName())) .hasValue(Any.pack(expr)); @@ -189,8 +191,7 @@ public void adaptAnyValue_hermeticTypes_bidirectionalConversion() { public static class AsymmetricConversionTest { @Test public void adaptValueToProto_asymmetricNullConversion() { - ProtoAdapter protoAdapter = - new ProtoAdapter(DynamicProto.newBuilder().build(), LEGACY.enableUnsignedLongs()); + ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); assertThat(protoAdapter.adaptValueToProto(null, Any.getDescriptor().getFullName())) .hasValue(Any.pack(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build())); assertThat( @@ -201,8 +202,7 @@ public void adaptValueToProto_asymmetricNullConversion() { @Test public void adaptValueToProto_asymmetricFloatConversion() { - ProtoAdapter protoAdapter = - new ProtoAdapter(DynamicProto.newBuilder().build(), LEGACY.enableUnsignedLongs()); + ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); assertThat(protoAdapter.adaptValueToProto(1.5F, Any.getDescriptor().getFullName())) .hasValue(Any.pack(FloatValue.of(1.5F))); assertThat(protoAdapter.adaptProtoToValue(Any.pack(FloatValue.of(1.5F)))).isEqualTo(1.5D); @@ -210,8 +210,7 @@ public void adaptValueToProto_asymmetricFloatConversion() { @Test public void adaptValueToProto_asymmetricDoubleFloatConversion() { - ProtoAdapter protoAdapter = - new ProtoAdapter(DynamicProto.newBuilder().build(), LEGACY.enableUnsignedLongs()); + ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); assertThat(protoAdapter.adaptValueToProto(1.5D, FloatValue.getDescriptor().getFullName())) .hasValue(FloatValue.of(1.5F)); assertThat(protoAdapter.adaptProtoToValue(FloatValue.of(1.5F))).isEqualTo(1.5D); @@ -219,16 +218,14 @@ public void adaptValueToProto_asymmetricDoubleFloatConversion() { @Test public void adaptValueToProto_asymmetricFloatDoubleConversion() { - ProtoAdapter protoAdapter = - new ProtoAdapter(DynamicProto.newBuilder().build(), LEGACY.enableUnsignedLongs()); + ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); assertThat(protoAdapter.adaptValueToProto(1.5F, DoubleValue.getDescriptor().getFullName())) .hasValue(DoubleValue.of(1.5D)); } @Test public void adaptValueToProto_asymmetricJsonConversion() { - ProtoAdapter protoAdapter = - new ProtoAdapter(DynamicProto.newBuilder().build(), CURRENT.enableUnsignedLongs()); + ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, CURRENT.enableUnsignedLongs()); assertThat( protoAdapter.adaptValueToProto( UnsignedLong.valueOf(1L), Value.getDescriptor().getFullName())) @@ -250,8 +247,7 @@ public void adaptValueToProto_asymmetricJsonConversion() { @Test public void adaptValueToProto_unsupportedJsonConversion() { - ProtoAdapter protoAdapter = - new ProtoAdapter(DynamicProto.newBuilder().build(), LEGACY.enableUnsignedLongs()); + ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); assertThat( protoAdapter.adaptValueToProto( ImmutableMap.of(1, 1), Any.getDescriptor().getFullName())) @@ -260,8 +256,7 @@ public void adaptValueToProto_unsupportedJsonConversion() { @Test public void adaptValueToProto_unsupportedJsonListConversion() { - ProtoAdapter protoAdapter = - new ProtoAdapter(DynamicProto.newBuilder().build(), LEGACY.enableUnsignedLongs()); + ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); assertThat( protoAdapter.adaptValueToProto( ImmutableMap.of(1, 1), ListValue.getDescriptor().getFullName())) @@ -270,8 +265,7 @@ public void adaptValueToProto_unsupportedJsonListConversion() { @Test public void adaptValueToProto_unsupportedConversion() { - ProtoAdapter protoAdapter = - new ProtoAdapter(DynamicProto.newBuilder().build(), LEGACY.enableUnsignedLongs()); + ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); assertThat(protoAdapter.adaptValueToProto("Hello", Expr.getDescriptor().getFullName())) .isEmpty(); } diff --git a/common/src/test/java/dev/cel/common/internal/ProtoEqualityTest.java b/common/src/test/java/dev/cel/common/internal/ProtoEqualityTest.java index c979175c8..7cbabf9bc 100644 --- a/common/src/test/java/dev/cel/common/internal/ProtoEqualityTest.java +++ b/common/src/test/java/dev/cel/common/internal/ProtoEqualityTest.java @@ -37,7 +37,7 @@ public final class ProtoEqualityTest { @Before public void setUp() { - this.dynamicProto = DynamicProto.newBuilder().build(); + this.dynamicProto = DynamicProto.create(DefaultMessageFactory.INSTANCE); this.protoEquality = new ProtoEquality(dynamicProto); } @@ -268,12 +268,12 @@ public void equalsMessageDynamicAnyFields() throws InvalidProtocolBufferExceptio DynamicMessage.parseFrom( Any.getDescriptor(), doublePackedStruct.getValue(), - ProtoRegistryProvider.getExtensionRegistry()); + DefaultDescriptorPool.INSTANCE.getExtensionRegistry()); DynamicMessage dynAny2 = DynamicMessage.parseFrom( Any.getDescriptor(), doublePackedStruct.getValue(), - ProtoRegistryProvider.getExtensionRegistry()); + DefaultDescriptorPool.INSTANCE.getExtensionRegistry()); assertThat(protoEquality.equals(dynAny, dynAny2)).isTrue(); } } diff --git a/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java index eea889f94..d7b66cb1d 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java @@ -19,10 +19,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Any; import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOverloadDecl; @@ -277,6 +281,29 @@ public void getExt_nonProtoNamespace_success(String expr) throws Exception { assertThat(result).isTrue(); } + @Test + public void getExt_onAnyPackedExtensionField_success() throws Exception { + ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); + MessagesProto2Extensions.registerAllExtensions(extensionRegistry); + Cel cel = + CelFactory.standardCelBuilder() + .addCompilerLibraries(CelExtensions.protos()) + .addFileTypes(MessagesProto2Extensions.getDescriptor()) + .setExtensionRegistry(extensionRegistry) + .addVar( + "msg", StructTypeReference.create("dev.cel.testing.testdata.proto2.Proto2Message")) + .build(); + CelAbstractSyntaxTree ast = + cel.compile("proto.getExt(msg, dev.cel.testing.testdata.proto2.int32_ext)").getAst(); + Any anyMsg = + Any.pack( + Proto2Message.newBuilder().setExtension(MessagesProto2Extensions.int32Ext, 1).build()); + + Long result = (Long) cel.createProgram(ast).eval(ImmutableMap.of("msg", anyMsg)); + + assertThat(result).isEqualTo(1); + } + private enum ParseErrorTestCase { FIELD_NOT_FULLY_QUALIFIED( "proto.getExt(Proto2ExtensionScopedMessage{}, int64_ext)", diff --git a/runtime/src/main/java/dev/cel/runtime/Activation.java b/runtime/src/main/java/dev/cel/runtime/Activation.java index d75d9f964..c51e77e0c 100644 --- a/runtime/src/main/java/dev/cel/runtime/Activation.java +++ b/runtime/src/main/java/dev/cel/runtime/Activation.java @@ -25,6 +25,7 @@ import dev.cel.common.CelOptions; import dev.cel.common.ExprFeatures; import dev.cel.common.annotations.Internal; +import dev.cel.common.internal.DefaultMessageFactory; import dev.cel.common.internal.DynamicProto; import dev.cel.common.internal.ProtoAdapter; import java.util.HashMap; @@ -166,8 +167,10 @@ public static Activation fromProto(Message message) { public static Activation fromProto(Message message, CelOptions celOptions) { Map variables = new HashMap<>(); Map msgFieldValues = message.getAllFields(); + ProtoAdapter protoAdapter = - new ProtoAdapter(DynamicProto.newBuilder().build(), celOptions.enableUnsignedLongs()); + new ProtoAdapter( + DynamicProto.create(DefaultMessageFactory.INSTANCE), celOptions.enableUnsignedLongs()); for (FieldDescriptor field : message.getDescriptorForType().getFields()) { // Get the value of the field set on the message, if present, otherwise use reflection to diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index 1c89d0996..646845723 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -52,6 +52,7 @@ java_library( "//common:runtime_exception", "//common/annotations", "//common/internal:comparison_functions", + "//common/internal:default_message_factory", "//common/internal:dynamic_proto", "//common/types", "//common/types:type_providers", @@ -84,7 +85,10 @@ java_library( "//common:runtime_exception", "//common/annotations", "//common/ast", + "//common/internal:cel_descriptor_pools", + "//common/internal:default_message_factory", "//common/internal:dynamic_proto", + "//common/internal:proto_message_factory", "//common/types:type_providers", "@cel_spec//proto/cel/expr:expr_java_proto", "@maven//:com_google_code_findbugs_annotations", @@ -150,7 +154,11 @@ java_library( "//common:error_codes", "//common:options", "//common/annotations", + "//common/internal:cel_descriptor_pools", + "//common/internal:default_message_factory", "//common/internal:dynamic_proto", + "//common/internal:proto_message_factory", + "//common/types:cel_types", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java index b24ef188a..1408a2069 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java @@ -19,6 +19,7 @@ import com.google.protobuf.DescriptorProtos.FileDescriptorSet; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import dev.cel.common.CelOptions; import java.util.function.Function; @@ -149,6 +150,13 @@ public interface CelRuntimeBuilder { @CanIgnoreReturnValue CelRuntimeBuilder addLibraries(Iterable libraries); + /** + * Sets a proto ExtensionRegistry to assist with unpacking Any messages containing a proto2 + extension field. + */ + @CanIgnoreReturnValue + CelRuntimeBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry); + /** Build a new instance of the {@code CelRuntime}. */ @CheckReturnValue CelRuntime build(); diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java index 4c052a03b..549977846 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java @@ -16,7 +16,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.ImmutableSet.toImmutableSet; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -26,14 +25,23 @@ import com.google.protobuf.DescriptorProtos.FileDescriptorSet; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelDescriptors; import dev.cel.common.CelOptions; import dev.cel.common.annotations.Internal; +import dev.cel.common.internal.CelDescriptorPool; +import dev.cel.common.internal.CombinedDescriptorPool; +import dev.cel.common.internal.DefaultDescriptorPool; +import dev.cel.common.internal.DefaultMessageFactory; import dev.cel.common.internal.DynamicProto; +// CEL-Internal-3 +import dev.cel.common.internal.ProtoMessageFactory; +import dev.cel.common.types.CelTypes; import java.util.Arrays; +import java.util.Optional; import java.util.function.Function; import org.jspecify.nullness.Nullable; @@ -64,7 +72,6 @@ public static Builder newBuilder() { /** Builder class for {@code CelRuntimeLegacyImpl}. */ public static final class Builder implements CelRuntimeBuilder { - private final ImmutableSet.Builder messageTypes; private final ImmutableSet.Builder fileTypes; private final ImmutableMap.Builder functionBindings; private final ImmutableSet.Builder celRuntimeLibraries; @@ -74,6 +81,7 @@ public static final class Builder implements CelRuntimeBuilder { private boolean standardEnvironmentEnabled; private Function customTypeFactory; + private ExtensionRegistry extensionRegistry; @Override @CanIgnoreReturnValue @@ -156,6 +164,14 @@ public Builder addLibraries(Iterable libraries) { return this; } + @Override + @CanIgnoreReturnValue + public Builder setExtensionRegistry(ExtensionRegistry extensionRegistry) { + checkNotNull(extensionRegistry); + this.extensionRegistry = extensionRegistry.getUnmodifiable(); + return this; + } + /** Build a new {@code CelRuntimeLegacyImpl} instance from the builder config. */ @Override @CanIgnoreReturnValue @@ -163,36 +179,26 @@ public CelRuntimeLegacyImpl build() { // Add libraries, such as extensions celRuntimeLibraries.build().forEach(celLibrary -> celLibrary.setRuntimeOptions(this)); - ImmutableSet fileTypeSet = fileTypes.build(); - ImmutableSet messageTypeSet = messageTypes.build(); - if (!messageTypeSet.isEmpty()) { - fileTypeSet = - new ImmutableSet.Builder() - .addAll(fileTypeSet) - .addAll(messageTypeSet.stream().map(Descriptor::getFile).collect(toImmutableSet())) - .build(); - } - - CelDescriptors celDescriptors = - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - fileTypeSet, options.resolveTypeDependencies()); + CelDescriptorPool celDescriptorPool = + newDescriptorPool( + fileTypes.build(), + extensionRegistry, + options); - // This lambda implements @Immutable interface 'MessageFactory', but 'Builder' has non-final - // field 'customTypeFactory' @SuppressWarnings("Immutable") - MessageFactory runtimeTypeFactory = - customTypeFactory != null ? typeName -> customTypeFactory.apply(typeName) : null; - + ProtoMessageFactory runtimeTypeFactory = + customTypeFactory != null + ? messageName -> + CelTypes.isWellKnownType( + messageName) // Let DefaultMessageFactory handle WKT constructions + ? Optional.empty() + : Optional.ofNullable(customTypeFactory.apply(messageName)) + : null; runtimeTypeFactory = - maybeCombineTypeFactory( - runtimeTypeFactory, - DynamicMessageFactory.typeFactory(celDescriptors)); + maybeCombineMessageFactory( + runtimeTypeFactory, DefaultMessageFactory.create(celDescriptorPool)); - DynamicProto dynamicProto = - DynamicProto.newBuilder() - .setDynamicDescriptors(celDescriptors) - .setProtoMessageFactory(runtimeTypeFactory::newBuilder) - .build(); + DynamicProto dynamicProto = DynamicProto.create(runtimeTypeFactory); DefaultDispatcher dispatcher = DefaultDispatcher.create(options, dynamicProto); if (standardEnvironmentEnabled) { @@ -218,28 +224,41 @@ public CelRuntimeLegacyImpl build() { return new CelRuntimeLegacyImpl( new DefaultInterpreter( - new DescriptorMessageProvider(runtimeTypeFactory, dynamicProto, options), - dispatcher, - options), + new DescriptorMessageProvider(runtimeTypeFactory, options), dispatcher, options), options); } + private static CelDescriptorPool newDescriptorPool( + ImmutableSet fileTypeSet, + ExtensionRegistry extensionRegistry, + CelOptions celOptions) { + CelDescriptors celDescriptors = + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( + fileTypeSet, celOptions.resolveTypeDependencies()); + + ImmutableList.Builder descriptorPools = new ImmutableList.Builder<>(); + + descriptorPools.add(DefaultDescriptorPool.create(celDescriptors, extensionRegistry)); + + return CombinedDescriptorPool.create(descriptorPools.build()); + } + @CanIgnoreReturnValue - private static MessageFactory maybeCombineTypeFactory( - @Nullable MessageFactory parentFactory, MessageFactory childFactory) { + private static ProtoMessageFactory maybeCombineMessageFactory( + @Nullable ProtoMessageFactory parentFactory, ProtoMessageFactory childFactory) { if (parentFactory == null) { return childFactory; } - return new MessageFactory.CombinedMessageFactory( + return new ProtoMessageFactory.CombinedMessageFactory( ImmutableList.of(parentFactory, childFactory)); } private Builder() { this.options = CelOptions.newBuilder().build(); this.fileTypes = ImmutableSet.builder(); - this.messageTypes = ImmutableSet.builder(); this.functionBindings = ImmutableMap.builder(); this.celRuntimeLibraries = ImmutableSet.builder(); + this.extensionRegistry = ExtensionRegistry.getEmptyRegistry(); this.customTypeFactory = null; } } diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java index 75c90cc42..09f70c52a 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java @@ -26,6 +26,7 @@ import dev.cel.common.CelOptions; import dev.cel.common.ExprFeatures; import dev.cel.common.annotations.Internal; +import dev.cel.common.internal.DefaultMessageFactory; import dev.cel.common.internal.DynamicProto; import java.util.ArrayList; import java.util.HashMap; @@ -68,7 +69,8 @@ public static DefaultDispatcher create(ImmutableSet features) { } public static DefaultDispatcher create(CelOptions celOptions) { - return create(celOptions, DynamicProto.newBuilder().build()); + DynamicProto dynamicProto = DynamicProto.create(DefaultMessageFactory.INSTANCE); + return create(celOptions, dynamicProto); } public static DefaultDispatcher create(CelOptions celOptions, DynamicProto dynamicProto) { diff --git a/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java b/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java index 8365bbcb2..97562bdbd 100644 --- a/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java +++ b/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java @@ -30,6 +30,7 @@ import dev.cel.common.annotations.Internal; import dev.cel.common.internal.DynamicProto; import dev.cel.common.internal.ProtoAdapter; +import dev.cel.common.internal.ProtoMessageFactory; import dev.cel.common.types.CelType; import java.util.Map; import java.util.Optional; @@ -47,38 +48,44 @@ @Immutable @Internal public final class DescriptorMessageProvider implements RuntimeTypeProvider { - private final MessageFactory messageFactory; - private final DynamicProto dynamicProto; + private final ProtoMessageFactory protoMessageFactory; private final TypeResolver typeResolver; @SuppressWarnings("Immutable") private final ProtoAdapter protoAdapter; - /** Creates a new message provider with the given message factory. */ + /** + * Creates a new message provider with the given message factory. + * + * @deprecated Migrate to the CEL-Java fluent APIs. See {@code CelRuntimeFactory}. + */ + @Deprecated public DescriptorMessageProvider(MessageFactory messageFactory) { - this(messageFactory, DynamicProto.newBuilder().build(), CelOptions.LEGACY); + this(messageFactory.toProtoMessageFactory(), CelOptions.LEGACY); } /** * Creates a new message provider with the given message factory and a set of customized {@code * features}. + * + * @deprecated Migrate to the CEL-Java fluent APIs. See {@code CelRuntimeFactory}. */ + @Deprecated public DescriptorMessageProvider( MessageFactory messageFactory, ImmutableSet features) { - this(messageFactory, DynamicProto.newBuilder().build(), CelOptions.fromExprFeatures(features)); + this(messageFactory.toProtoMessageFactory(), CelOptions.fromExprFeatures(features)); } /** * Create a new message provider with a given message factory and custom descriptor set to use * when adapting from proto to CEL and vice versa. */ - public DescriptorMessageProvider( - MessageFactory messageFactory, DynamicProto dynamicProto, CelOptions celOptions) { - this.dynamicProto = dynamicProto; - // Dedupe the descriptors while indexing by name. + public DescriptorMessageProvider(ProtoMessageFactory protoMessageFactory, CelOptions celOptions) { this.typeResolver = StandardTypeResolver.getInstance(celOptions); - this.messageFactory = messageFactory; - this.protoAdapter = new ProtoAdapter(dynamicProto, celOptions.enableUnsignedLongs()); + this.protoMessageFactory = protoMessageFactory; + this.protoAdapter = + new ProtoAdapter( + DynamicProto.create(protoMessageFactory), celOptions.enableUnsignedLongs()); } @Override @@ -104,13 +111,16 @@ public Value adaptType(@Nullable Type type) { @Nullable @Override public Object createMessage(String messageName, Map values) { - Message.Builder builder = messageFactory.newBuilder(messageName); - if (builder == null) { - throw new CelRuntimeException( - new IllegalArgumentException( - String.format("cannot resolve '%s' as a message", messageName)), - CelErrorCode.ATTRIBUTE_NOT_FOUND); - } + Message.Builder builder = + protoMessageFactory + .newBuilder(messageName) + .orElseThrow( + () -> + new CelRuntimeException( + new IllegalArgumentException( + String.format("cannot resolve '%s' as a message", messageName)), + CelErrorCode.ATTRIBUTE_NOT_FOUND)); + try { Descriptor descriptor = builder.getDescriptorForType(); for (Map.Entry entry : values.entrySet()) { @@ -199,7 +209,7 @@ private FieldDescriptor findField(Descriptor descriptor, String fieldName) { FieldDescriptor fieldDescriptor = descriptor.findFieldByName(fieldName); if (fieldDescriptor == null) { Optional maybeFieldDescriptor = - dynamicProto.maybeGetExtensionDescriptor(descriptor, fieldName); + protoMessageFactory.getDescriptorPool().findExtensionDescriptor(descriptor, fieldName); if (maybeFieldDescriptor.isPresent()) { fieldDescriptor = maybeFieldDescriptor.get(); } diff --git a/runtime/src/main/java/dev/cel/runtime/DynamicMessageFactory.java b/runtime/src/main/java/dev/cel/runtime/DynamicMessageFactory.java index 3a7e6bd92..60bfb6cd8 100644 --- a/runtime/src/main/java/dev/cel/runtime/DynamicMessageFactory.java +++ b/runtime/src/main/java/dev/cel/runtime/DynamicMessageFactory.java @@ -21,8 +21,10 @@ import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelDescriptors; import dev.cel.common.CelOptions; -import dev.cel.common.annotations.Internal; -import dev.cel.common.internal.DynamicProto; +import dev.cel.common.internal.CelDescriptorPool; +import dev.cel.common.internal.DefaultDescriptorPool; +import dev.cel.common.internal.DefaultMessageFactory; +import dev.cel.common.internal.ProtoMessageFactory; import java.util.Collection; import org.jspecify.nullness.Nullable; @@ -32,11 +34,13 @@ *

Creating message with {@code DynamicMessage} is significantly slower than instantiating * messages directly as it uses Java reflection. * - *

CEL Library Internals. Do Not Use. + * @deprecated Do not use. CEL-Java users should leverage the Fluent APIs instead. See {@code + * CelRuntimeFactory}. */ @Immutable -@Internal +@Deprecated public final class DynamicMessageFactory implements MessageFactory { + private final ProtoMessageFactory protoMessageFactory; /** * Create a {@link RuntimeTypeProvider} which can access only the types listed in the input {@code @@ -47,14 +51,8 @@ public final class DynamicMessageFactory implements MessageFactory { */ @Deprecated public static RuntimeTypeProvider typeProvider(Collection descriptors) { - CelDescriptors celDescriptors = - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - CelDescriptorUtil.getFileDescriptorsForDescriptors(descriptors)); - return new DescriptorMessageProvider( - typeFactory(descriptors), - DynamicProto.newBuilder().setDynamicDescriptors(celDescriptors).build(), - CelOptions.LEGACY); + typeFactory(descriptors).toProtoMessageFactory(), CelOptions.LEGACY); } /** @@ -69,28 +67,21 @@ public static MessageFactory typeFactory(Collection descriptors) { CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( CelDescriptorUtil.getFileDescriptorsForDescriptors(descriptors)); - return typeFactory(celDescriptors); - } - /** - * Create a {@code MessageFactory} which can produce any protobuf type linked in the binary, or - * present in the collection of {@code descriptors}. - */ - public static MessageFactory typeFactory(CelDescriptors celDescriptors) { - return new DynamicMessageFactory(celDescriptors); + return new DynamicMessageFactory(DefaultDescriptorPool.create(celDescriptors)); } - private final DynamicProto dynamicProto; + @Override + public ProtoMessageFactory toProtoMessageFactory() { + return protoMessageFactory; + } - private DynamicMessageFactory(CelDescriptors celDescriptors) { - this.dynamicProto = - DynamicProto.newBuilder() - .setDynamicDescriptors(celDescriptors) - .build(); + private DynamicMessageFactory(CelDescriptorPool celDescriptorPool) { + protoMessageFactory = DefaultMessageFactory.create(celDescriptorPool); } @Override public Message.@Nullable Builder newBuilder(String messageName) { - return dynamicProto.newMessageBuilder(messageName).orElse(null); + return protoMessageFactory.newBuilder(messageName).orElse(null); } } diff --git a/runtime/src/main/java/dev/cel/runtime/MessageFactory.java b/runtime/src/main/java/dev/cel/runtime/MessageFactory.java index f14b31ce1..e56825b4a 100644 --- a/runtime/src/main/java/dev/cel/runtime/MessageFactory.java +++ b/runtime/src/main/java/dev/cel/runtime/MessageFactory.java @@ -17,17 +17,19 @@ import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; import com.google.protobuf.Message; -import dev.cel.common.annotations.Internal; +import dev.cel.common.internal.DefaultMessageFactory; +import dev.cel.common.internal.ProtoMessageFactory; +import java.util.Optional; import org.jspecify.nullness.Nullable; /** * The {@code MessageFactory} provides a method to create a protobuf builder objects by name. * - *

CEL Library Internals. Do Not Use. + * @deprecated Do not use. Internally, {@link ProtoMessageFactory} should be used. */ @Immutable -@Internal @FunctionalInterface +@Deprecated public interface MessageFactory { /** @@ -37,6 +39,22 @@ public interface MessageFactory { */ Message.@Nullable Builder newBuilder(String messageName); + /** + * Exists only to maintain FunctionalInterface requirement and to make legacy Dynamic/Linked + * message factories compatible with the new ProtoMessageFactory. + */ + default ProtoMessageFactory toProtoMessageFactory() { + return msgName -> { + Optional msgBuilder = DefaultMessageFactory.INSTANCE.newBuilder(msgName); + if (msgBuilder.isPresent()) { + return msgBuilder; // Bypass custom factory and return well known type with our own + // descriptor. + } + + return Optional.ofNullable(newBuilder(msgName)); + }; + } + /** * The {@code CombinedMessageFactory} takes one or more {@code MessageFactory} instances and * attempts to create a {@code Message.Builder} instance for a given {@code messageName} by diff --git a/runtime/src/main/java/dev/cel/runtime/RuntimeHelpers.java b/runtime/src/main/java/dev/cel/runtime/RuntimeHelpers.java index d9a05abc7..ffb979842 100644 --- a/runtime/src/main/java/dev/cel/runtime/RuntimeHelpers.java +++ b/runtime/src/main/java/dev/cel/runtime/RuntimeHelpers.java @@ -50,8 +50,6 @@ public final class RuntimeHelpers { private static final java.time.Duration DURATION_MAX = java.time.Duration.ofDays(3652500); private static final java.time.Duration DURATION_MIN = DURATION_MAX.negated(); - private static final DynamicProto DYNAMIC_PROTO_INSTANCE = DynamicProto.newBuilder().build(); - // Functions // ========= @@ -362,10 +360,6 @@ public static Object maybeAdaptPrimitive(Object value) { return value; } - static Object adaptProtoToValue(MessageOrBuilder obj, CelOptions celOptions) { - return adaptProtoToValue(DYNAMIC_PROTO_INSTANCE, obj, celOptions); - } - /** * Adapts a {@code protobuf.Message} to a plain old Java object. * diff --git a/runtime/src/main/java/dev/cel/runtime/StandardFunctions.java b/runtime/src/main/java/dev/cel/runtime/StandardFunctions.java index 048d5a9fc..3b859aa1f 100644 --- a/runtime/src/main/java/dev/cel/runtime/StandardFunctions.java +++ b/runtime/src/main/java/dev/cel/runtime/StandardFunctions.java @@ -28,6 +28,7 @@ import dev.cel.common.CelOptions; import dev.cel.common.annotations.Internal; import dev.cel.common.internal.ComparisonFunctions; +import dev.cel.common.internal.DefaultMessageFactory; import dev.cel.common.internal.DynamicProto; import java.math.BigDecimal; import java.text.ParseException; @@ -86,7 +87,10 @@ public static void add(Registrar registrar, DynamicProto dynamicProto, CelOption * {@code FuturesInterpreter}. */ public static void addNonInlined(Registrar registrar, CelOptions celOptions) { - addNonInlined(registrar, new RuntimeEquality(DynamicProto.newBuilder().build()), celOptions); + addNonInlined( + registrar, + new RuntimeEquality(DynamicProto.create(DefaultMessageFactory.INSTANCE)), + celOptions); } /** diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index 17e4179e2..d50c8eab3 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -21,8 +21,12 @@ java_library( "//common:proto_v1alpha1_ast", "//common:runtime_exception", "//common/ast", + "//common/internal:cel_descriptor_pools", "//common/internal:converter", + "//common/internal:default_message_factory", "//common/internal:dynamic_proto", + "//common/internal:proto_message_factory", + "//common/internal:well_known_proto", "//common/resources/testdata/proto2:messages_extensions_proto2_java_proto", "//common/resources/testdata/proto2:messages_proto2_java_proto", "//common/resources/testdata/proto3:test_all_types_java_proto", @@ -43,6 +47,7 @@ java_library( "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_google_protobuf_protobuf_java_util", + "@maven//:com_google_testparameterinjector_test_parameter_injector", "@maven//:com_google_truth_extensions_truth_proto_extension", "@maven//:junit_junit", "@maven//:org_jspecify_jspecify", diff --git a/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java b/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java index 0c4605cef..72c318e37 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java @@ -21,7 +21,10 @@ import com.google.api.expr.v1alpha1.Type.PrimitiveType; import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; import com.google.protobuf.ByteString; +import com.google.protobuf.DescriptorProtos.FileDescriptorSet; +import com.google.protobuf.DynamicMessage; import com.google.rpc.context.AttributeContext; import dev.cel.bundle.Cel; import dev.cel.bundle.CelFactory; @@ -35,6 +38,8 @@ import dev.cel.common.types.CelV1AlphaTypes; import dev.cel.common.types.SimpleType; import dev.cel.common.types.StructTypeReference; +import dev.cel.compiler.CelCompiler; +import dev.cel.compiler.CelCompilerFactory; import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparserFactory; import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes; @@ -107,6 +112,100 @@ public void evaluate_v1alpha1CheckedExpr() throws Exception { assertThat(evaluatedResult).isEqualTo("Hello world!"); } + @Test + public void newWellKnownTypeMessage_withDifferentDescriptorInstance() throws Exception { + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addMessageTypes(BoolValue.getDescriptor()) + .build(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addFileTypes( + FileDescriptorSet.newBuilder() + .addFile( + BoolValue.getDescriptor().getFile().toProto()) // Copy the WKT descriptor + .build()) + .build(); + + CelAbstractSyntaxTree ast = + celCompiler.compile("google.protobuf.BoolValue{value: false}").getAst(); + + assertThat(celRuntime.createProgram(ast).eval()).isEqualTo(false); + } + + @Test + public void newWellKnownTypeMessage_inDynamicMessage_withSetTypeFactory() throws Exception { + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addMessageTypes(BoolValue.getDescriptor()) + .build(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .setTypeFactory( + (typeName) -> + typeName.equals("google.protobuf.BoolValue") + ? DynamicMessage.newBuilder(BoolValue.getDescriptor()) + : null) + .build(); + + CelAbstractSyntaxTree ast = + celCompiler.compile("google.protobuf.BoolValue{value: false}").getAst(); + + assertThat(celRuntime.createProgram(ast).eval()).isEqualTo(false); + } + + @Test + public void newWellKnownTypeMessage_inAnyMessage_withDifferentDescriptorInstance() + throws Exception { + FileDescriptorSet fds = + FileDescriptorSet.newBuilder() + // Copy the WKT descriptors + .addFile(Any.getDescriptor().getFile().toProto()) + .addFile(BoolValue.getDescriptor().getFile().toProto()) + .build(); + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder().addFileTypes(fds).build(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .build(); + + CelAbstractSyntaxTree ast = + celCompiler + .compile( + "google.protobuf.Any{type_url: 'types.googleapis.com/google.protobuf.DoubleValue'}") + .getAst(); + + assertThat(celRuntime.createProgram(ast).eval()).isEqualTo(0.0d); + } + + @Test + public void newWellKnownTypeMessage_inAnyMessage_withSetTypeFactory() throws Exception { + FileDescriptorSet fds = + FileDescriptorSet.newBuilder() + // Copy the WKT descriptors + .addFile(Any.getDescriptor().getFile().toProto()) + .addFile(BoolValue.getDescriptor().getFile().toProto()) + .build(); + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder().addFileTypes(fds).build(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .setTypeFactory( + (typeName) -> + typeName.equals("google.protobuf.Any") + ? Any.newBuilder().setTypeUrl("google.protobuf.DoubleValue") + : null) + .build(); + + CelAbstractSyntaxTree ast = + celCompiler + .compile( + "google.protobuf.Any{type_url: 'types.googleapis.com/google.protobuf.DoubleValue'}") + .getAst(); + + assertThat(celRuntime.createProgram(ast).eval()).isEqualTo(0.0d); + } + @Test public void trace_callExpr_identifyFalseBranch() throws Exception { AtomicReference capturedExpr = new AtomicReference<>(); diff --git a/runtime/src/test/java/dev/cel/runtime/DescriptorMessageProviderTest.java b/runtime/src/test/java/dev/cel/runtime/DescriptorMessageProviderTest.java index 37888d19c..2421d2daf 100644 --- a/runtime/src/test/java/dev/cel/runtime/DescriptorMessageProviderTest.java +++ b/runtime/src/test/java/dev/cel/runtime/DescriptorMessageProviderTest.java @@ -19,15 +19,25 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Any; import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.DynamicMessage; import com.google.protobuf.Message; import com.google.protobuf.NullValue; +import com.google.testing.junit.testparameterinjector.TestParameter; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelDescriptors; import dev.cel.common.CelErrorCode; import dev.cel.common.CelOptions; import dev.cel.common.CelRuntimeException; -import dev.cel.common.internal.DynamicProto; +import dev.cel.common.internal.CelDescriptorPool; +import dev.cel.common.internal.DefaultDescriptorPool; +import dev.cel.common.internal.DefaultMessageFactory; +// CEL-Internal-3 +import dev.cel.common.internal.ProtoMessageFactory; +import dev.cel.common.internal.WellKnownProto; import dev.cel.testing.testdata.proto2.MessagesProto2; import dev.cel.testing.testdata.proto2.MessagesProto2Extensions; import dev.cel.testing.testdata.proto2.Proto2Message; @@ -37,9 +47,8 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -@RunWith(JUnit4.class) +@RunWith(TestParameterInjector.class) public final class DescriptorMessageProviderTest { private RuntimeTypeProvider provider; @@ -47,16 +56,12 @@ public final class DescriptorMessageProviderTest { @Before public void setUp() { CelOptions options = CelOptions.current().build(); - ImmutableList descriptors = ImmutableList.of(TestAllTypes.getDescriptor()); - DynamicProto dynamicProto = - DynamicProto.newBuilder() - .setDynamicDescriptors( - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - TestAllTypes.getDescriptor().getFile())) - .build(); - provider = - new DescriptorMessageProvider( - DynamicMessageFactory.typeFactory(descriptors), dynamicProto, options); + CelDescriptors celDescriptors = + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( + TestAllTypes.getDescriptor().getFile()); + ProtoMessageFactory dynamicMessageFactory = + DefaultMessageFactory.create(DefaultDescriptorPool.create(celDescriptors)); + provider = new DescriptorMessageProvider(dynamicMessageFactory, options); } @Test @@ -169,15 +174,11 @@ public void selectField_extensionUsingDynamicTypes() { CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( ImmutableList.of(MessagesProto2Extensions.getDescriptor())); - DynamicProto dynamicProto = - DynamicProto.newBuilder() - .setDynamicDescriptors(celDescriptors) - .build(); + CelDescriptorPool pool = DefaultDescriptorPool.create(celDescriptors); + provider = new DescriptorMessageProvider( - DynamicMessageFactory.typeFactory(celDescriptors), - dynamicProto, - CelOptions.current().build()); + DefaultMessageFactory.create(pool), CelOptions.current().build()); long result = (long) @@ -189,4 +190,55 @@ public void selectField_extensionUsingDynamicTypes() { assertThat(result).isEqualTo(10); } + + @Test + public void createMessage_wellKnownType_withCustomMessageProvider( + @TestParameter WellKnownProto wellKnownProto) { + if (wellKnownProto.equals(WellKnownProto.ANY_VALUE)) { + return; + } + + Descriptor wellKnownDescriptor = wellKnownProto.descriptor(); + DescriptorMessageProvider messageProvider = + new DescriptorMessageProvider( + msgName -> + msgName.equals(wellKnownDescriptor.getFullName()) + ? DynamicMessage.newBuilder(wellKnownDescriptor) + : null); + + Object createdMessage = + messageProvider.createMessage(wellKnownDescriptor.getFullName(), ImmutableMap.of()); + + assertThat(createdMessage).isNotNull(); + } + + @Test + public void createMessage_anyType_withCustomMessageProvider() { + DescriptorMessageProvider messageProvider = + new DescriptorMessageProvider( + msgName -> msgName.equals(Any.getDescriptor().getFullName()) ? Any.newBuilder() : null); + + double createdMessage = + (double) + messageProvider.createMessage( + Any.getDescriptor().getFullName(), + ImmutableMap.of("type_url", "types.googleapis.com/google.protobuf.DoubleValue")); + + assertThat(createdMessage).isEqualTo(0.0d); + } + + @Test + public void createMessage_doubleValue_withCustomMessageProvider() { + DescriptorMessageProvider messageProvider = + new DescriptorMessageProvider( + msgName -> + msgName.equals("google.protobuf.DoubleValue") + ? DynamicMessage.newBuilder(DoubleValue.getDescriptor()) + : null); + + double value = + (double) messageProvider.createMessage("google.protobuf.DoubleValue", ImmutableMap.of()); + + assertThat(value).isEqualTo(0.0d); + } } diff --git a/runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java b/runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java index 74cadc81f..250584f48 100644 --- a/runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java +++ b/runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java @@ -47,6 +47,8 @@ import dev.cel.common.CelRuntimeException; import dev.cel.common.internal.AdaptingTypes; import dev.cel.common.internal.BidiConverter; +import dev.cel.common.internal.DefaultDescriptorPool; +import dev.cel.common.internal.DefaultMessageFactory; import dev.cel.common.internal.DynamicProto; import java.util.Arrays; import java.util.List; @@ -78,11 +80,11 @@ public final class RuntimeEqualityTest { private static final RuntimeEquality RUNTIME_EQUALITY = new RuntimeEquality( - DynamicProto.newBuilder() - .setDynamicDescriptors( - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - AttributeContext.getDescriptor().getFile())) - .build()); + DynamicProto.create( + DefaultMessageFactory.create( + DefaultDescriptorPool.create( + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( + AttributeContext.getDescriptor().getFile()))))); @Test public void inMap() throws Exception { diff --git a/runtime/src/test/java/dev/cel/runtime/RuntimeHelpersTest.java b/runtime/src/test/java/dev/cel/runtime/RuntimeHelpersTest.java index 32ab52209..b22c3f14e 100644 --- a/runtime/src/test/java/dev/cel/runtime/RuntimeHelpersTest.java +++ b/runtime/src/test/java/dev/cel/runtime/RuntimeHelpersTest.java @@ -37,6 +37,8 @@ import com.google.protobuf.Value; import dev.cel.common.CelOptions; import dev.cel.common.CelRuntimeException; +import dev.cel.common.internal.DefaultMessageFactory; +import dev.cel.common.internal.DynamicProto; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -47,6 +49,8 @@ @RunWith(JUnit4.class) public final class RuntimeHelpersTest { + private static final DynamicProto DYNAMIC_PROTO = + DynamicProto.create(DefaultMessageFactory.INSTANCE); @Test public void createDurationFromString() throws Exception { @@ -310,26 +314,38 @@ public void maybeAdaptPrimitive_optionalValues() { @Test public void adaptProtoToValue_wrapperValues() throws Exception { CelOptions celOptions = CelOptions.LEGACY; - assertThat(RuntimeHelpers.adaptProtoToValue(BoolValue.of(true), celOptions)).isEqualTo(true); - assertThat(RuntimeHelpers.adaptProtoToValue(BytesValue.of(ByteString.EMPTY), celOptions)) + assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, BoolValue.of(true), celOptions)) + .isEqualTo(true); + assertThat( + RuntimeHelpers.adaptProtoToValue( + DYNAMIC_PROTO, BytesValue.of(ByteString.EMPTY), celOptions)) .isEqualTo(ByteString.EMPTY); - assertThat(RuntimeHelpers.adaptProtoToValue(DoubleValue.of(1.5d), celOptions)).isEqualTo(1.5d); - assertThat(RuntimeHelpers.adaptProtoToValue(FloatValue.of(1.5f), celOptions)).isEqualTo(1.5d); - assertThat(RuntimeHelpers.adaptProtoToValue(Int32Value.of(12), celOptions)).isEqualTo(12L); - assertThat(RuntimeHelpers.adaptProtoToValue(Int64Value.of(-12L), celOptions)).isEqualTo(-12L); - assertThat(RuntimeHelpers.adaptProtoToValue(UInt32Value.of(123), celOptions)).isEqualTo(123L); - assertThat(RuntimeHelpers.adaptProtoToValue(UInt64Value.of(1234L), celOptions)) + assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, DoubleValue.of(1.5d), celOptions)) + .isEqualTo(1.5d); + assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, FloatValue.of(1.5f), celOptions)) + .isEqualTo(1.5d); + assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, Int32Value.of(12), celOptions)) + .isEqualTo(12L); + assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, Int64Value.of(-12L), celOptions)) + .isEqualTo(-12L); + assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, UInt32Value.of(123), celOptions)) + .isEqualTo(123L); + assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, UInt64Value.of(1234L), celOptions)) .isEqualTo(1234L); - assertThat(RuntimeHelpers.adaptProtoToValue(StringValue.of("hello"), celOptions)) + assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, StringValue.of("hello"), celOptions)) .isEqualTo("hello"); assertThat( RuntimeHelpers.adaptProtoToValue( - UInt32Value.of(123), CelOptions.newBuilder().enableUnsignedLongs(true).build())) + DYNAMIC_PROTO, + UInt32Value.of(123), + CelOptions.newBuilder().enableUnsignedLongs(true).build())) .isEqualTo(UnsignedLong.valueOf(123L)); assertThat( RuntimeHelpers.adaptProtoToValue( - UInt64Value.of(1234L), CelOptions.newBuilder().enableUnsignedLongs(true).build())) + DYNAMIC_PROTO, + UInt64Value.of(1234L), + CelOptions.newBuilder().enableUnsignedLongs(true).build())) .isEqualTo(UnsignedLong.valueOf(1234L)); } @@ -337,11 +353,14 @@ public void adaptProtoToValue_wrapperValues() throws Exception { public void adaptProtoToValue_jsonValues() throws Exception { assertThat( RuntimeHelpers.adaptProtoToValue( - Value.newBuilder().setStringValue("json").build(), CelOptions.LEGACY)) + DYNAMIC_PROTO, + Value.newBuilder().setStringValue("json").build(), + CelOptions.LEGACY)) .isEqualTo("json"); assertThat( RuntimeHelpers.adaptProtoToValue( + DYNAMIC_PROTO, Value.newBuilder() .setListValue( ListValue.newBuilder() @@ -354,6 +373,7 @@ public void adaptProtoToValue_jsonValues() throws Exception { mp.put("list_value", ImmutableList.of(false, NullValue.NULL_VALUE)); assertThat( RuntimeHelpers.adaptProtoToValue( + DYNAMIC_PROTO, Struct.newBuilder() .putFields( "list_value", @@ -385,13 +405,16 @@ public void adaptProtoToValue_anyValues() throws Exception { .build(); Any anyJsonValue = Any.pack(jsonValue); mp.put("list_value", ImmutableList.of(false, NullValue.NULL_VALUE)); - assertThat(RuntimeHelpers.adaptProtoToValue(anyJsonValue, CelOptions.LEGACY)).isEqualTo(mp); + assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, anyJsonValue, CelOptions.LEGACY)) + .isEqualTo(mp); } @Test public void adaptProtoToValue_builderValue() throws Exception { CelOptions celOptions = CelOptions.LEGACY; - assertThat(RuntimeHelpers.adaptProtoToValue(BoolValue.newBuilder().setValue(true), celOptions)) + assertThat( + RuntimeHelpers.adaptProtoToValue( + DYNAMIC_PROTO, BoolValue.newBuilder().setValue(true), celOptions)) .isEqualTo(true); } diff --git a/testing/src/main/java/dev/cel/testing/BUILD.bazel b/testing/src/main/java/dev/cel/testing/BUILD.bazel index b754368bd..1ffb77b47 100644 --- a/testing/src/main/java/dev/cel/testing/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/BUILD.bazel @@ -89,7 +89,8 @@ java_library( ":eval", "//common", "//common:options", - "//common/internal:dynamic_proto", + "//common/internal:cel_descriptor_pools", + "//common/internal:default_message_factory", "//runtime:interpreter", "@cel_spec//proto/cel/expr:expr_java_proto", "@maven//:com_google_guava_guava", @@ -126,7 +127,7 @@ java_library( ":eval", "//:java_truth", "//common", - "//common/internal:dynamic_proto", + "//common/internal:cel_descriptor_pools", "//common/resources/testdata/proto3:standalone_global_enum_java_proto", "//common/resources/testdata/proto3:test_all_types_java_proto", "//common/types:cel_types", diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index 34ae763b2..4f7ba5b3e 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -50,7 +50,7 @@ import com.google.protobuf.util.Timestamps; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelProtoAbstractSyntaxTree; -import dev.cel.common.internal.ProtoRegistryProvider; +import dev.cel.common.internal.DefaultDescriptorPool; import dev.cel.common.types.CelTypes; import dev.cel.runtime.Activation; import dev.cel.runtime.InterpreterException; @@ -1916,7 +1916,7 @@ public void dynamicMessage() throws Exception { DynamicMessage.parseFrom( TestAllTypes.getDescriptor(), wrapperBindings.toByteArray(), - ProtoRegistryProvider.getExtensionRegistry())); + DefaultDescriptorPool.INSTANCE.getExtensionRegistry())); declareVariable("msg", CelTypes.createMessage(TestAllTypes.getDescriptor().getFullName())); diff --git a/testing/src/main/java/dev/cel/testing/EvalSync.java b/testing/src/main/java/dev/cel/testing/EvalSync.java index b72badd13..403134454 100644 --- a/testing/src/main/java/dev/cel/testing/EvalSync.java +++ b/testing/src/main/java/dev/cel/testing/EvalSync.java @@ -20,12 +20,12 @@ import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelDescriptors; import dev.cel.common.CelOptions; -import dev.cel.common.internal.DynamicProto; +import dev.cel.common.internal.DefaultDescriptorPool; +import dev.cel.common.internal.DefaultMessageFactory; import dev.cel.runtime.Activation; import dev.cel.runtime.DefaultDispatcher; import dev.cel.runtime.DefaultInterpreter; import dev.cel.runtime.DescriptorMessageProvider; -import dev.cel.runtime.DynamicMessageFactory; import dev.cel.runtime.Interpreter; import dev.cel.runtime.InterpreterException; import dev.cel.runtime.Registrar; @@ -47,9 +47,7 @@ public EvalSync(ImmutableList fileDescriptors, CelOptions celOpt CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors); this.typeProvider = new DescriptorMessageProvider( - DynamicMessageFactory.typeFactory(celDescriptors), - DynamicProto.newBuilder().setDynamicDescriptors(celDescriptors).build(), - celOptions); + DefaultMessageFactory.create(DefaultDescriptorPool.create(celDescriptors)), celOptions); this.interpreter = new DefaultInterpreter(typeProvider, dispatcher, celOptions); this.celOptions = celOptions; }