diff --git a/bundle/src/main/java/dev/cel/bundle/CelBuilder.java b/bundle/src/main/java/dev/cel/bundle/CelBuilder.java index 2e4b535ae..7857d4140 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelBuilder.java +++ b/bundle/src/main/java/dev/cel/bundle/CelBuilder.java @@ -20,6 +20,7 @@ import com.google.protobuf.DescriptorProtos.FileDescriptorSet; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import dev.cel.checker.ProtoTypeMask; import dev.cel.checker.TypeProvider; @@ -282,6 +283,13 @@ public interface CelBuilder { @CanIgnoreReturnValue CelBuilder addRuntimeLibraries(Iterable libraries); + /** + * Sets a proto ExtensionRegistry to assist with unpacking Any messages containing a proto2 + extension field. + */ + @CanIgnoreReturnValue + CelBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry); + /** Construct a new {@code Cel} instance from the provided configuration. */ Cel build(); } diff --git a/bundle/src/main/java/dev/cel/bundle/CelImpl.java b/bundle/src/main/java/dev/cel/bundle/CelImpl.java index 81680c86a..3d65ada89 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelImpl.java +++ b/bundle/src/main/java/dev/cel/bundle/CelImpl.java @@ -25,6 +25,7 @@ import com.google.protobuf.DescriptorProtos.FileDescriptorSet; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import dev.cel.checker.CelCheckerBuilder; import dev.cel.checker.ProtoTypeMask; @@ -339,6 +340,13 @@ public Builder addRuntimeLibraries(Iterable libraries) { return this; } + @Override + public CelBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry) { + checkNotNull(extensionRegistry); + runtimeBuilder.setExtensionRegistry(extensionRegistry); + return this; + } + @Override public Cel build() { return new CelImpl( diff --git a/common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java b/common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java index 66dbf2e1e..fc703c905 100644 --- a/common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java +++ b/common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java @@ -46,7 +46,10 @@ public final class DefaultDescriptorPool implements CelDescriptorPool { /** A DefaultDescriptorPool instance with just well known types loaded. */ public static final DefaultDescriptorPool INSTANCE = - new DefaultDescriptorPool(WELL_KNOWN_TYPE_DESCRIPTORS, ImmutableMultimap.of()); + new DefaultDescriptorPool( + WELL_KNOWN_TYPE_DESCRIPTORS, + ImmutableMultimap.of(), + ExtensionRegistry.getEmptyRegistry()); // K: Fully qualified message type name, V: Message descriptor private final ImmutableMap descriptorMap; @@ -55,7 +58,15 @@ public final class DefaultDescriptorPool implements CelDescriptorPool { // V: Field descriptor for the extension message private final ImmutableMultimap extensionDescriptorMap; + @SuppressWarnings("Immutable") // ExtensionRegistry is immutable, just not marked as such. + private final ExtensionRegistry extensionRegistry; + public static DefaultDescriptorPool create(CelDescriptors celDescriptors) { + return create(celDescriptors, ExtensionRegistry.getEmptyRegistry()); + } + + public static DefaultDescriptorPool create( + CelDescriptors celDescriptors, ExtensionRegistry extensionRegistry) { Map descriptorMap = new HashMap<>(); // Using a hashmap to allow deduping stream(WellKnownProto.values()).forEach(d -> descriptorMap.put(d.typeName(), d.descriptor())); @@ -64,7 +75,9 @@ public static DefaultDescriptorPool create(CelDescriptors celDescriptors) { } return new DefaultDescriptorPool( - ImmutableMap.copyOf(descriptorMap), celDescriptors.extensionDescriptors()); + ImmutableMap.copyOf(descriptorMap), + celDescriptors.extensionDescriptors(), + extensionRegistry); } @Override @@ -83,14 +96,15 @@ public Optional findExtensionDescriptor( @Override public ExtensionRegistry getExtensionRegistry() { - // TODO: Populate one from runtime builder. - return ExtensionRegistry.getEmptyRegistry(); + return extensionRegistry; } private DefaultDescriptorPool( ImmutableMap descriptorMap, - ImmutableMultimap extensionDescriptorMap) { + ImmutableMultimap extensionDescriptorMap, + ExtensionRegistry extensionRegistry) { this.descriptorMap = checkNotNull(descriptorMap); this.extensionDescriptorMap = checkNotNull(extensionDescriptorMap); + this.extensionRegistry = checkNotNull(extensionRegistry); } } diff --git a/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java index eea889f94..d7b66cb1d 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java @@ -19,10 +19,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Any; import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOverloadDecl; @@ -277,6 +281,29 @@ public void getExt_nonProtoNamespace_success(String expr) throws Exception { assertThat(result).isTrue(); } + @Test + public void getExt_onAnyPackedExtensionField_success() throws Exception { + ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); + MessagesProto2Extensions.registerAllExtensions(extensionRegistry); + Cel cel = + CelFactory.standardCelBuilder() + .addCompilerLibraries(CelExtensions.protos()) + .addFileTypes(MessagesProto2Extensions.getDescriptor()) + .setExtensionRegistry(extensionRegistry) + .addVar( + "msg", StructTypeReference.create("dev.cel.testing.testdata.proto2.Proto2Message")) + .build(); + CelAbstractSyntaxTree ast = + cel.compile("proto.getExt(msg, dev.cel.testing.testdata.proto2.int32_ext)").getAst(); + Any anyMsg = + Any.pack( + Proto2Message.newBuilder().setExtension(MessagesProto2Extensions.int32Ext, 1).build()); + + Long result = (Long) cel.createProgram(ast).eval(ImmutableMap.of("msg", anyMsg)); + + assertThat(result).isEqualTo(1); + } + private enum ParseErrorTestCase { FIELD_NOT_FULLY_QUALIFIED( "proto.getExt(Proto2ExtensionScopedMessage{}, int64_ext)", diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java index b24ef188a..1408a2069 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java @@ -19,6 +19,7 @@ import com.google.protobuf.DescriptorProtos.FileDescriptorSet; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import dev.cel.common.CelOptions; import java.util.function.Function; @@ -149,6 +150,13 @@ public interface CelRuntimeBuilder { @CanIgnoreReturnValue CelRuntimeBuilder addLibraries(Iterable libraries); + /** + * Sets a proto ExtensionRegistry to assist with unpacking Any messages containing a proto2 + extension field. + */ + @CanIgnoreReturnValue + CelRuntimeBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry); + /** Build a new instance of the {@code CelRuntime}. */ @CheckReturnValue CelRuntime build(); diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java index e318c2036..45adca17b 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java @@ -25,6 +25,7 @@ import com.google.protobuf.DescriptorProtos.FileDescriptorSet; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelDescriptorUtil; @@ -79,6 +80,7 @@ public static final class Builder implements CelRuntimeBuilder { private boolean standardEnvironmentEnabled; private Function customTypeFactory; + private ExtensionRegistry extensionRegistry; @Override @CanIgnoreReturnValue @@ -161,6 +163,14 @@ public Builder addLibraries(Iterable libraries) { return this; } + @Override + @CanIgnoreReturnValue + public Builder setExtensionRegistry(ExtensionRegistry extensionRegistry) { + checkNotNull(extensionRegistry); + this.extensionRegistry = extensionRegistry.getUnmodifiable(); + return this; + } + /** Build a new {@code CelRuntimeLegacyImpl} instance from the builder config. */ @Override @CanIgnoreReturnValue @@ -171,6 +181,7 @@ public CelRuntimeLegacyImpl build() { CelDescriptorPool celDescriptorPool = newDescriptorPool( fileTypes.build(), + extensionRegistry, options); @SuppressWarnings("Immutable") @@ -214,6 +225,7 @@ public CelRuntimeLegacyImpl build() { private static CelDescriptorPool newDescriptorPool( ImmutableSet fileTypeSet, + ExtensionRegistry extensionRegistry, CelOptions celOptions) { CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( @@ -221,7 +233,7 @@ private static CelDescriptorPool newDescriptorPool( ImmutableList.Builder descriptorPools = new ImmutableList.Builder<>(); - descriptorPools.add(DefaultDescriptorPool.create(celDescriptors)); + descriptorPools.add(DefaultDescriptorPool.create(celDescriptors, extensionRegistry)); return CombinedDescriptorPool.create(descriptorPools.build()); } @@ -241,6 +253,7 @@ private Builder() { this.fileTypes = ImmutableSet.builder(); this.functionBindings = ImmutableMap.builder(); this.celRuntimeLibraries = ImmutableSet.builder(); + this.extensionRegistry = ExtensionRegistry.getEmptyRegistry(); this.customTypeFactory = null; } }